mirror of
				https://github.com/Kagami/go-face.git
				synced 2025-10-31 10:56:26 +08:00 
			
		
		
		
	Add idol tests
This commit is contained in:
		
							
								
								
									
										177
									
								
								face_test.go
									
									
									
									
									
								
							
							
						
						
									
										177
									
								
								face_test.go
									
									
									
									
									
								
							| @@ -1,16 +1,150 @@ | ||||
| package face | ||||
|  | ||||
| import ( | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| func TestNumFaces(t *testing.T) { | ||||
| 	rec, err := NewRecognizer("testdata") | ||||
| var ( | ||||
| 	rec *Recognizer | ||||
|  | ||||
| 	idolTests = map[string]string{ | ||||
| 		"elkie.jpg":      "Elkie, CLC", | ||||
| 		"chaeyoung.jpg":  "Chaeyoung, Twice", | ||||
| 		"chaeyoung2.jpg": "Chaeyoung, Twice", | ||||
| 		"sejeong.jpg":    "Sejeong, Gugudan", | ||||
| 		"jimin.jpg":      "Jimin, AOA", | ||||
| 		"jimin2.jpg":     "Jimin, AOA", | ||||
| 		"jimin4.jpg":     "Jimin, AOA", | ||||
| 		"meiqi.jpg":      "Mei Qi, WJSN", | ||||
| 		"chaeyeon.jpg":   "Chaeyeon, DIA", | ||||
| 		"chaeyeon3.jpg":  "Chaeyeon, DIA", | ||||
| 		"tzuyu2.jpg":     "Tzuyu, Twice", | ||||
| 		"nayoung.jpg":    "Nayoung, PRISTIN", | ||||
| 		"luda2.jpg":      "Luda, WJSN", | ||||
| 		"joy.jpg":        "Joy, Red Velvet", | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| type Idol struct { | ||||
| 	ID       string `json:"id"` | ||||
| 	Name     string `json:"name"` | ||||
| 	BandName string `json:"band_name"` | ||||
| } | ||||
|  | ||||
| type IdolFace struct { | ||||
| 	Descriptor string `json:"descriptor"` | ||||
| 	IdolID     string `json:"idol_id"` | ||||
| } | ||||
|  | ||||
| type IdolData struct { | ||||
| 	Idols []Idol     `json:"idols"` | ||||
| 	Faces []IdolFace `json:"faces"` | ||||
| 	byID  map[string]*Idol | ||||
| } | ||||
|  | ||||
| type TrainData struct { | ||||
| 	samples []Descriptor | ||||
| 	cats    []int32 | ||||
| 	labels  map[int]string | ||||
| } | ||||
|  | ||||
| func getTPath(fname string) string { | ||||
| 	return filepath.Join("testdata", fname) | ||||
| } | ||||
|  | ||||
| func getIdolData() (idata *IdolData, err error) { | ||||
| 	data, err := ioutil.ReadFile(getTPath("idols.json")) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	idata = &IdolData{} | ||||
| 	err = json.Unmarshal(data, idata) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	idata.byID = make(map[string]*Idol) | ||||
| 	for i, _ := range idata.Idols { | ||||
| 		idol := &idata.Idols[i] | ||||
| 		idata.byID[idol.ID] = idol | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func str2descr(s string) Descriptor { | ||||
| 	b, err := base64.StdEncoding.DecodeString(s) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return *(*Descriptor)(unsafe.Pointer(&b[0])) | ||||
| } | ||||
|  | ||||
| func getTrainData(idata *IdolData) (tdata *TrainData) { | ||||
| 	var samples []Descriptor | ||||
| 	var cats []int32 | ||||
| 	labels := make(map[int]string) | ||||
|  | ||||
| 	var catID int32 | ||||
| 	var prevIdolID string | ||||
| 	catID = -1 | ||||
| 	for i, _ := range idata.Faces { | ||||
| 		iface := &idata.Faces[i] | ||||
| 		descriptor := str2descr(iface.Descriptor) | ||||
| 		samples = append(samples, descriptor) | ||||
| 		if iface.IdolID != prevIdolID { | ||||
| 			catID++ | ||||
| 			labels[int(catID)] = iface.IdolID | ||||
| 		} | ||||
| 		cats = append(cats, catID) | ||||
| 		prevIdolID = iface.IdolID | ||||
| 	} | ||||
|  | ||||
| 	tdata = &TrainData{ | ||||
| 		samples: samples, | ||||
| 		cats:    cats, | ||||
| 		labels:  labels, | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func recognizeFile(fpath string) (catID *int, err error) { | ||||
| 	fd, err := os.Open(fpath) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	imgData, err := ioutil.ReadAll(fd) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 	f, err := rec.RecognizeSingle(imgData) | ||||
| 	if err != nil || f == nil { | ||||
| 		return | ||||
| 	} | ||||
| 	id := rec.Classify(f.Descriptor) | ||||
| 	if id < 0 { | ||||
| 		return | ||||
| 	} | ||||
| 	catID = &id | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func TestInit(t *testing.T) { | ||||
| 	var err error | ||||
| 	rec, err = NewRecognizer("testdata") | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Can't init face recognizer: %v", err) | ||||
| 	} | ||||
| 	defer rec.Close() | ||||
| 	faces, err := rec.RecognizeFile("testdata/pristin.jpg") | ||||
| } | ||||
|  | ||||
| func TestNumFaces(t *testing.T) { | ||||
| 	faces, err := rec.RecognizeFile(getTPath("pristin.jpg")) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Can't get faces: %v", err) | ||||
| 	} | ||||
| @@ -19,3 +153,38 @@ func TestNumFaces(t *testing.T) { | ||||
| 		t.Fatalf("Wrong number of faces: %d", numFaces) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIdols(t *testing.T) { | ||||
| 	idata, err := getIdolData() | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("Can't get idol data: %v", err) | ||||
| 	} | ||||
| 	tdata := getTrainData(idata) | ||||
| 	rec.SetSamples(tdata.samples, tdata.cats) | ||||
|  | ||||
| 	for fname, expected := range idolTests { | ||||
| 		t.Run(fname, func(t *testing.T) { | ||||
| 			names := strings.Split(expected, ", ") | ||||
| 			expectedIname := names[0] | ||||
| 			expectedBname := names[1] | ||||
|  | ||||
| 			catID, err := recognizeFile(getTPath(fname)) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
| 			if catID == nil { | ||||
| 				t.Errorf("%s: expected “%s” but not recognized", fname, expected) | ||||
| 				return | ||||
| 			} | ||||
| 			idolID := tdata.labels[*catID] | ||||
| 			idol := idata.byID[idolID] | ||||
| 			actualIname := idol.Name | ||||
| 			actualBname := idol.BandName | ||||
|  | ||||
| 			if expectedIname != actualIname || expectedBname != actualBname { | ||||
| 				actual := fmt.Sprintf("%s, %s", actualIname, actualBname) | ||||
| 				t.Errorf("%s: expected “%s” but got “%s”", fname, expected, actual) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Kagami Hiiragi
					Kagami Hiiragi