Commit 04037199a0aceeaef93e4a5404ea5a04de99032c
1 parent
7b195032
cleaned up cosine distance
Showing
2 changed files
with
14 additions
and
18 deletions
sdk/plugins/compare.cpp
| @@ -42,7 +42,7 @@ public: | @@ -42,7 +42,7 @@ public: | ||
| 42 | INF, | 42 | INF, |
| 43 | L1, | 43 | L1, |
| 44 | L2, | 44 | L2, |
| 45 | - CosineSimilarity }; | 45 | + Cosine }; |
| 46 | 46 | ||
| 47 | private: | 47 | private: |
| 48 | BR_PROPERTY(Metric, metric, L2) | 48 | BR_PROPERTY(Metric, metric, L2) |
| @@ -76,8 +76,8 @@ private: | @@ -76,8 +76,8 @@ private: | ||
| 76 | case L2: | 76 | case L2: |
| 77 | result = norm(a, b, NORM_L2); | 77 | result = norm(a, b, NORM_L2); |
| 78 | break; | 78 | break; |
| 79 | - case CosineSimilarity: | ||
| 80 | - result = cosineSimilarity(a, b); | 79 | + case Cosine: |
| 80 | + result = cosine(a, b); | ||
| 81 | break; | 81 | break; |
| 82 | default: | 82 | default: |
| 83 | qFatal("Invalid metric"); | 83 | qFatal("Invalid metric"); |
| @@ -89,33 +89,29 @@ private: | @@ -89,33 +89,29 @@ private: | ||
| 89 | return -log(result+1); | 89 | return -log(result+1); |
| 90 | } | 90 | } |
| 91 | 91 | ||
| 92 | - static float cosineSimilarity(const Mat &a, const Mat &b) | 92 | + static float cosine(const Mat &a, const Mat &b) |
| 93 | { | 93 | { |
| 94 | - assert((a.type() == CV_32FC1) && (b.type() == CV_32FC1)); | ||
| 95 | - assert((a.rows == b.rows) && (a.cols == b.cols)); | ||
| 96 | - | ||
| 97 | - float denom = 0; | ||
| 98 | - float tnum = 0; | ||
| 99 | - float qnum = 0; | 94 | + float dot = 0; |
| 95 | + float magA = 0; | ||
| 96 | + float magB = 0; | ||
| 100 | 97 | ||
| 101 | for (int row=0; row<a.rows; row++) { | 98 | for (int row=0; row<a.rows; row++) { |
| 102 | for (int col=0; col<a.cols; col++) { | 99 | for (int col=0; col<a.cols; col++) { |
| 103 | - float target = a.at<float>(row,col); | ||
| 104 | - float query = b.at<float>(row,col); | 100 | + const float target = a.at<float>(row,col); |
| 101 | + const float query = b.at<float>(row,col); | ||
| 105 | 102 | ||
| 106 | - denom += target * query; | ||
| 107 | - tnum += target * target; | ||
| 108 | - qnum += query * query; | 103 | + dot += target * query; |
| 104 | + magA += target * target; | ||
| 105 | + magB += query * query; | ||
| 109 | } | 106 | } |
| 110 | } | 107 | } |
| 111 | 108 | ||
| 112 | - return denom / (sqrt(tnum)*sqrt(qnum)); | 109 | + return dot / (sqrt(magA)*sqrt(magB)); |
| 113 | } | 110 | } |
| 114 | }; | 111 | }; |
| 115 | 112 | ||
| 116 | BR_REGISTER(Distance, Dist) | 113 | BR_REGISTER(Distance, Dist) |
| 117 | 114 | ||
| 118 | - | ||
| 119 | /*! | 115 | /*! |
| 120 | * \ingroup distances | 116 | * \ingroup distances |
| 121 | * \brief Fast 8-bit L1 distance | 117 | * \brief Fast 8-bit L1 distance |