Commit 04037199a0aceeaef93e4a5404ea5a04de99032c
1 parent
7b195032
cleaned up cosine distance
Showing
2 changed files
with
14 additions
and
18 deletions
sdk/plugins/compare.cpp
| ... | ... | @@ -42,7 +42,7 @@ public: |
| 42 | 42 | INF, |
| 43 | 43 | L1, |
| 44 | 44 | L2, |
| 45 | - CosineSimilarity }; | |
| 45 | + Cosine }; | |
| 46 | 46 | |
| 47 | 47 | private: |
| 48 | 48 | BR_PROPERTY(Metric, metric, L2) |
| ... | ... | @@ -76,8 +76,8 @@ private: |
| 76 | 76 | case L2: |
| 77 | 77 | result = norm(a, b, NORM_L2); |
| 78 | 78 | break; |
| 79 | - case CosineSimilarity: | |
| 80 | - result = cosineSimilarity(a, b); | |
| 79 | + case Cosine: | |
| 80 | + result = cosine(a, b); | |
| 81 | 81 | break; |
| 82 | 82 | default: |
| 83 | 83 | qFatal("Invalid metric"); |
| ... | ... | @@ -89,33 +89,29 @@ private: |
| 89 | 89 | return -log(result+1); |
| 90 | 90 | } |
| 91 | 91 | |
| 92 | - static float cosineSimilarity(const Mat &a, const Mat &b) | |
| 92 | + static float cosine(const Mat &a, const Mat &b) | |
| 93 | 93 | { |
| 94 | - assert((a.type() == CV_32FC1) && (b.type() == CV_32FC1)); | |
| 95 | - assert((a.rows == b.rows) && (a.cols == b.cols)); | |
| 96 | - | |
| 97 | - float denom = 0; | |
| 98 | - float tnum = 0; | |
| 99 | - float qnum = 0; | |
| 94 | + float dot = 0; | |
| 95 | + float magA = 0; | |
| 96 | + float magB = 0; | |
| 100 | 97 | |
| 101 | 98 | for (int row=0; row<a.rows; row++) { |
| 102 | 99 | for (int col=0; col<a.cols; col++) { |
| 103 | - float target = a.at<float>(row,col); | |
| 104 | - float query = b.at<float>(row,col); | |
| 100 | + const float target = a.at<float>(row,col); | |
| 101 | + const float query = b.at<float>(row,col); | |
| 105 | 102 | |
| 106 | - denom += target * query; | |
| 107 | - tnum += target * target; | |
| 108 | - qnum += query * query; | |
| 103 | + dot += target * query; | |
| 104 | + magA += target * target; | |
| 105 | + magB += query * query; | |
| 109 | 106 | } |
| 110 | 107 | } |
| 111 | 108 | |
| 112 | - return denom / (sqrt(tnum)*sqrt(qnum)); | |
| 109 | + return dot / (sqrt(magA)*sqrt(magB)); | |
| 113 | 110 | } |
| 114 | 111 | }; |
| 115 | 112 | |
| 116 | 113 | BR_REGISTER(Distance, Dist) |
| 117 | 114 | |
| 118 | - | |
| 119 | 115 | /*! |
| 120 | 116 | * \ingroup distances |
| 121 | 117 | * \brief Fast 8-bit L1 distance | ... | ... |