Files
go-face/face_test.go
2020-05-13 20:31:38 +03:00

234 lines
5.0 KiB
Go

package face_test
import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
"testing"
"unsafe"
"github.com/Kagami/go-face"
)
var (
rec *face.Recognizer
idolTests = map[string]string{
"elkie.jpg": "Elkie, CLC",
"chaeyoung.jpg": "Chaeyoung, Twice",
"chaeyoung2.jpg": "Chaeyoung, Twice",
"sejeong.jpg": "Sejeong, Gugudan",
"jimin.jpg": "Jimin, AOA",
"jimin2.jpg": "Jimin, AOA",
"jimin4.jpg": "Jimin, AOA",
"meiqi.jpg": "Mei Qi, WJSN",
"chaeyeon.jpg": "Chaeyeon, DIA",
"chaeyeon3.jpg": "Chaeyeon, DIA",
"tzuyu2.jpg": "Tzuyu, Twice",
"nayoung.jpg": "Nayoung, PRISTIN",
"luda2.jpg": "Luda, WJSN",
"joy.jpg": "Joy, Red Velvet",
}
)
type Idol struct {
ID string `json:"id"`
Name string `json:"name"`
BandName string `json:"band_name"`
}
type IdolFace struct {
Descriptor string `json:"descriptor"`
IdolID string `json:"idol_id"`
}
type IdolData struct {
Idols []Idol `json:"idols"`
Faces []IdolFace `json:"faces"`
byID map[string]*Idol
}
type TrainData struct {
samples []face.Descriptor
cats []int32
labels []string
}
func getTPath(fname string) string {
return filepath.Join("testdata", "images", fname)
}
func getIdolData() (idata *IdolData, err error) {
data, err := ioutil.ReadFile(filepath.Join("testdata", "idols.json"))
if err != nil {
return
}
idata = &IdolData{}
err = json.Unmarshal(data, idata)
if err != nil {
return
}
idata.byID = make(map[string]*Idol)
for i, _ := range idata.Idols {
idol := &idata.Idols[i]
idata.byID[idol.ID] = idol
}
return
}
func str2descr(s string) face.Descriptor {
b, err := base64.StdEncoding.DecodeString(s)
if err != nil {
panic(err)
}
return *(*face.Descriptor)(unsafe.Pointer(&b[0]))
}
func getTrainData(idata *IdolData) (tdata *TrainData) {
var samples []face.Descriptor
var cats []int32
var labels []string
var catID int32
var prevIdolID string
catID = -1
for i, _ := range idata.Faces {
iface := &idata.Faces[i]
descriptor := str2descr(iface.Descriptor)
samples = append(samples, descriptor)
if iface.IdolID != prevIdolID {
catID++
labels = append(labels, iface.IdolID)
}
cats = append(cats, catID)
prevIdolID = iface.IdolID
}
tdata = &TrainData{
samples: samples,
cats: cats,
labels: labels,
}
return
}
func recognizeAndClassify(fpath string, tolerance float32) (id int, err error) {
id = -1
f, err := rec.RecognizeSingleFile(fpath)
if err != nil || f == nil {
return
}
if tolerance < 0 {
id = rec.Classify(f.Descriptor)
} else {
id = rec.ClassifyThreshold(f.Descriptor, tolerance)
}
return
}
func TestSerializationError(t *testing.T) {
_, err := face.NewRecognizer("/notexist")
switch err.(type) {
case face.SerializationError:
// skip
default:
t.Fatalf("Wrong error: %v", err)
}
}
func TestInit(t *testing.T) {
var err error
rec, err = face.NewRecognizer(filepath.Join("testdata", "models"))
if err != nil {
t.Fatalf("Can't init face recognizer: %v", err)
}
}
func TestImageLoadError(t *testing.T) {
_, err := rec.Recognize([]byte{1, 2, 3})
switch err.(type) {
case face.ImageLoadError:
// skip
default:
t.Fatalf("Wrong error: %v", err)
}
}
func TestNumFaces(t *testing.T) {
faces, err := rec.RecognizeFile(getTPath("pristin.jpg"))
if err != nil {
t.Fatalf("Can't get faces: %v", err)
}
numFaces := len(faces)
if numFaces != 10 {
t.Fatalf("Wrong number of faces: %d", numFaces)
}
}
func TestEmptyClassify(t *testing.T) {
var sample face.Descriptor
id := rec.Classify(sample)
if id >= 0 {
t.Fatalf("Shouldn't recognize but got %d category", id)
}
}
func TestIdols(t *testing.T) {
idata, err := getIdolData()
if err != nil {
t.Fatalf("Can't get idol data: %v", err)
}
tdata := getTrainData(idata)
rec.SetSamples(tdata.samples, tdata.cats)
for fname, expected := range idolTests {
t.Run(fname, func(t *testing.T) {
names := strings.Split(expected, ", ")
expectedIname := names[0]
expectedBname := names[1]
catID, err := recognizeAndClassify(getTPath(fname), -1)
if err != nil {
t.Fatalf("Can't recognize: %v", err)
}
if catID < 0 {
t.Errorf("%s: expected “%s” but not recognized", fname, expected)
return
}
idolID := tdata.labels[catID]
idol := idata.byID[idolID]
actualIname := idol.Name
actualBname := idol.BandName
if expectedIname != actualIname || expectedBname != actualBname {
actual := fmt.Sprintf("%s, %s", actualIname, actualBname)
t.Errorf("%s: expected “%s” but got “%s”", fname, expected, actual)
}
})
}
}
func TestClassifyThreshold(t *testing.T) {
id, err := recognizeAndClassify(getTPath("nana.jpg"), 0.1)
if err != nil {
t.Fatalf("Can't recognize: %v", err)
}
if id >= 0 {
t.Fatalf("Shouldn't recognize but got %d category", id)
}
id, err = recognizeAndClassify(getTPath("nana.jpg"), 0.8)
if err != nil {
t.Fatalf("Can't recognize: %v", err)
}
if id < 0 {
t.Fatalf("Should have recognized but got %d category", id)
}
}
func TestClose(t *testing.T) {
rec.Close()
}