Commit b17fcc310fe7a93eb37fb08ac17c94676aa00a4b
1 parent
d254fe2a
cleaner l2
Showing
1 changed file
with
9 additions
and
1 deletions
openbr/plugins/cuda/cudal2.cpp
| ... | ... | @@ -42,8 +42,16 @@ class CUDAL2Distance : public UntrainableDistance |
| 42 | 42 | float* cudaAPtr = (float*)srcDataPtr[0]; |
| 43 | 43 | int rows = *((int*)srcDataPtr[1]); |
| 44 | 44 | int cols = *((int*)srcDataPtr[2]); |
| 45 | + int srcType = *((int*)srcDataPtr[3]); | |
| 45 | 46 | |
| 46 | - float* cudaBPtr = (float*)b.ptr<void*>()[0]; | |
| 47 | + void* const* dstDataPtr = b.ptr<void*>(); | |
| 48 | + float* cudaBPtr = (float*)dstDataPtr[0]; | |
| 49 | + int dstType = *((int*)dstDataPtr[3]); | |
| 50 | + | |
| 51 | + if (srcType != dstType) { | |
| 52 | + cout << "ERR: Type mismatch" << endl; | |
| 53 | + throw 0; | |
| 54 | + } | |
| 47 | 55 | |
| 48 | 56 | float out; |
| 49 | 57 | cuda::L2::wrapper(cudaAPtr, cudaBPtr, rows*cols, &out); | ... | ... |