From ef56f89511818f57c79dc7dc1c99f4546b5f5add Mon Sep 17 00:00:00 2001 From: Kagami Hiiragi Date: Sun, 18 Aug 2019 11:43:46 +0300 Subject: [PATCH] Simplify classify We don't need additional level of indirection and can store category in distances vector right away. --- classify.cc | 14 ++++---------- classify.h | 4 +--- facerec.cc | 10 +++------- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/classify.cc b/classify.cc index 49e6179..d9a914b 100644 --- a/classify.cc +++ b/classify.cc @@ -1,9 +1,10 @@ +#include #include #include "classify.h" int classify( const std::vector& samples, - const std::unordered_map& cats, + const std::vector& cats, const descriptor& test_sample, float tolerance ) { @@ -17,7 +18,7 @@ int classify( for (const auto& sample : samples) { float dist = dist_func(sample, test_sample); if (dist >= tolerance) { - distances.push_back({idx, dist}); + distances.push_back({cats[idx], dist}); } idx++; } @@ -33,18 +34,12 @@ int classify( int len = std::min((int)distances.size(), 10); std::unordered_map> hits_by_cat; for (int i = 0; i < len; i++) { - int idx = distances[i].first; + int cat_idx = distances[i].first; float dist = distances[i].second; - auto cat = cats.find(idx); - if (cat == cats.end()) - continue; - int cat_idx = cat->second; auto hit = hits_by_cat.find(cat_idx); if (hit == hits_by_cat.end()) { - // printf("1 hit for %d (%d: %f)\n", cat_idx, idx, dist); hits_by_cat[cat_idx] = {1, dist}; } else { - // printf("+1 hit for %d (%d: %f)\n", cat_idx, idx, dist); hits_by_cat[cat_idx].first++; } } @@ -60,6 +55,5 @@ int classify( return hits1 < hits2; } ); - // printf("Found cat with max hits: %d\n", hit->first); fflush(stdout); return hit->first; } diff --git a/classify.h b/classify.h index e21fbc0..ddc4642 100644 --- a/classify.h +++ b/classify.h @@ -1,12 +1,10 @@ #pragma once -#include - typedef dlib::matrix descriptor; int classify( const std::vector& samples, - const std::unordered_map& cats, + const std::vector& cats, const descriptor& test_sample, float tolerance ); diff --git a/facerec.cc b/facerec.cc index 144070b..09d11e7 100644 --- a/facerec.cc +++ b/facerec.cc @@ -87,7 +87,7 @@ public: return {std::move(rects), std::move(descrs)}; } - void SetSamples(std::vector&& samples, std::unordered_map&& cats) { + void SetSamples(std::vector&& samples, std::vector&& cats) { std::unique_lock lock(samples_mutex_); samples_ = std::move(samples); cats_ = std::move(cats); @@ -105,7 +105,7 @@ private: shape_predictor sp_; anet_type net_; std::vector samples_; - std::unordered_map cats_; + std::vector cats_; }; // Plain C interface for Go. @@ -177,11 +177,7 @@ void facerec_set_samples( descriptor sample = mat(c_samples + i*DESCR_LEN, DESCR_LEN, 1); samples.push_back(std::move(sample)); } - std::unordered_map cats; - cats.reserve(len); - for (int i = 0; i < len; i++) { - cats[i] = c_cats[i]; - } + std::vector cats(c_cats, c_cats + len); cls->SetSamples(std::move(samples), std::move(cats)); }