Files
go-face/classify.cc
Kagami Hiiragi 0956e40d34 Fix stupid bug with comparing threshold
It was working backwards.

Fixes #33 #16
2019-08-18 12:30:04 +03:00

60 lines
1.4 KiB
C++

#include <unordered_map>
#include <dlib/graph_utils.h>
#include "classify.h"
int classify(
const std::vector<descriptor>& samples,
const std::vector<int>& cats,
const descriptor& test_sample,
float tolerance
) {
if (samples.size() == 0)
return -1;
std::vector<std::pair<int, float>> distances;
distances.reserve(samples.size());
auto dist_func = dlib::squared_euclidean_distance();
int idx = 0;
for (const auto& sample : samples) {
float dist = dist_func(sample, test_sample);
if (tolerance < 0 || dist <= tolerance) {
distances.push_back({cats[idx], dist});
}
idx++;
}
if (distances.size() == 0)
return -1;
std::sort(
distances.begin(), distances.end(),
[](const auto a, const auto b) { return a.second < b.second; }
);
int len = std::min((int)distances.size(), 10);
std::unordered_map<int, std::pair<int, float>> hits_by_cat;
for (int i = 0; i < len; i++) {
int cat_idx = distances[i].first;
float dist = distances[i].second;
auto hit = hits_by_cat.find(cat_idx);
if (hit == hits_by_cat.end()) {
hits_by_cat[cat_idx] = {1, dist};
} else {
hits_by_cat[cat_idx].first++;
}
}
auto hit = std::max_element(
hits_by_cat.begin(), hits_by_cat.end(),
[](const auto a, const auto b) {
auto hits1 = a.second.first;
auto hits2 = b.second.first;
auto dist1 = a.second.second;
auto dist2 = b.second.second;
if (hits1 == hits2) return dist1 > dist2;
return hits1 < hits2;
}
);
return hit->first;
}