Simplify classify

We don't need additional level of indirection and can store category in
distances vector right away.
This commit is contained in:
Kagami Hiiragi
2019-08-18 11:43:46 +03:00
parent 84d557bef3
commit ef56f89511
3 changed files with 8 additions and 20 deletions

View File

@@ -1,9 +1,10 @@
#include <unordered_map>
#include <dlib/graph_utils.h>
#include "classify.h"
int classify(
const std::vector<descriptor>& samples,
const std::unordered_map<int, int>& cats,
const std::vector<int>& cats,
const descriptor& test_sample,
float tolerance
) {
@@ -17,7 +18,7 @@ int classify(
for (const auto& sample : samples) {
float dist = dist_func(sample, test_sample);
if (dist >= tolerance) {
distances.push_back({idx, dist});
distances.push_back({cats[idx], dist});
}
idx++;
}
@@ -33,18 +34,12 @@ int classify(
int len = std::min((int)distances.size(), 10);
std::unordered_map<int, std::pair<int, float>> hits_by_cat;
for (int i = 0; i < len; i++) {
int idx = distances[i].first;
int cat_idx = distances[i].first;
float dist = distances[i].second;
auto cat = cats.find(idx);
if (cat == cats.end())
continue;
int cat_idx = cat->second;
auto hit = hits_by_cat.find(cat_idx);
if (hit == hits_by_cat.end()) {
// printf("1 hit for %d (%d: %f)\n", cat_idx, idx, dist);
hits_by_cat[cat_idx] = {1, dist};
} else {
// printf("+1 hit for %d (%d: %f)\n", cat_idx, idx, dist);
hits_by_cat[cat_idx].first++;
}
}
@@ -60,6 +55,5 @@ int classify(
return hits1 < hits2;
}
);
// printf("Found cat with max hits: %d\n", hit->first); fflush(stdout);
return hit->first;
}

View File

@@ -1,12 +1,10 @@
#pragma once
#include <unordered_map>
typedef dlib::matrix<float,0,1> descriptor;
int classify(
const std::vector<descriptor>& samples,
const std::unordered_map<int, int>& cats,
const std::vector<int>& cats,
const descriptor& test_sample,
float tolerance
);

View File

@@ -87,7 +87,7 @@ public:
return {std::move(rects), std::move(descrs)};
}
void SetSamples(std::vector<descriptor>&& samples, std::unordered_map<int, int>&& cats) {
void SetSamples(std::vector<descriptor>&& samples, std::vector<int>&& cats) {
std::unique_lock<std::shared_mutex> lock(samples_mutex_);
samples_ = std::move(samples);
cats_ = std::move(cats);
@@ -105,7 +105,7 @@ private:
shape_predictor sp_;
anet_type net_;
std::vector<descriptor> samples_;
std::unordered_map<int, int> cats_;
std::vector<int> cats_;
};
// Plain C interface for Go.
@@ -177,11 +177,7 @@ void facerec_set_samples(
descriptor sample = mat(c_samples + i*DESCR_LEN, DESCR_LEN, 1);
samples.push_back(std::move(sample));
}
std::unordered_map<int, int> cats;
cats.reserve(len);
for (int i = 0; i < len; i++) {
cats[i] = c_cats[i];
}
std::vector<int> cats(c_cats, c_cats + len);
cls->SetSamples(std::move(samples), std::move(cats));
}