mirror of
https://github.com/Kagami/go-face.git
synced 2025-09-27 03:55:51 +08:00
Simplify classify
We don't need additional level of indirection and can store category in distances vector right away.
This commit is contained in:
14
classify.cc
14
classify.cc
@@ -1,9 +1,10 @@
|
|||||||
|
#include <unordered_map>
|
||||||
#include <dlib/graph_utils.h>
|
#include <dlib/graph_utils.h>
|
||||||
#include "classify.h"
|
#include "classify.h"
|
||||||
|
|
||||||
int classify(
|
int classify(
|
||||||
const std::vector<descriptor>& samples,
|
const std::vector<descriptor>& samples,
|
||||||
const std::unordered_map<int, int>& cats,
|
const std::vector<int>& cats,
|
||||||
const descriptor& test_sample,
|
const descriptor& test_sample,
|
||||||
float tolerance
|
float tolerance
|
||||||
) {
|
) {
|
||||||
@@ -17,7 +18,7 @@ int classify(
|
|||||||
for (const auto& sample : samples) {
|
for (const auto& sample : samples) {
|
||||||
float dist = dist_func(sample, test_sample);
|
float dist = dist_func(sample, test_sample);
|
||||||
if (dist >= tolerance) {
|
if (dist >= tolerance) {
|
||||||
distances.push_back({idx, dist});
|
distances.push_back({cats[idx], dist});
|
||||||
}
|
}
|
||||||
idx++;
|
idx++;
|
||||||
}
|
}
|
||||||
@@ -33,18 +34,12 @@ int classify(
|
|||||||
int len = std::min((int)distances.size(), 10);
|
int len = std::min((int)distances.size(), 10);
|
||||||
std::unordered_map<int, std::pair<int, float>> hits_by_cat;
|
std::unordered_map<int, std::pair<int, float>> hits_by_cat;
|
||||||
for (int i = 0; i < len; i++) {
|
for (int i = 0; i < len; i++) {
|
||||||
int idx = distances[i].first;
|
int cat_idx = distances[i].first;
|
||||||
float dist = distances[i].second;
|
float dist = distances[i].second;
|
||||||
auto cat = cats.find(idx);
|
|
||||||
if (cat == cats.end())
|
|
||||||
continue;
|
|
||||||
int cat_idx = cat->second;
|
|
||||||
auto hit = hits_by_cat.find(cat_idx);
|
auto hit = hits_by_cat.find(cat_idx);
|
||||||
if (hit == hits_by_cat.end()) {
|
if (hit == hits_by_cat.end()) {
|
||||||
// printf("1 hit for %d (%d: %f)\n", cat_idx, idx, dist);
|
|
||||||
hits_by_cat[cat_idx] = {1, dist};
|
hits_by_cat[cat_idx] = {1, dist};
|
||||||
} else {
|
} else {
|
||||||
// printf("+1 hit for %d (%d: %f)\n", cat_idx, idx, dist);
|
|
||||||
hits_by_cat[cat_idx].first++;
|
hits_by_cat[cat_idx].first++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -60,6 +55,5 @@ int classify(
|
|||||||
return hits1 < hits2;
|
return hits1 < hits2;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
// printf("Found cat with max hits: %d\n", hit->first); fflush(stdout);
|
|
||||||
return hit->first;
|
return hit->first;
|
||||||
}
|
}
|
||||||
|
@@ -1,12 +1,10 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
typedef dlib::matrix<float,0,1> descriptor;
|
typedef dlib::matrix<float,0,1> descriptor;
|
||||||
|
|
||||||
int classify(
|
int classify(
|
||||||
const std::vector<descriptor>& samples,
|
const std::vector<descriptor>& samples,
|
||||||
const std::unordered_map<int, int>& cats,
|
const std::vector<int>& cats,
|
||||||
const descriptor& test_sample,
|
const descriptor& test_sample,
|
||||||
float tolerance
|
float tolerance
|
||||||
);
|
);
|
||||||
|
10
facerec.cc
10
facerec.cc
@@ -87,7 +87,7 @@ public:
|
|||||||
return {std::move(rects), std::move(descrs)};
|
return {std::move(rects), std::move(descrs)};
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetSamples(std::vector<descriptor>&& samples, std::unordered_map<int, int>&& cats) {
|
void SetSamples(std::vector<descriptor>&& samples, std::vector<int>&& cats) {
|
||||||
std::unique_lock<std::shared_mutex> lock(samples_mutex_);
|
std::unique_lock<std::shared_mutex> lock(samples_mutex_);
|
||||||
samples_ = std::move(samples);
|
samples_ = std::move(samples);
|
||||||
cats_ = std::move(cats);
|
cats_ = std::move(cats);
|
||||||
@@ -105,7 +105,7 @@ private:
|
|||||||
shape_predictor sp_;
|
shape_predictor sp_;
|
||||||
anet_type net_;
|
anet_type net_;
|
||||||
std::vector<descriptor> samples_;
|
std::vector<descriptor> samples_;
|
||||||
std::unordered_map<int, int> cats_;
|
std::vector<int> cats_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Plain C interface for Go.
|
// Plain C interface for Go.
|
||||||
@@ -177,11 +177,7 @@ void facerec_set_samples(
|
|||||||
descriptor sample = mat(c_samples + i*DESCR_LEN, DESCR_LEN, 1);
|
descriptor sample = mat(c_samples + i*DESCR_LEN, DESCR_LEN, 1);
|
||||||
samples.push_back(std::move(sample));
|
samples.push_back(std::move(sample));
|
||||||
}
|
}
|
||||||
std::unordered_map<int, int> cats;
|
std::vector<int> cats(c_cats, c_cats + len);
|
||||||
cats.reserve(len);
|
|
||||||
for (int i = 0; i < len; i++) {
|
|
||||||
cats[i] = c_cats[i];
|
|
||||||
}
|
|
||||||
cls->SetSamples(std::move(samples), std::move(cats));
|
cls->SetSamples(std::move(samples), std::move(cats));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user