cluster.cpp 9.28 KB
/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
 * Copyright 2012 The MITRE Corporation                                      *
 *                                                                           *
 * Licensed under the Apache License, Version 2.0 (the "License");           *
 * you may not use this file except in compliance with the License.          *
 * You may obtain a copy of the License at                                   *
 *                                                                           *
 *     http://www.apache.org/licenses/LICENSE-2.0                            *
 *                                                                           *
 * Unless required by applicable law or agreed to in writing, software       *
 * distributed under the License is distributed on an "AS IS" BASIS,         *
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  *
 * See the License for the specific language governing permissions and       *
 * limitations under the License.                                            *
 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */

#include <opencv2/flann/flann.hpp>

#include "openbr_internal.h"
#include "openbr/core/common.h"
#include "openbr/core/opencvutils.h"
#include <fstream>

using namespace cv;

namespace br
{

/*!
 * \ingroup transforms
 * \brief Wraps OpenCV kmeans and flann.
 * \author Josh Klontz \cite jklontz
 */
class KMeansTransform : public Transform
{
    Q_OBJECT
    Q_PROPERTY(int kTrain READ get_kTrain WRITE set_kTrain RESET reset_kTrain STORED false)
    Q_PROPERTY(int kSearch READ get_kSearch WRITE set_kSearch RESET reset_kSearch STORED false)
    BR_PROPERTY(int, kTrain, 256)
    BR_PROPERTY(int, kSearch, 1)

    Mat centers;
    mutable QScopedPointer<flann::Index> index;
    mutable QMutex mutex;

    void reindex()
    {
        index.reset(new flann::Index(centers, flann::LinearIndexParams()));
    }

    void train(const TemplateList &data)
    {
        Mat bestLabels;
        const double compactness = kmeans(OpenCVUtils::toMatByRow(data.data()), kTrain, bestLabels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, centers);
        qDebug("KMeans compactness = %f", compactness);
        reindex();
    }

    void project(const Template &src, Template &dst) const
    {
        QMutexLocker locker(&mutex);
        Mat dists, indicies;
        index->knnSearch(src, indicies, dists, kSearch);
        dst = indicies.reshape(1, 1);
    }

    void load(QDataStream &stream)
    {
        stream >> centers;
        reindex();
    }

    void store(QDataStream &stream) const
    {
        stream << centers;
    }
};

BR_REGISTER(Transform, KMeansTransform)

/*!
 * \ingroup transforms
 * \brief K nearest neighbors classifier.
 * \author Josh Klontz \cite jklontz
 */
class KNNTransform : public Transform
{
    Q_OBJECT
    Q_PROPERTY(int k READ get_k WRITE set_k RESET reset_k STORED false)
    Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
    Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false)
    Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false)
    Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
    Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
    Q_PROPERTY(QString galleryName READ get_galleryName WRITE set_galleryName RESET reset_galleryName STORED false)
    BR_PROPERTY(int, k, 1)
    BR_PROPERTY(br::Distance*, distance, NULL)
    BR_PROPERTY(bool, weighted, false)
    BR_PROPERTY(int, numSubjects, 1)
    BR_PROPERTY(QString, inputVariable, "Label")
    BR_PROPERTY(QString, outputVariable, "KNN")
    BR_PROPERTY(QString, galleryName, "")

    TemplateList gallery;

    void train(const TemplateList &data)
    {
        distance->train(data);
        gallery = data;
    }

    void project(const Template &src, Template &dst) const
    {
        QList< QPair<float, int> > sortedScores = Common::Sort(distance->compare(gallery, src), true);

        QStringList subjects;
        for (int i=0; i<numSubjects; i++) {
            QHash<QString, float> votes;
            const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size());
            for (int j=0; j<max; j++)
                votes[gallery[sortedScores[j].second].file.get<QString>(inputVariable)] += (weighted ? sortedScores[j].first : 1);
            subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]);

            // Remove subject from consideration
            if (subjects.size() < numSubjects)
                for (int j=sortedScores.size()-1; j>=0; j--)
                    if (gallery[sortedScores[j].second].file.get<QString>(inputVariable) == subjects.last())
                        sortedScores.removeAt(j);
        }

        dst.file.set(outputVariable, subjects.size() > 1 ? "[" + subjects.join(",") + "]" : subjects.first());
        dst.file.set("Nearest", gallery[sortedScores[0].second].file.name);
    }

    void store(QDataStream &stream) const
    {
        stream << gallery;
    }

    void load(QDataStream &stream)
    {
        stream >> gallery;
    }

    void init()
    {
        if (!galleryName.isEmpty())
            gallery = TemplateList::fromGallery(galleryName);
    }
};

BR_REGISTER(Transform, KNNTransform)

/*!
 * \ingroup transforms
 * \brief Chooses k random points to be centroids.
 * \author Austin Blanton \cite imaus10
 * \see KMeansTransform
 */
class RandomCentroidsTransform : public Transform
{
    Q_OBJECT
    Q_PROPERTY(int kTrain READ get_kTrain WRITE set_kTrain RESET reset_kTrain STORED false)
    Q_PROPERTY(int kSearch READ get_kSearch WRITE set_kSearch RESET reset_kSearch STORED false)
    BR_PROPERTY(int, kTrain, 256)
    BR_PROPERTY(int, kSearch, 1)

    Mat centers;
    mutable QScopedPointer<flann::Index> index;
    mutable QMutex mutex;

    void reindex()
    {
        index.reset(new flann::Index(centers, flann::LinearIndexParams()));
    }

    void train(const TemplateList &data)
    {
        Mat flat = OpenCVUtils::toMatByRow(data.data());
        QList<int> sample = Common::RandSample(kTrain, flat.rows, 0, true);
        foreach (const int &idx, sample)
            centers.push_back(flat.row(idx));
        reindex();
    }

    void project(const Template &src, Template &dst) const
    {
        QMutexLocker locker(&mutex);
        Mat dists, indicies;
        index->knnSearch(src, indicies, dists, kSearch);
        dst = indicies.reshape(1, 1);
    }

    void load(QDataStream &stream)
    {
        stream >> centers;
        reindex();
    }

    void store(QDataStream &stream) const
    {
        stream << centers;
    }
};

BR_REGISTER(Transform, RandomCentroidsTransform)

class RegInitializer : public Initializer
{
    Q_OBJECT

    void initialize() const
    {
        qRegisterMetaType<br::Neighbors>();
    }
};
BR_REGISTER(Initializer, RegInitializer)

class CollectNNTransform : public UntrainableMetaTransform
{
    Q_OBJECT

    Q_PROPERTY(int keep READ get_keep WRITE set_keep RESET reset_keep STORED false)
    BR_PROPERTY(int, keep, 20)

    void project(const Template &src, Template &dst) const
    {
        dst.file = src.file;
        dst.clear();
        dst.m() = cv::Mat();
        Neighbors neighbors;
        for (int i=0; i < src.m().cols;i++) {
            // skip self compares
            if (i == src.file.get<int>("FrameNumber"))
                continue;
            neighbors.append(Neighbor(i, src.m().at<float>(0,i)));
        }
        int actuallyKeep = std::min(keep, neighbors.size());
        std::partial_sort(neighbors.begin(), neighbors.begin()+actuallyKeep, neighbors.end(), compareNeighbors);

        Neighbors selected = neighbors.mid(0, actuallyKeep);
        dst.file.set("neighbors", QVariant::fromValue(selected));
    }
};
BR_REGISTER(Transform, CollectNNTransform)

class LogNNTransform : public TimeVaryingTransform
{
    Q_OBJECT

    Q_PROPERTY(QString fileName READ get_fileName WRITE set_fileName RESET reset_fileName STORED false)
    BR_PROPERTY(QString, fileName, "")

    std::fstream fout;

    void projectUpdate(const Template &src, Template &dst)
    {
        dst = src;

        if (!dst.file.contains("neighbors")) {
            fout << std::endl;
            return;
        }

        Neighbors neighbors = dst.file.get<Neighbors>("neighbors");
        if (neighbors.isEmpty() ) {
            fout << std::endl;
            return;
        }

        QString aLine;
        aLine.append(QString::number(neighbors[0].first)+":"+QString::number(neighbors[0].second));
        for (int i=1; i < neighbors.size();i++)
            aLine.append(","+QString::number(neighbors[i].first)+":"+QString::number(neighbors[i].second));

        fout << qPrintable(aLine) << std::endl;
    }

    void init()
    {
        if (!fileName.isEmpty())
            fout.open(qPrintable(fileName), std::ios_base::out);
    }

    void finalize(TemplateList &output)
    {
        (void) output;
        fout.close();
    }

public:
    LogNNTransform() : TimeVaryingTransform(false, false) {}
};
BR_REGISTER(Transform, LogNNTransform)

} // namespace br

#include "cluster.moc"