Simplify classify

We don't need additional level of indirection and can store category in
distances vector right away.
This commit is contained in:
Kagami Hiiragi
2019-08-18 11:43:46 +03:00
parent 84d557bef3
commit ef56f89511
3 changed files with 8 additions and 20 deletions

View File

@@ -1,9 +1,10 @@
#include <unordered_map>
#include <dlib/graph_utils.h> #include <dlib/graph_utils.h>
#include "classify.h" #include "classify.h"
int classify( int classify(
const std::vector<descriptor>& samples, const std::vector<descriptor>& samples,
const std::unordered_map<int, int>& cats, const std::vector<int>& cats,
const descriptor& test_sample, const descriptor& test_sample,
float tolerance float tolerance
) { ) {
@@ -17,7 +18,7 @@ int classify(
for (const auto& sample : samples) { for (const auto& sample : samples) {
float dist = dist_func(sample, test_sample); float dist = dist_func(sample, test_sample);
if (dist >= tolerance) { if (dist >= tolerance) {
distances.push_back({idx, dist}); distances.push_back({cats[idx], dist});
} }
idx++; idx++;
} }
@@ -33,18 +34,12 @@ int classify(
int len = std::min((int)distances.size(), 10); int len = std::min((int)distances.size(), 10);
std::unordered_map<int, std::pair<int, float>> hits_by_cat; std::unordered_map<int, std::pair<int, float>> hits_by_cat;
for (int i = 0; i < len; i++) { for (int i = 0; i < len; i++) {
int idx = distances[i].first; int cat_idx = distances[i].first;
float dist = distances[i].second; float dist = distances[i].second;
auto cat = cats.find(idx);
if (cat == cats.end())
continue;
int cat_idx = cat->second;
auto hit = hits_by_cat.find(cat_idx); auto hit = hits_by_cat.find(cat_idx);
if (hit == hits_by_cat.end()) { if (hit == hits_by_cat.end()) {
// printf("1 hit for %d (%d: %f)\n", cat_idx, idx, dist);
hits_by_cat[cat_idx] = {1, dist}; hits_by_cat[cat_idx] = {1, dist};
} else { } else {
// printf("+1 hit for %d (%d: %f)\n", cat_idx, idx, dist);
hits_by_cat[cat_idx].first++; hits_by_cat[cat_idx].first++;
} }
} }
@@ -60,6 +55,5 @@ int classify(
return hits1 < hits2; return hits1 < hits2;
} }
); );
// printf("Found cat with max hits: %d\n", hit->first); fflush(stdout);
return hit->first; return hit->first;
} }

View File

@@ -1,12 +1,10 @@
#pragma once #pragma once
#include <unordered_map>
typedef dlib::matrix<float,0,1> descriptor; typedef dlib::matrix<float,0,1> descriptor;
int classify( int classify(
const std::vector<descriptor>& samples, const std::vector<descriptor>& samples,
const std::unordered_map<int, int>& cats, const std::vector<int>& cats,
const descriptor& test_sample, const descriptor& test_sample,
float tolerance float tolerance
); );

View File

@@ -87,7 +87,7 @@ public:
return {std::move(rects), std::move(descrs)}; return {std::move(rects), std::move(descrs)};
} }
void SetSamples(std::vector<descriptor>&& samples, std::unordered_map<int, int>&& cats) { void SetSamples(std::vector<descriptor>&& samples, std::vector<int>&& cats) {
std::unique_lock<std::shared_mutex> lock(samples_mutex_); std::unique_lock<std::shared_mutex> lock(samples_mutex_);
samples_ = std::move(samples); samples_ = std::move(samples);
cats_ = std::move(cats); cats_ = std::move(cats);
@@ -105,7 +105,7 @@ private:
shape_predictor sp_; shape_predictor sp_;
anet_type net_; anet_type net_;
std::vector<descriptor> samples_; std::vector<descriptor> samples_;
std::unordered_map<int, int> cats_; std::vector<int> cats_;
}; };
// Plain C interface for Go. // Plain C interface for Go.
@@ -177,11 +177,7 @@ void facerec_set_samples(
descriptor sample = mat(c_samples + i*DESCR_LEN, DESCR_LEN, 1); descriptor sample = mat(c_samples + i*DESCR_LEN, DESCR_LEN, 1);
samples.push_back(std::move(sample)); samples.push_back(std::move(sample));
} }
std::unordered_map<int, int> cats; std::vector<int> cats(c_cats, c_cats + len);
cats.reserve(len);
for (int i = 0; i < len; i++) {
cats[i] = c_cats[i];
}
cls->SetSamples(std::move(samples), std::move(cats)); cls->SetSamples(std::move(samples), std::move(cats));
} }