mirror of
https://github.com/bububa/openvision.git
synced 2025-09-26 17:51:13 +08:00
Merge branch 'release/v1.0.1'
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -70,3 +70,5 @@ _testmain.go
|
|||||||
test
|
test
|
||||||
.vim
|
.vim
|
||||||
dist/
|
dist/
|
||||||
|
|
||||||
|
libtorch/
|
||||||
|
@@ -33,6 +33,9 @@ cmake .. # optional -DNCNN_VULKAN=OFF -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COM
|
|||||||
- scrfd [Google Drive](https://drive.google.com/drive/folders/1XPjfsuXGj9rXqAmo1K70BsqWmHvoYQv_?usp=sharing)
|
- scrfd [Google Drive](https://drive.google.com/drive/folders/1XPjfsuXGj9rXqAmo1K70BsqWmHvoYQv_?usp=sharing)
|
||||||
- tracker (for face IOU calculation bettween frames)
|
- tracker (for face IOU calculation bettween frames)
|
||||||
- hopenet (for head pose detection) [Google Drive](https://drive.google.com/drive/folders/1zLam-8s9ZMPDUxUEtNU2F9yFTDRM5fk-?usp=sharing)
|
- hopenet (for head pose detection) [Google Drive](https://drive.google.com/drive/folders/1zLam-8s9ZMPDUxUEtNU2F9yFTDRM5fk-?usp=sharing)
|
||||||
|
- hair (for hair segmentation) [Google Drive](https://drive.google.com/drive/folders/14DOBaFrxTL1k4T1ved5qfRUUziurItT8?usp=sharing)
|
||||||
|
- eye
|
||||||
|
- lenet (eye status detector) [Google Drive](https://drive.google.com/drive/folders/1jaonx6PeXFLA8gBKo4eQGuxsncVnqS7o?usp=sharing)
|
||||||
- pose
|
- pose
|
||||||
- detector (for pose detection/estimation)
|
- detector (for pose detection/estimation)
|
||||||
- ultralight [Google Drive](https://drive.google.com/drive/folders/15b-I5HDyGe2WLb-TO85SJYmnYONvGOKh?usp=sharing)
|
- ultralight [Google Drive](https://drive.google.com/drive/folders/15b-I5HDyGe2WLb-TO85SJYmnYONvGOKh?usp=sharing)
|
||||||
@@ -50,10 +53,14 @@ cmake .. # optional -DNCNN_VULKAN=OFF -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COM
|
|||||||
- nanodet [Google Drive](https://drive.google.com/drive/folders/1ywH7r_clqqA_BAOFSzA92Q0lxJtWlN3z?usp=sharing)
|
- nanodet [Google Drive](https://drive.google.com/drive/folders/1ywH7r_clqqA_BAOFSzA92Q0lxJtWlN3z?usp=sharing)
|
||||||
- pose (for hand pose estimation)
|
- pose (for hand pose estimation)
|
||||||
- handnet [Google Drive](https://drive.google.com/drive/folders/1DsCGmiVaZobbMWRp5Oec8GbIpeg7CsNR?usp=sharing)
|
- handnet [Google Drive](https://drive.google.com/drive/folders/1DsCGmiVaZobbMWRp5Oec8GbIpeg7CsNR?usp=sharing)
|
||||||
|
- pose3d (for 3d handpose detection)
|
||||||
|
- mediapipe [Google Drive](https://drive.google.com/drive/folders/1LsqIGB55dusZJqmP1uhnQUnNE2tLzifp?usp=sharing)
|
||||||
- styletransfer
|
- styletransfer
|
||||||
- animegan2 [Google Drive](https://drive.google.com/drive/folders/1K6ZScENPHVbxupHkwl5WcpG8PPECtD8e?usp=sharing)
|
- animegan2 [Google Drive](https://drive.google.com/drive/folders/1K6ZScENPHVbxupHkwl5WcpG8PPECtD8e?usp=sharing)
|
||||||
- tracker
|
- tracker
|
||||||
- lighttrack [Google Drive](https://drive.google.com/drive/folders/16cxns_xzSOABHn6UcY1OXyf4MFcSSbEf?usp=sharing)
|
- lighttrack [Google Drive](https://drive.google.com/drive/folders/16cxns_xzSOABHn6UcY1OXyf4MFcSSbEf?usp=sharing)
|
||||||
|
- counter
|
||||||
|
- p2pnet [Google Drive](https://drive.google.com/drive/folders/1kmtBsPIS79C3hMAwm_Tv9tAPvJLV9k35?usp=sharing)
|
||||||
- golang binding (github.com/bububa/openvision/go)
|
- golang binding (github.com/bububa/openvision/go)
|
||||||
|
|
||||||
## Reference
|
## Reference
|
||||||
|
3
data/font/.gitignore
vendored
Normal file
3
data/font/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
*
|
||||||
|
*/
|
||||||
|
!.gitignore
|
@@ -45,6 +45,9 @@ make -j 4
|
|||||||
- scrfd [Google Drive](https://drive.google.com/drive/folders/1XPjfsuXGj9rXqAmo1K70BsqWmHvoYQv_?usp=sharing)
|
- scrfd [Google Drive](https://drive.google.com/drive/folders/1XPjfsuXGj9rXqAmo1K70BsqWmHvoYQv_?usp=sharing)
|
||||||
- tracker (for face IOU calculation bettween frames)
|
- tracker (for face IOU calculation bettween frames)
|
||||||
- hopenet (for head pose detection) [Google Drive](https://drive.google.com/drive/folders/1zLam-8s9ZMPDUxUEtNU2F9yFTDRM5fk-?usp=sharing)
|
- hopenet (for head pose detection) [Google Drive](https://drive.google.com/drive/folders/1zLam-8s9ZMPDUxUEtNU2F9yFTDRM5fk-?usp=sharing)
|
||||||
|
- hair (for hair segmentation) [Google Drive](https://drive.google.com/drive/folders/14DOBaFrxTL1k4T1ved5qfRUUziurItT8?usp=sharing)
|
||||||
|
- eye
|
||||||
|
- lenet (eye status detector) [Google Drive](https://drive.google.com/drive/folders/1jaonx6PeXFLA8gBKo4eQGuxsncVnqS7o?usp=sharing)
|
||||||
- pose
|
- pose
|
||||||
- detector (for pose detection/estimation)
|
- detector (for pose detection/estimation)
|
||||||
- ultralight [Google Drive](https://drive.google.com/drive/folders/15b-I5HDyGe2WLb-TO85SJYmnYONvGOKh?usp=sharing)
|
- ultralight [Google Drive](https://drive.google.com/drive/folders/15b-I5HDyGe2WLb-TO85SJYmnYONvGOKh?usp=sharing)
|
||||||
@@ -66,3 +69,5 @@ make -j 4
|
|||||||
- animegan2 [Google Drive](https://drive.google.com/drive/folders/1K6ZScENPHVbxupHkwl5WcpG8PPECtD8e?usp=sharing)
|
- animegan2 [Google Drive](https://drive.google.com/drive/folders/1K6ZScENPHVbxupHkwl5WcpG8PPECtD8e?usp=sharing)
|
||||||
- tracker
|
- tracker
|
||||||
- lighttrack [Google Drive](https://drive.google.com/drive/folders/16cxns_xzSOABHn6UcY1OXyf4MFcSSbEf?usp=sharing)
|
- lighttrack [Google Drive](https://drive.google.com/drive/folders/16cxns_xzSOABHn6UcY1OXyf4MFcSSbEf?usp=sharing)
|
||||||
|
- counter
|
||||||
|
- p2pnet [Google Drive](https://drive.google.com/drive/folders/1kmtBsPIS79C3hMAwm_Tv9tAPvJLV9k35?usp=sharing)
|
||||||
|
2
go/classifier/doc.go
Normal file
2
go/classifier/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package classifier implement different classifiers
|
||||||
|
package classifier
|
40
go/classifier/svm/binary_classifier.go
Normal file
40
go/classifier/svm/binary_classifier.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/classifier/svm_classifier.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// BinaryClassifier represents svm classifier
|
||||||
|
type BinaryClassifier struct {
|
||||||
|
d C.ISVMClassifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBinaryClassifier returns a new BinaryClassifier
|
||||||
|
func NewBinaryClassifier() *BinaryClassifier {
|
||||||
|
return &BinaryClassifier{
|
||||||
|
d: C.new_svm_binary_classifier(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy C.ISVMClassifier
|
||||||
|
func (t *BinaryClassifier) Destroy() {
|
||||||
|
DestroyClassifier(t.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel load model
|
||||||
|
func (t *BinaryClassifier) LoadModel(modelPath string) {
|
||||||
|
LoadClassifierModel(t.d, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predict returns predicted score
|
||||||
|
func (t *BinaryClassifier) Predict(vec []float32) float64 {
|
||||||
|
return Predict(t.d, vec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classify returns classifid scores
|
||||||
|
func (t *BinaryClassifier) Classify(vec []float32) ([]float64, error) {
|
||||||
|
return Classify(t.d, vec)
|
||||||
|
}
|
50
go/classifier/svm/binary_trainer.go
Normal file
50
go/classifier/svm/binary_trainer.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/classifier/svm_trainer.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// BinaryTrainer represents svm trainer
|
||||||
|
type BinaryTrainer struct {
|
||||||
|
d C.ISVMTrainer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBinaryTrainer returns a new BinaryTrainer
|
||||||
|
func NewBinaryTrainer() *BinaryTrainer {
|
||||||
|
return &BinaryTrainer{
|
||||||
|
d: C.new_svm_binary_trainer(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy C.ISVMTrainer
|
||||||
|
func (t *BinaryTrainer) Destroy() {
|
||||||
|
DestroyTrainer(t.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset reset C.ISVMTrainer
|
||||||
|
func (t *BinaryTrainer) Reset() {
|
||||||
|
ResetTrainer(t.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLabels set total labels
|
||||||
|
func (t *BinaryTrainer) Labels(labels int) {
|
||||||
|
SetLabels(t.d, labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFeatures set total features
|
||||||
|
func (t *BinaryTrainer) SetFeatures(feats int) {
|
||||||
|
SetFeatures(t.d, feats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddData add data with label
|
||||||
|
func (t *BinaryTrainer) AddData(labelID int, feats []float32) {
|
||||||
|
AddData(t.d, labelID, feats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train train model
|
||||||
|
func (t *BinaryTrainer) Train(modelPath string) error {
|
||||||
|
return Train(t.d, modelPath)
|
||||||
|
}
|
12
go/classifier/svm/cgo.go
Normal file
12
go/classifier/svm/cgo.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !vulkan
|
||||||
|
// +build !vulkan
|
||||||
|
|
||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
12
go/classifier/svm/cgo_vulkan.go
Normal file
12
go/classifier/svm/cgo_vulkan.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build vulkan
|
||||||
|
// +build vulkan
|
||||||
|
|
||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
58
go/classifier/svm/classifier.go
Normal file
58
go/classifier/svm/classifier.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/classifier/svm_classifier.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
openvision "github.com/bububa/openvision/go"
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Classifier represents svm classifier
|
||||||
|
type Classifier interface {
|
||||||
|
LoadModel(string)
|
||||||
|
Destroy()
|
||||||
|
Predict(vec []float32) float64
|
||||||
|
Classify(vec []float32) (scores []float64, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy C.ISVMClassifier
|
||||||
|
func DestroyClassifier(d C.ISVMClassifier) {
|
||||||
|
C.destroy_svm_classifier(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel load model
|
||||||
|
func LoadClassifierModel(d C.ISVMClassifier, modelPath string) {
|
||||||
|
cPath := C.CString(modelPath)
|
||||||
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
C.svm_classifier_load_model(d, cPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Predict(d C.ISVMClassifier, vec []float32) float64 {
|
||||||
|
cvals := make([]C.float, 0, len(vec))
|
||||||
|
for _, v := range vec {
|
||||||
|
cvals = append(cvals, C.float(v))
|
||||||
|
}
|
||||||
|
score := C.svm_predict(d, &cvals[0])
|
||||||
|
return float64(score)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classify returns class scores
|
||||||
|
func Classify(d C.ISVMClassifier, vec []float32) ([]float64, error) {
|
||||||
|
cvals := make([]C.float, 0, len(vec))
|
||||||
|
for _, v := range vec {
|
||||||
|
cvals = append(cvals, C.float(v))
|
||||||
|
}
|
||||||
|
cScores := common.NewCFloatVector()
|
||||||
|
defer common.FreeCFloatVector(cScores)
|
||||||
|
errCode := C.svm_classify(d, &cvals[0], (*C.FloatVector)(unsafe.Pointer(cScores)))
|
||||||
|
if errCode != 0 {
|
||||||
|
return nil, openvision.ClassifyError(int(errCode))
|
||||||
|
}
|
||||||
|
return common.GoFloatVector(cScores), nil
|
||||||
|
}
|
2
go/classifier/svm/doc.go
Normal file
2
go/classifier/svm/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package svm implement svm classifier
|
||||||
|
package svm
|
40
go/classifier/svm/multiclass_classifier.go
Normal file
40
go/classifier/svm/multiclass_classifier.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/classifier/svm_classifier.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// MultiClassClassifier represents svm classifier
|
||||||
|
type MultiClassClassifier struct {
|
||||||
|
d C.ISVMClassifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMultiClassClassifier returns a new MultiClassClassifier
|
||||||
|
func NewMultiClassClassifier() *MultiClassClassifier {
|
||||||
|
return &MultiClassClassifier{
|
||||||
|
d: C.new_svm_multiclass_classifier(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy C.ISVMClassifier
|
||||||
|
func (t *MultiClassClassifier) Destroy() {
|
||||||
|
DestroyClassifier(t.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel load model
|
||||||
|
func (t *MultiClassClassifier) LoadModel(modelPath string) {
|
||||||
|
LoadClassifierModel(t.d, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Predict returns predicted score
|
||||||
|
func (t *MultiClassClassifier) Predict(vec []float32) float64 {
|
||||||
|
return Predict(t.d, vec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Classify returns classifid scores
|
||||||
|
func (t *MultiClassClassifier) Classify(vec []float32) ([]float64, error) {
|
||||||
|
return Classify(t.d, vec)
|
||||||
|
}
|
50
go/classifier/svm/multiclass_trainer.go
Normal file
50
go/classifier/svm/multiclass_trainer.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/classifier/svm_trainer.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
// MultiClassTrainer represents svm trainer
|
||||||
|
type MultiClassTrainer struct {
|
||||||
|
d C.ISVMTrainer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMultiClassTrainer returns a new MultiClassTrainer
|
||||||
|
func NewMultiClassTrainer() *MultiClassTrainer {
|
||||||
|
return &MultiClassTrainer{
|
||||||
|
d: C.new_svm_multiclass_trainer(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy C.ISVMTrainer
|
||||||
|
func (t *MultiClassTrainer) Destroy() {
|
||||||
|
DestroyTrainer(t.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset reset C.ISVMTrainer
|
||||||
|
func (t *MultiClassTrainer) Reset() {
|
||||||
|
ResetTrainer(t.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLabels set total labels
|
||||||
|
func (t *MultiClassTrainer) SetLabels(labels int) {
|
||||||
|
SetLabels(t.d, labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFeatures set total features
|
||||||
|
func (t *MultiClassTrainer) SetFeatures(feats int) {
|
||||||
|
SetFeatures(t.d, feats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddData add data with label
|
||||||
|
func (t *MultiClassTrainer) AddData(labelID int, feats []float32) {
|
||||||
|
AddData(t.d, labelID, feats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train train model
|
||||||
|
func (t *MultiClassTrainer) Train(modelPath string) error {
|
||||||
|
return Train(t.d, modelPath)
|
||||||
|
}
|
63
go/classifier/svm/trainer.go
Normal file
63
go/classifier/svm/trainer.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package svm
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/classifier/svm_trainer.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
openvision "github.com/bububa/openvision/go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Trainer represents svm trainer
|
||||||
|
type Trainer interface {
|
||||||
|
Reset()
|
||||||
|
Destroy()
|
||||||
|
SetLabels(int)
|
||||||
|
SetFeatures(int)
|
||||||
|
AddData(int, []float32)
|
||||||
|
Train(modelPath string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DestroyTrainer destroy C.ISVMTrainer
|
||||||
|
func DestroyTrainer(d C.ISVMTrainer) {
|
||||||
|
C.destroy_svm_trainer(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetTrainer reset C.ISVMTrainer
|
||||||
|
func ResetTrainer(d C.ISVMTrainer) {
|
||||||
|
C.svm_trainer_reset(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLabels set total labels
|
||||||
|
func SetLabels(d C.ISVMTrainer, labels int) {
|
||||||
|
C.svm_trainer_set_labels(d, C.int(labels))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFeatures set total features
|
||||||
|
func SetFeatures(d C.ISVMTrainer, feats int) {
|
||||||
|
C.svm_trainer_set_features(d, C.int(feats))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddData add data with label
|
||||||
|
func AddData(d C.ISVMTrainer, labelID int, feats []float32) {
|
||||||
|
vec := make([]C.float, 0, len(feats))
|
||||||
|
for _, v := range feats {
|
||||||
|
vec = append(vec, C.float(v))
|
||||||
|
}
|
||||||
|
C.svm_trainer_add_data(d, C.int(labelID), &vec[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Train train model
|
||||||
|
func Train(d C.ISVMTrainer, modelPath string) error {
|
||||||
|
cPath := C.CString(modelPath)
|
||||||
|
defer C.free(unsafe.Pointer(cPath))
|
||||||
|
errCode := C.svm_train(d, cPath)
|
||||||
|
if errCode != 0 {
|
||||||
|
return openvision.TrainingError(int(errCode))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
@@ -38,6 +38,9 @@ func parseHexColor(x string) (r, g, b, a uint32) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
White = "#FFFFFF"
|
||||||
|
Black = "#000000"
|
||||||
|
Gray = "#333333"
|
||||||
Green = "#64DD17"
|
Green = "#64DD17"
|
||||||
Pink = "#E91E63"
|
Pink = "#E91E63"
|
||||||
Red = "#FF1744"
|
Red = "#FF1744"
|
||||||
|
45
go/common/font.go
Normal file
45
go/common/font.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/golang/freetype/truetype"
|
||||||
|
"github.com/llgcode/draw2d"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Font font info
|
||||||
|
type Font struct {
|
||||||
|
// Cache FontCache
|
||||||
|
Cache draw2d.FontCache
|
||||||
|
// Size font size
|
||||||
|
Size float64 `json:"size,omitempty"`
|
||||||
|
// Data font setting
|
||||||
|
Data *draw2d.FontData `json:"data,omitempty"`
|
||||||
|
// Font
|
||||||
|
Font *truetype.Font `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load font from font cache
|
||||||
|
func (f *Font) Load(cache draw2d.FontCache) error {
|
||||||
|
if f.Font != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if f.Data == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cache == nil {
|
||||||
|
return errors.New("missing font cache")
|
||||||
|
}
|
||||||
|
ft, err := cache.Load(*f.Data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.Cache = cache
|
||||||
|
f.Font = ft
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFontCache load font cache
|
||||||
|
func NewFontCache(fontFolder string) *draw2d.SyncFolderFontCache {
|
||||||
|
return draw2d.NewSyncFolderFontCache(fontFolder)
|
||||||
|
}
|
@@ -60,6 +60,8 @@ func GoRect(c *C.Rect, w float64, h float64) Rectangle {
|
|||||||
|
|
||||||
var ZR = Rectangle{}
|
var ZR = Rectangle{}
|
||||||
|
|
||||||
|
var FullRect = Rect(0, 0, 1, 1)
|
||||||
|
|
||||||
// Point represents a Point
|
// Point represents a Point
|
||||||
type Point struct {
|
type Point struct {
|
||||||
X float64
|
X float64
|
||||||
@@ -88,6 +90,9 @@ func NewCPoint2fVector() *C.Point2fVector {
|
|||||||
|
|
||||||
// GoPoint2fVector convert C.Point2fVector to []Point
|
// GoPoint2fVector convert C.Point2fVector to []Point
|
||||||
func GoPoint2fVector(cVector *C.Point2fVector, w float64, h float64) []Point {
|
func GoPoint2fVector(cVector *C.Point2fVector, w float64, h float64) []Point {
|
||||||
|
if cVector == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
l := int(cVector.length)
|
l := int(cVector.length)
|
||||||
ret := make([]Point, 0, l)
|
ret := make([]Point, 0, l)
|
||||||
ptr := unsafe.Pointer(cVector.points)
|
ptr := unsafe.Pointer(cVector.points)
|
||||||
@@ -103,3 +108,52 @@ func FreeCPoint2fVector(c *C.Point2fVector) {
|
|||||||
C.FreePoint2fVector(c)
|
C.FreePoint2fVector(c)
|
||||||
C.free(unsafe.Pointer(c))
|
C.free(unsafe.Pointer(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Point3d represents a 3dPoint
|
||||||
|
type Point3d struct {
|
||||||
|
X float64
|
||||||
|
Y float64
|
||||||
|
Z float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pt3d returns a New Point3d
|
||||||
|
func Pt3d(x, y, z float64) Point3d {
|
||||||
|
return Point3d{x, y, z}
|
||||||
|
}
|
||||||
|
|
||||||
|
var ZP3d = Point3d{}
|
||||||
|
|
||||||
|
// GoPoint3d conver C.Point3d to Point3d
|
||||||
|
func GoPoint3d(c *C.Point3d) Point3d {
|
||||||
|
return Pt3d(
|
||||||
|
float64(c.x),
|
||||||
|
float64(c.y),
|
||||||
|
float64(c.z),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCPoint3dVector retruns C.Point3dVector pointer
|
||||||
|
func NewCPoint3dVector() *C.Point3dVector {
|
||||||
|
return (*C.Point3dVector)(C.malloc(C.sizeof_Point3d))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoPoint3dVector convert C.Point3dVector to []Point3d
|
||||||
|
func GoPoint3dVector(cVector *C.Point3dVector) []Point3d {
|
||||||
|
if cVector == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
l := int(cVector.length)
|
||||||
|
ret := make([]Point3d, 0, l)
|
||||||
|
ptr := unsafe.Pointer(cVector.points)
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
cPoint3d := (*C.Point3d)(unsafe.Pointer(uintptr(ptr) + uintptr(C.sizeof_Point3d*C.int(i))))
|
||||||
|
ret = append(ret, GoPoint3d(cPoint3d))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// FreeCPoint3dVector release C.Point3dVector memory
|
||||||
|
func FreeCPoint3dVector(c *C.Point3dVector) {
|
||||||
|
C.FreePoint3dVector(c)
|
||||||
|
C.free(unsafe.Pointer(c))
|
||||||
|
}
|
||||||
|
@@ -26,6 +26,9 @@ type Image struct {
|
|||||||
// NewImage returns a new Image
|
// NewImage returns a new Image
|
||||||
func NewImage(img image.Image) *Image {
|
func NewImage(img image.Image) *Image {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
|
if img == nil {
|
||||||
|
return &Image{buffer: buf}
|
||||||
|
}
|
||||||
Image2RGB(buf, img)
|
Image2RGB(buf, img)
|
||||||
return &Image{
|
return &Image{
|
||||||
Image: img,
|
Image: img,
|
||||||
@@ -33,6 +36,21 @@ func NewImage(img image.Image) *Image {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (i *Image) Reset() {
|
||||||
|
i.Image = nil
|
||||||
|
if i.buffer != nil {
|
||||||
|
i.buffer.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write write bytes to buffer
|
||||||
|
func (i *Image) Write(b []byte) {
|
||||||
|
if i.buffer == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
i.buffer.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
// Bytes returns image bytes in rgb
|
// Bytes returns image bytes in rgb
|
||||||
func (i Image) Bytes() []byte {
|
func (i Image) Bytes() []byte {
|
||||||
if i.buffer == nil {
|
if i.buffer == nil {
|
||||||
@@ -74,20 +92,23 @@ func NewCImage() *C.Image {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FreeCImage free C.Image
|
||||||
func FreeCImage(c *C.Image) {
|
func FreeCImage(c *C.Image) {
|
||||||
C.FreeImage(c)
|
C.FreeImage(c)
|
||||||
C.free(unsafe.Pointer(c))
|
C.free(unsafe.Pointer(c))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GoImage(c *C.Image) (image.Image, error) {
|
// GoImage returns Image from C.Image
|
||||||
|
func GoImage(c *C.Image, out *Image) {
|
||||||
w := int(c.width)
|
w := int(c.width)
|
||||||
h := int(c.height)
|
h := int(c.height)
|
||||||
channels := int(c.channels)
|
channels := int(c.channels)
|
||||||
data := C.GoBytes(unsafe.Pointer(c.data), C.int(w*h*channels)*C.sizeof_uchar)
|
data := C.GoBytes(unsafe.Pointer(c.data), C.int(w*h*channels)*C.sizeof_uchar)
|
||||||
return NewImageFromBytes(data, w, h, channels)
|
NewImageFromBytes(data, w, h, channels, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewImageFromBytes(data []byte, w int, h int, channels int) (image.Image, error) {
|
// NewImageFromBytes returns Image by []byte
|
||||||
|
func NewImageFromBytes(data []byte, w int, h int, channels int, out *Image) {
|
||||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||||
for y := 0; y < h; y++ {
|
for y := 0; y < h; y++ {
|
||||||
for x := 0; x < w; x++ {
|
for x := 0; x < w; x++ {
|
||||||
@@ -97,9 +118,10 @@ func NewImageFromBytes(data []byte, w int, h int, channels int) (image.Image, er
|
|||||||
alpha = data[pos+3]
|
alpha = data[pos+3]
|
||||||
}
|
}
|
||||||
img.SetRGBA(x, y, color.RGBA{uint8(data[pos]), uint8(data[pos+1]), uint8(data[pos+2]), uint8(alpha)})
|
img.SetRGBA(x, y, color.RGBA{uint8(data[pos]), uint8(data[pos+1]), uint8(data[pos+2]), uint8(alpha)})
|
||||||
|
out.Write([]byte{byte(data[pos]), byte(data[pos+1]), byte(data[pos+2]), byte(alpha)})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return img, nil
|
out.Image = img
|
||||||
}
|
}
|
||||||
|
|
||||||
// Image2RGB write image rgbdata to buffer
|
// Image2RGB write image rgbdata to buffer
|
||||||
@@ -170,3 +192,44 @@ func DrawCircle(gc *draw2dimg.GraphicContext, pt Point, r float64, borderColor s
|
|||||||
gc.Stroke()
|
gc.Stroke()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DrawLabel draw label text to image
|
||||||
|
func DrawLabel(gc *draw2dimg.GraphicContext, font *Font, label string, pt Point, txtColor string, bgColor string, scale float64) {
|
||||||
|
if font == nil || font.Cache == nil || font.Data == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gc.FontCache = font.Cache
|
||||||
|
gc.SetFontData(*font.Data)
|
||||||
|
gc.SetFontSize(font.Size * scale)
|
||||||
|
var (
|
||||||
|
x = float64(pt.X)
|
||||||
|
y = float64(pt.Y)
|
||||||
|
padding = 2.0 * scale
|
||||||
|
)
|
||||||
|
left, top, right, bottom := gc.GetStringBounds(label)
|
||||||
|
height := bottom - top
|
||||||
|
width := right - left
|
||||||
|
if bgColor != "" {
|
||||||
|
gc.SetFillColor(ColorFromHex(bgColor))
|
||||||
|
draw2dkit.Rectangle(gc, x, y, x+width+padding*2, y+height+padding*2)
|
||||||
|
gc.Fill()
|
||||||
|
}
|
||||||
|
gc.SetFillColor(ColorFromHex(txtColor))
|
||||||
|
gc.FillStringAt(label, x-left+padding, y-top+padding)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DrawLabelInWidth draw label text to image in width restrict
|
||||||
|
func DrawLabelInWidth(gc *draw2dimg.GraphicContext, font *Font, label string, pt Point, txtColor string, bgColor string, boundWidth float64) {
|
||||||
|
if font == nil || font.Cache == nil || font.Data == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
gc.FontCache = font.Cache
|
||||||
|
gc.SetFontData(*font.Data)
|
||||||
|
gc.SetFontSize(font.Size)
|
||||||
|
left, _, right, _ := gc.GetStringBounds(label)
|
||||||
|
padding := 2.0
|
||||||
|
width := right - left
|
||||||
|
fontWidth := width + padding*2
|
||||||
|
scale := boundWidth / fontWidth
|
||||||
|
DrawLabel(gc, font, label, pt, txtColor, bgColor, scale)
|
||||||
|
}
|
||||||
|
@@ -20,6 +20,8 @@ type ObjectInfo struct {
|
|||||||
Rect Rectangle
|
Rect Rectangle
|
||||||
// Points keypoints
|
// Points keypoints
|
||||||
Keypoints []Keypoint
|
Keypoints []Keypoint
|
||||||
|
// Name
|
||||||
|
Name string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoObjectInfo convert C.ObjectInfo to go type
|
// GoObjectInfo convert C.ObjectInfo to go type
|
||||||
|
57
go/common/palmobject.go
Normal file
57
go/common/palmobject.go
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/common/common.h"
|
||||||
|
#include "openvision/hand/pose3d.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PalmObject
|
||||||
|
type PalmObject struct {
|
||||||
|
Name string
|
||||||
|
Score float64
|
||||||
|
Rotation float64
|
||||||
|
RectPoints []Point
|
||||||
|
Landmarks []Point
|
||||||
|
Skeleton []Point
|
||||||
|
Skeleton3d []Point3d
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCPalmObjectVector returns *C.PalmObjectVector
|
||||||
|
func NewCPalmObjectVector() *C.PalmObjectVector {
|
||||||
|
return (*C.PalmObjectVector)(C.malloc(C.sizeof_PalmObjectVector))
|
||||||
|
}
|
||||||
|
|
||||||
|
// FreeCPalmObjectVector release *C.PalmObjectVector memory
|
||||||
|
func FreeCPalmObjectVector(p *C.PalmObjectVector) {
|
||||||
|
C.FreePalmObjectVector(p)
|
||||||
|
C.free(unsafe.Pointer(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GoPalmObject convert C.PalmObject to Go type
|
||||||
|
func GoPalmObject(cObj *C.PalmObject, w float64, h float64) PalmObject {
|
||||||
|
return PalmObject{
|
||||||
|
Score: float64(cObj.score),
|
||||||
|
Rotation: float64(cObj.rotation),
|
||||||
|
RectPoints: GoPoint2fVector(cObj.rect, w, h),
|
||||||
|
Landmarks: GoPoint2fVector(cObj.landmarks, w, h),
|
||||||
|
Skeleton: GoPoint2fVector(cObj.skeleton, w, h),
|
||||||
|
Skeleton3d: GoPoint3dVector(cObj.skeleton3d),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func GoPalmObjectVector(c *C.PalmObjectVector, w float64, h float64) []PalmObject {
|
||||||
|
l := int(c.length)
|
||||||
|
ret := make([]PalmObject, 0, l)
|
||||||
|
ptr := unsafe.Pointer(c.items)
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
cObj := (*C.PalmObject)(unsafe.Pointer(uintptr(ptr) + uintptr(C.sizeof_PalmObject*C.int(i))))
|
||||||
|
ret = append(ret, GoPalmObject(cObj, w, h))
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
12
go/counter/cgo.go
Normal file
12
go/counter/cgo.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !vulkan
|
||||||
|
// +build !vulkan
|
||||||
|
|
||||||
|
package counter
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
11
go/counter/cgo_vulkan.go
Normal file
11
go/counter/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build vulkan
|
||||||
|
|
||||||
|
package counter
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
40
go/counter/counter.go
Normal file
40
go/counter/counter.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package counter
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/common/common.h"
|
||||||
|
#include "openvision/counter/counter.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
openvision "github.com/bububa/openvision/go"
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Counter represents Object Counter interface
|
||||||
|
type Counter interface {
|
||||||
|
common.Estimator
|
||||||
|
CrowdCount(img *common.Image) ([]common.Keypoint, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CrowdCount returns object counter
|
||||||
|
func CrowdCount(d Counter, img *common.Image) ([]common.Keypoint, error) {
|
||||||
|
imgWidth := img.WidthF64()
|
||||||
|
imgHeight := img.HeightF64()
|
||||||
|
data := img.Bytes()
|
||||||
|
ptsC := common.NewCKeypointVector()
|
||||||
|
defer common.FreeCKeypointVector(ptsC)
|
||||||
|
errCode := C.crowd_count(
|
||||||
|
(C.ICounter)(d.Pointer()),
|
||||||
|
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||||
|
C.int(imgWidth),
|
||||||
|
C.int(imgHeight),
|
||||||
|
(*C.KeypointVector)(unsafe.Pointer(ptsC)))
|
||||||
|
if errCode != 0 {
|
||||||
|
return nil, openvision.CounterError(int(errCode))
|
||||||
|
}
|
||||||
|
return common.GoKeypointVector(ptsC, imgWidth, imgHeight), nil
|
||||||
|
}
|
2
go/counter/doc.go
Normal file
2
go/counter/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package counter include object counter
|
||||||
|
package counter
|
45
go/counter/p2pnet.go
Normal file
45
go/counter/p2pnet.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package counter
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/counter/counter.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// P2PNet represents p2pnet counter
|
||||||
|
type P2PNet struct {
|
||||||
|
d C.ICounter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewP2PNet returns a new P2PNet
|
||||||
|
func NewP2PNet() *P2PNet {
|
||||||
|
return &P2PNet{
|
||||||
|
d: C.new_p2pnet_crowd_counter(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy free tracker
|
||||||
|
func (d *P2PNet) Destroy() {
|
||||||
|
common.DestroyEstimator(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pointer implement Estimator interface
|
||||||
|
func (d *P2PNet) Pointer() unsafe.Pointer {
|
||||||
|
return unsafe.Pointer(d.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel load model for detecter
|
||||||
|
func (d *P2PNet) LoadModel(modelPath string) error {
|
||||||
|
return common.EstimatorLoadModel(d, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CrowdCount implement Object Counter interface
|
||||||
|
func (d *P2PNet) CrowdCount(img *common.Image) ([]common.Keypoint, error) {
|
||||||
|
return CrowdCount(d, img)
|
||||||
|
}
|
24
go/error.go
24
go/error.go
@@ -56,6 +56,12 @@ var (
|
|||||||
Message: "detect head pose failed",
|
Message: "detect head pose failed",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
HairMattingError = func(code int) Error {
|
||||||
|
return Error{
|
||||||
|
Code: code,
|
||||||
|
Message: "hair matting failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
DetectHandError = func(code int) Error {
|
DetectHandError = func(code int) Error {
|
||||||
return Error{
|
return Error{
|
||||||
Code: code,
|
Code: code,
|
||||||
@@ -74,10 +80,28 @@ var (
|
|||||||
Message: "object tracker error",
|
Message: "object tracker error",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
CounterError = func(code int) Error {
|
||||||
|
return Error{
|
||||||
|
Code: code,
|
||||||
|
Message: "object counter error",
|
||||||
|
}
|
||||||
|
}
|
||||||
RealsrError = func(code int) Error {
|
RealsrError = func(code int) Error {
|
||||||
return Error{
|
return Error{
|
||||||
Code: code,
|
Code: code,
|
||||||
Message: "super-resolution process error",
|
Message: "super-resolution process error",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TrainingError = func(code int) Error {
|
||||||
|
return Error{
|
||||||
|
Code: code,
|
||||||
|
Message: "training process failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ClassifyError = func(code int) Error {
|
||||||
|
return Error{
|
||||||
|
Code: code,
|
||||||
|
Message: "classify process failed",
|
||||||
|
}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
@@ -50,8 +50,9 @@ func align(d detecter.Detecter, a *aligner.Aligner, imgPath string, filename str
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
aligned := common.NewImage(nil)
|
||||||
for idx, face := range faces {
|
for idx, face := range faces {
|
||||||
aligned, err := a.Align(img, face)
|
err := a.Align(img, face, aligned)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
92
go/examples/counter/main.go
Normal file
92
go/examples/counter/main.go
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"image"
|
||||||
|
"image/jpeg"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
"github.com/bububa/openvision/go/counter"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wd, _ := os.Getwd()
|
||||||
|
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||||
|
imgPath := filepath.Join(dataPath, "./images")
|
||||||
|
modelPath := filepath.Join(dataPath, "./models")
|
||||||
|
common.CreateGPUInstance()
|
||||||
|
defer common.DestroyGPUInstance()
|
||||||
|
cpuCores := common.GetBigCPUCount()
|
||||||
|
common.SetOMPThreads(cpuCores)
|
||||||
|
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||||
|
d := p2pnet(modelPath)
|
||||||
|
defer d.Destroy()
|
||||||
|
common.SetEstimatorThreads(d, cpuCores)
|
||||||
|
crowdCount(d, imgPath, "congested2.jpg")
|
||||||
|
}
|
||||||
|
|
||||||
|
func p2pnet(modelPath string) counter.Counter {
|
||||||
|
modelPath = filepath.Join(modelPath, "p2pnet")
|
||||||
|
d := counter.NewP2PNet()
|
||||||
|
if err := d.LoadModel(modelPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func crowdCount(d counter.Counter, imgPath string, filename string) {
|
||||||
|
inPath := filepath.Join(imgPath, filename)
|
||||||
|
imgLoaded, err := loadImage(inPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("load image failed,", err)
|
||||||
|
}
|
||||||
|
img := common.NewImage(imgLoaded)
|
||||||
|
pts, err := d.CrowdCount(img)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
log.Printf("count: %d\n", len(pts))
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadImage(filePath string) (image.Image, error) {
|
||||||
|
fn, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer fn.Close()
|
||||||
|
img, _, err := image.Decode(fn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveImage(img image.Image, filePath string) error {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if err := jpeg.Encode(buf, img, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fn, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fn.Close()
|
||||||
|
fn.Write(buf.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanPath(wd string, path string) string {
|
||||||
|
usr, _ := user.Current()
|
||||||
|
dir := usr.HomeDir
|
||||||
|
if path == "~" {
|
||||||
|
return dir
|
||||||
|
} else if strings.HasPrefix(path, "~/") {
|
||||||
|
return filepath.Join(dir, path[2:])
|
||||||
|
}
|
||||||
|
return filepath.Join(wd, path)
|
||||||
|
}
|
@@ -9,10 +9,14 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/llgcode/draw2d"
|
||||||
|
|
||||||
"github.com/bububa/openvision/go/common"
|
"github.com/bububa/openvision/go/common"
|
||||||
"github.com/bububa/openvision/go/face/detecter"
|
"github.com/bububa/openvision/go/face/detecter"
|
||||||
|
"github.com/bububa/openvision/go/face/drawer"
|
||||||
facedrawer "github.com/bububa/openvision/go/face/drawer"
|
facedrawer "github.com/bububa/openvision/go/face/drawer"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,16 +25,35 @@ func main() {
|
|||||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||||
imgPath := filepath.Join(dataPath, "./images")
|
imgPath := filepath.Join(dataPath, "./images")
|
||||||
modelPath := filepath.Join(dataPath, "./models")
|
modelPath := filepath.Join(dataPath, "./models")
|
||||||
|
fontPath := filepath.Join(dataPath, "./font")
|
||||||
common.CreateGPUInstance()
|
common.CreateGPUInstance()
|
||||||
defer common.DestroyGPUInstance()
|
defer common.DestroyGPUInstance()
|
||||||
cpuCores := common.GetBigCPUCount()
|
cpuCores := common.GetBigCPUCount()
|
||||||
common.SetOMPThreads(cpuCores)
|
common.SetOMPThreads(cpuCores)
|
||||||
log.Printf("CPU big cores:%d\n", cpuCores)
|
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||||
test_detect(imgPath, modelPath, cpuCores)
|
test_detect(imgPath, modelPath, fontPath, cpuCores)
|
||||||
test_mask(imgPath, modelPath, cpuCores)
|
test_mask(imgPath, modelPath, fontPath, cpuCores)
|
||||||
}
|
}
|
||||||
|
|
||||||
func test_detect(imgPath string, modelPath string, threads int) {
|
func load_font(fontPath string) *common.Font {
|
||||||
|
fontCache := common.NewFontCache(fontPath)
|
||||||
|
fnt := &common.Font{
|
||||||
|
Size: 9,
|
||||||
|
Data: &draw2d.FontData{
|
||||||
|
Name: "NotoSansCJKsc",
|
||||||
|
//Name: "Roboto",
|
||||||
|
Family: draw2d.FontFamilySans,
|
||||||
|
Style: draw2d.FontStyleNormal,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := fnt.Load(fontCache); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return fnt
|
||||||
|
}
|
||||||
|
func test_detect(imgPath string, modelPath string, fontPath string, threads int) {
|
||||||
|
fnt := load_font(fontPath)
|
||||||
|
drawer := facedrawer.New(drawer.WithFont(fnt))
|
||||||
for idx, d := range []detecter.Detecter{
|
for idx, d := range []detecter.Detecter{
|
||||||
retinaface(modelPath),
|
retinaface(modelPath),
|
||||||
centerface(modelPath),
|
centerface(modelPath),
|
||||||
@@ -39,16 +62,22 @@ func test_detect(imgPath string, modelPath string, threads int) {
|
|||||||
scrfd(modelPath),
|
scrfd(modelPath),
|
||||||
} {
|
} {
|
||||||
common.SetEstimatorThreads(d, threads)
|
common.SetEstimatorThreads(d, threads)
|
||||||
detect(d, imgPath, idx, "4.jpg", false)
|
detect(d, drawer, imgPath, idx, "4.jpg")
|
||||||
d.Destroy()
|
d.Destroy()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func test_mask(imgPath string, modelPath string, threads int) {
|
func test_mask(imgPath string, modelPath string, fontPath string, threads int) {
|
||||||
|
fnt := load_font(fontPath)
|
||||||
|
drawer := facedrawer.New(
|
||||||
|
facedrawer.WithBorderColor(common.Red),
|
||||||
|
facedrawer.WithMaskColor(common.Green),
|
||||||
|
facedrawer.WithFont(fnt),
|
||||||
|
)
|
||||||
d := anticonv(modelPath)
|
d := anticonv(modelPath)
|
||||||
common.SetEstimatorThreads(d, threads)
|
common.SetEstimatorThreads(d, threads)
|
||||||
defer d.Destroy()
|
defer d.Destroy()
|
||||||
detect(d, imgPath, 0, "mask3.jpg", true)
|
detect(d, drawer, imgPath, 0, "mask3.jpg")
|
||||||
}
|
}
|
||||||
|
|
||||||
func retinaface(modelPath string) detecter.Detecter {
|
func retinaface(modelPath string) detecter.Detecter {
|
||||||
@@ -105,7 +134,7 @@ func anticonv(modelPath string) detecter.Detecter {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func detect(d detecter.Detecter, imgPath string, idx int, filename string, mask bool) {
|
func detect(d detecter.Detecter, drawer *facedrawer.Drawer, imgPath string, idx int, filename string) {
|
||||||
inPath := filepath.Join(imgPath, filename)
|
inPath := filepath.Join(imgPath, filename)
|
||||||
img, err := loadImage(inPath)
|
img, err := loadImage(inPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -116,18 +145,11 @@ func detect(d detecter.Detecter, imgPath string, idx int, filename string, mask
|
|||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%d-%s", idx, filename))
|
for idx, face := range faces {
|
||||||
|
faces[idx].Label = strconv.FormatFloat(float64(face.Score), 'f', 4, 64)
|
||||||
var drawer *facedrawer.Drawer
|
|
||||||
if mask {
|
|
||||||
drawer = facedrawer.New(
|
|
||||||
facedrawer.WithBorderColor(common.Red),
|
|
||||||
facedrawer.WithMaskColor(common.Green),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
drawer = facedrawer.New()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%d-%s", idx, filename))
|
||||||
out := drawer.Draw(img, faces)
|
out := drawer.Draw(img, faces)
|
||||||
|
|
||||||
if err := saveImage(out, outPath); err != nil {
|
if err := saveImage(out, outPath); err != nil {
|
||||||
|
115
go/examples/eye/main.go
Normal file
115
go/examples/eye/main.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/jpeg"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
"github.com/bububa/openvision/go/face/detecter"
|
||||||
|
"github.com/bububa/openvision/go/face/eye"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wd, _ := os.Getwd()
|
||||||
|
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||||
|
imgPath := filepath.Join(dataPath, "./images")
|
||||||
|
modelPath := filepath.Join(dataPath, "./models")
|
||||||
|
common.CreateGPUInstance()
|
||||||
|
defer common.DestroyGPUInstance()
|
||||||
|
cpuCores := common.GetBigCPUCount()
|
||||||
|
common.SetOMPThreads(cpuCores)
|
||||||
|
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||||
|
d := retinaface(modelPath)
|
||||||
|
defer d.Destroy()
|
||||||
|
common.SetEstimatorThreads(d, cpuCores)
|
||||||
|
e := lenet(modelPath)
|
||||||
|
defer e.Destroy()
|
||||||
|
common.SetEstimatorThreads(e, cpuCores)
|
||||||
|
for _, fn := range []string{"eye-open.jpg", "eye-close.jpg", "eye-half.jpg"} {
|
||||||
|
detect(d, e, imgPath, fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func retinaface(modelPath string) detecter.Detecter {
|
||||||
|
modelPath = filepath.Join(modelPath, "fd")
|
||||||
|
d := detecter.NewRetinaFace()
|
||||||
|
if err := d.LoadModel(modelPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func lenet(modelPath string) eye.Detecter {
|
||||||
|
modelPath = filepath.Join(modelPath, "eye/lenet")
|
||||||
|
d := eye.NewLenet()
|
||||||
|
if err := d.LoadModel(modelPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func detect(d detecter.Detecter, e eye.Detecter, imgPath string, filename string) {
|
||||||
|
inPath := filepath.Join(imgPath, filename)
|
||||||
|
imgLoaded, err := loadImage(inPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("load image failed,", err)
|
||||||
|
}
|
||||||
|
img := common.NewImage(imgLoaded)
|
||||||
|
faces, err := d.Detect(img)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
for _, face := range faces {
|
||||||
|
rect := face.Rect
|
||||||
|
closed, err := e.IsClosed(img, rect)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
fmt.Printf("fn: %s, closed: %+v\n", filename, closed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadImage(filePath string) (image.Image, error) {
|
||||||
|
fn, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer fn.Close()
|
||||||
|
img, _, err := image.Decode(fn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveImage(img image.Image, filePath string) error {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if err := jpeg.Encode(buf, img, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fn, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fn.Close()
|
||||||
|
fn.Write(buf.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanPath(wd string, path string) string {
|
||||||
|
usr, _ := user.Current()
|
||||||
|
dir := usr.HomeDir
|
||||||
|
if path == "~" {
|
||||||
|
return dir
|
||||||
|
} else if strings.HasPrefix(path, "~/") {
|
||||||
|
return filepath.Join(dir, path[2:])
|
||||||
|
}
|
||||||
|
return filepath.Join(wd, path)
|
||||||
|
}
|
94
go/examples/hair/main.go
Normal file
94
go/examples/hair/main.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"image"
|
||||||
|
"image/jpeg"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
"github.com/bububa/openvision/go/face/hair"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
wd, _ := os.Getwd()
|
||||||
|
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||||
|
imgPath := filepath.Join(dataPath, "./images")
|
||||||
|
modelPath := filepath.Join(dataPath, "./models")
|
||||||
|
common.CreateGPUInstance()
|
||||||
|
defer common.DestroyGPUInstance()
|
||||||
|
d := estimator(modelPath)
|
||||||
|
defer d.Destroy()
|
||||||
|
matting(d, imgPath, "hair1.jpg")
|
||||||
|
}
|
||||||
|
|
||||||
|
func estimator(modelPath string) *hair.Hair {
|
||||||
|
modelPath = filepath.Join(modelPath, "hair")
|
||||||
|
d := hair.NewHair()
|
||||||
|
if err := d.LoadModel(modelPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func matting(d *hair.Hair, imgPath string, filename string) {
|
||||||
|
inPath := filepath.Join(imgPath, filename)
|
||||||
|
imgLoaded, err := loadImage(inPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("load image failed,", err)
|
||||||
|
}
|
||||||
|
img := common.NewImage(imgLoaded)
|
||||||
|
out := common.NewImage(nil)
|
||||||
|
if err := d.Matting(img, out); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("hair-matting-%s", filename))
|
||||||
|
|
||||||
|
if err := saveImage(out, outPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadImage(filePath string) (image.Image, error) {
|
||||||
|
fn, err := os.Open(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer fn.Close()
|
||||||
|
img, _, err := image.Decode(fn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return img, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func saveImage(img image.Image, filePath string) error {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
if err := jpeg.Encode(buf, img, nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fn, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fn.Close()
|
||||||
|
fn.Write(buf.Bytes())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanPath(wd string, path string) string {
|
||||||
|
usr, _ := user.Current()
|
||||||
|
dir := usr.HomeDir
|
||||||
|
if path == "~" {
|
||||||
|
return dir
|
||||||
|
} else if strings.HasPrefix(path, "~/") {
|
||||||
|
return filepath.Join(dir, path[2:])
|
||||||
|
}
|
||||||
|
return filepath.Join(wd, path)
|
||||||
|
}
|
@@ -15,6 +15,8 @@ import (
|
|||||||
"github.com/bububa/openvision/go/hand/detecter"
|
"github.com/bububa/openvision/go/hand/detecter"
|
||||||
handdrawer "github.com/bububa/openvision/go/hand/drawer"
|
handdrawer "github.com/bububa/openvision/go/hand/drawer"
|
||||||
"github.com/bububa/openvision/go/hand/pose"
|
"github.com/bububa/openvision/go/hand/pose"
|
||||||
|
"github.com/bububa/openvision/go/hand/pose3d"
|
||||||
|
"github.com/llgcode/draw2d"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -22,22 +24,26 @@ func main() {
|
|||||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||||
imgPath := filepath.Join(dataPath, "./images")
|
imgPath := filepath.Join(dataPath, "./images")
|
||||||
modelPath := filepath.Join(dataPath, "./models")
|
modelPath := filepath.Join(dataPath, "./models")
|
||||||
|
fontPath := filepath.Join(dataPath, "./font")
|
||||||
common.CreateGPUInstance()
|
common.CreateGPUInstance()
|
||||||
defer common.DestroyGPUInstance()
|
defer common.DestroyGPUInstance()
|
||||||
cpuCores := common.GetBigCPUCount()
|
cpuCores := common.GetBigCPUCount()
|
||||||
common.SetOMPThreads(cpuCores)
|
common.SetOMPThreads(cpuCores)
|
||||||
log.Printf("CPU big cores:%d\n", cpuCores)
|
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||||
estimator := handpose(modelPath)
|
// estimator := handpose(modelPath)
|
||||||
defer estimator.Destroy()
|
// defer estimator.Destroy()
|
||||||
common.SetEstimatorThreads(estimator, cpuCores)
|
// common.SetEstimatorThreads(estimator, cpuCores)
|
||||||
for idx, d := range []detecter.Detecter{
|
// for idx, d := range []detecter.Detecter{
|
||||||
yolox(modelPath),
|
// yolox(modelPath),
|
||||||
nanodet(modelPath),
|
// nanodet(modelPath),
|
||||||
} {
|
// } {
|
||||||
defer d.Destroy()
|
// defer d.Destroy()
|
||||||
common.SetEstimatorThreads(d, cpuCores)
|
// common.SetEstimatorThreads(d, cpuCores)
|
||||||
detect(d, estimator, imgPath, "hand1.jpg", idx)
|
// detect(d, estimator, imgPath, "hand2.jpg", idx)
|
||||||
}
|
// }
|
||||||
|
d3d := mediapipe(modelPath)
|
||||||
|
detect3d(d3d, imgPath, fontPath, "hand1.jpg")
|
||||||
|
detect3d(d3d, imgPath, fontPath, "hand2.jpg")
|
||||||
}
|
}
|
||||||
|
|
||||||
func yolox(modelPath string) detecter.Detecter {
|
func yolox(modelPath string) detecter.Detecter {
|
||||||
@@ -67,6 +73,16 @@ func handpose(modelPath string) pose.Estimator {
|
|||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mediapipe(modelPath string) *pose3d.Mediapipe {
|
||||||
|
palmPath := filepath.Join(modelPath, "mediapipe/palm/full")
|
||||||
|
handPath := filepath.Join(modelPath, "mediapipe/hand/full")
|
||||||
|
d := pose3d.NewMediapipe()
|
||||||
|
if err := d.LoadModel(palmPath, handPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
func detect(d detecter.Detecter, e pose.Estimator, imgPath string, filename string, idx int) {
|
func detect(d detecter.Detecter, e pose.Estimator, imgPath string, filename string, idx int) {
|
||||||
inPath := filepath.Join(imgPath, filename)
|
inPath := filepath.Join(imgPath, filename)
|
||||||
imgSrc, err := loadImage(inPath)
|
imgSrc, err := loadImage(inPath)
|
||||||
@@ -104,6 +120,37 @@ func detect(d detecter.Detecter, e pose.Estimator, imgPath string, filename stri
|
|||||||
if err := saveImage(out, outPath); err != nil {
|
if err := saveImage(out, outPath); err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func detect3d(d *pose3d.Mediapipe, imgPath string, fontPath string, filename string) {
|
||||||
|
inPath := filepath.Join(imgPath, filename)
|
||||||
|
imgSrc, err := loadImage(inPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("load image failed,", err)
|
||||||
|
}
|
||||||
|
img := common.NewImage(imgSrc)
|
||||||
|
rois, err := d.Detect(img)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
// log.Printf("%+v\n", rois)
|
||||||
|
fnt := load_font(fontPath)
|
||||||
|
drawer := handdrawer.New(handdrawer.WithFont(fnt))
|
||||||
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("pose3d-hand-%s", filename))
|
||||||
|
out := drawer.DrawPalm(img, rois)
|
||||||
|
|
||||||
|
if err := saveImage(out, outPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, roi := range rois {
|
||||||
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("pose3d-palm3d-%d-%s", idx, filename))
|
||||||
|
out := drawer.DrawPalm3D(roi, 400, "#442519")
|
||||||
|
|
||||||
|
if err := saveImage(out, outPath); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -144,3 +191,20 @@ func cleanPath(wd string, path string) string {
|
|||||||
}
|
}
|
||||||
return filepath.Join(wd, path)
|
return filepath.Join(wd, path)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func load_font(fontPath string) *common.Font {
|
||||||
|
fontCache := common.NewFontCache(fontPath)
|
||||||
|
fnt := &common.Font{
|
||||||
|
Size: 9,
|
||||||
|
Data: &draw2d.FontData{
|
||||||
|
Name: "NotoSansCJKsc",
|
||||||
|
//Name: "Roboto",
|
||||||
|
Family: draw2d.FontFamilySans,
|
||||||
|
Style: draw2d.FontStyleNormal,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := fnt.Load(fontCache); err != nil {
|
||||||
|
log.Fatalln(err)
|
||||||
|
}
|
||||||
|
return fnt
|
||||||
|
}
|
||||||
|
@@ -77,8 +77,8 @@ func videomatting(seg segmentor.Segmentor, imgPath string, filename string, idx
|
|||||||
log.Fatalln("load image failed,", err)
|
log.Fatalln("load image failed,", err)
|
||||||
}
|
}
|
||||||
img := common.NewImage(imgLoaded)
|
img := common.NewImage(imgLoaded)
|
||||||
out, err := seg.Matting(img)
|
out := common.NewImage(nil)
|
||||||
if err != nil {
|
if err := seg.Matting(img, out); err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
outPath := filepath.Join(imgPath, "./results/videomatting", fmt.Sprintf("%d.jpeg", idx))
|
outPath := filepath.Join(imgPath, "./results/videomatting", fmt.Sprintf("%d.jpeg", idx))
|
||||||
@@ -95,8 +95,8 @@ func matting(seg segmentor.Segmentor, imgPath string, filename string, idx int)
|
|||||||
log.Fatalln("load image failed,", err)
|
log.Fatalln("load image failed,", err)
|
||||||
}
|
}
|
||||||
img := common.NewImage(imgLoaded)
|
img := common.NewImage(imgLoaded)
|
||||||
out, err := seg.Matting(img)
|
out := common.NewImage(nil)
|
||||||
if err != nil {
|
if err := seg.Matting(img, out); err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("poseseg-matting-%d-%s", idx, filename))
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("poseseg-matting-%d-%s", idx, filename))
|
||||||
@@ -119,8 +119,8 @@ func merge(seg segmentor.Segmentor, imgPath string, filename string, bgFilename
|
|||||||
log.Fatalln("load bg image failed,", err)
|
log.Fatalln("load bg image failed,", err)
|
||||||
}
|
}
|
||||||
bg := common.NewImage(bgLoaded)
|
bg := common.NewImage(bgLoaded)
|
||||||
out, err := seg.Merge(img, bg)
|
out := common.NewImage(nil)
|
||||||
if err != nil {
|
if err := seg.Merge(img, bg, out); err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("poseseg-merge-%d-%s", idx, filename))
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("poseseg-merge-%d-%s", idx, filename))
|
||||||
|
@@ -49,15 +49,14 @@ func transform(transfer styletransfer.StyleTransfer, imgPath string, filename st
|
|||||||
log.Fatalln("load image failed,", err)
|
log.Fatalln("load image failed,", err)
|
||||||
}
|
}
|
||||||
img := common.NewImage(imgLoaded)
|
img := common.NewImage(imgLoaded)
|
||||||
out, err := transfer.Transform(img)
|
out := common.NewImage(nil)
|
||||||
if err != nil {
|
if err := transfer.Transform(img, out); err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%s-%s", modelName, filename))
|
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%s-%s", modelName, filename))
|
||||||
if err := saveImage(out, outPath); err != nil {
|
if err := saveImage(out, outPath); err != nil {
|
||||||
log.Fatalln(err)
|
log.Fatalln(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadImage(filePath string) (image.Image, error) {
|
func loadImage(filePath string) (image.Image, error) {
|
||||||
|
@@ -8,7 +8,6 @@ package aligner
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
openvision "github.com/bububa/openvision/go"
|
openvision "github.com/bububa/openvision/go"
|
||||||
@@ -39,7 +38,7 @@ func (a *Aligner) SetThreads(n int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Align face
|
// Align face
|
||||||
func (a *Aligner) Align(img *common.Image, faceInfo face.FaceInfo) (image.Image, error) {
|
func (a *Aligner) Align(img *common.Image, faceInfo face.FaceInfo, out *common.Image) error {
|
||||||
imgWidth := img.WidthF64()
|
imgWidth := img.WidthF64()
|
||||||
imgHeight := img.HeightF64()
|
imgHeight := img.HeightF64()
|
||||||
data := img.Bytes()
|
data := img.Bytes()
|
||||||
@@ -61,7 +60,8 @@ func (a *Aligner) Align(img *common.Image, faceInfo face.FaceInfo) (image.Image,
|
|||||||
(*C.Image)(unsafe.Pointer(outImgC)),
|
(*C.Image)(unsafe.Pointer(outImgC)),
|
||||||
)
|
)
|
||||||
if errCode != 0 {
|
if errCode != 0 {
|
||||||
return nil, openvision.AlignFaceError(int(errCode))
|
return openvision.AlignFaceError(int(errCode))
|
||||||
}
|
}
|
||||||
return common.GoImage(outImgC)
|
common.GoImage(outImgC, out)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -17,4 +17,6 @@ const (
|
|||||||
DefaultKeypointStrokeWidth = 2
|
DefaultKeypointStrokeWidth = 2
|
||||||
// DefaultInvalidBorderColor default drawer invalid border color
|
// DefaultInvalidBorderColor default drawer invalid border color
|
||||||
DefaultInvalidBorderColor = common.Red
|
DefaultInvalidBorderColor = common.Red
|
||||||
|
// DefaultLabelColor default label color
|
||||||
|
DefaultLabelColor = common.White
|
||||||
)
|
)
|
||||||
|
@@ -26,6 +26,10 @@ type Drawer struct {
|
|||||||
MaskColor string
|
MaskColor string
|
||||||
// InvalidBorderColor
|
// InvalidBorderColor
|
||||||
InvalidBorderColor string
|
InvalidBorderColor string
|
||||||
|
// LabelColor string
|
||||||
|
LabelColor string
|
||||||
|
// Font
|
||||||
|
Font *common.Font
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new Drawer
|
// New returns a new Drawer
|
||||||
@@ -38,6 +42,7 @@ func New(options ...Option) *Drawer {
|
|||||||
KeypointRadius: DefaultKeypointRadius,
|
KeypointRadius: DefaultKeypointRadius,
|
||||||
InvalidBorderColor: DefaultInvalidBorderColor,
|
InvalidBorderColor: DefaultInvalidBorderColor,
|
||||||
MaskColor: DefaultBorderColor,
|
MaskColor: DefaultBorderColor,
|
||||||
|
LabelColor: DefaultLabelColor,
|
||||||
}
|
}
|
||||||
for _, opt := range options {
|
for _, opt := range options {
|
||||||
opt.apply(d)
|
opt.apply(d)
|
||||||
@@ -69,6 +74,9 @@ func (d *Drawer) Draw(img image.Image, faces []face.FaceInfo) image.Image {
|
|||||||
for _, pt := range face.Keypoints {
|
for _, pt := range face.Keypoints {
|
||||||
common.DrawCircle(gc, common.Pt(pt.X*imgW, pt.Y*imgH), d.KeypointRadius, d.KeypointColor, "", d.KeypointStrokeWidth)
|
common.DrawCircle(gc, common.Pt(pt.X*imgW, pt.Y*imgH), d.KeypointRadius, d.KeypointColor, "", d.KeypointStrokeWidth)
|
||||||
}
|
}
|
||||||
|
if face.Label != "" {
|
||||||
|
common.DrawLabelInWidth(gc, d.Font, face.Label, common.Pt(rect.X, rect.MaxY()), d.LabelColor, borderColor, rect.Width)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
package drawer
|
package drawer
|
||||||
|
|
||||||
|
import "github.com/bububa/openvision/go/common"
|
||||||
|
|
||||||
// Option represents Drawer option interface
|
// Option represents Drawer option interface
|
||||||
type Option interface {
|
type Option interface {
|
||||||
apply(*Drawer)
|
apply(*Drawer)
|
||||||
@@ -59,3 +61,17 @@ func WithMaskColor(color string) Option {
|
|||||||
d.MaskColor = color
|
d.MaskColor = color
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithLabelColor set Drawer LabelColor
|
||||||
|
func WithLabelColor(color string) Option {
|
||||||
|
return optionFunc(func(d *Drawer) {
|
||||||
|
d.LabelColor = color
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFont set Drawer Font
|
||||||
|
func WithFont(font *common.Font) Option {
|
||||||
|
return optionFunc(func(d *Drawer) {
|
||||||
|
d.Font = font
|
||||||
|
})
|
||||||
|
}
|
||||||
|
11
go/face/eye/cgo.go
Normal file
11
go/face/eye/cgo.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build !vulkan
|
||||||
|
|
||||||
|
package eye
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
11
go/face/eye/cgo_vulkan.go
Normal file
11
go/face/eye/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build vulkan
|
||||||
|
|
||||||
|
package eye
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
44
go/face/eye/detecter.go
Normal file
44
go/face/eye/detecter.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package eye
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/common/common.h"
|
||||||
|
#include "openvision/face/eye.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
openvision "github.com/bububa/openvision/go"
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Detecter represents Eye Detector interface
|
||||||
|
type Detecter interface {
|
||||||
|
common.Estimator
|
||||||
|
IsClosed(img *common.Image, face common.Rectangle) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed check whether eyes are closed
|
||||||
|
func IsClosed(r Detecter, img *common.Image, faceRect common.Rectangle) (bool, error) {
|
||||||
|
imgWidth := img.WidthF64()
|
||||||
|
imgHeight := img.HeightF64()
|
||||||
|
data := img.Bytes()
|
||||||
|
scoresC := common.NewCFloatVector()
|
||||||
|
defer common.FreeCFloatVector(scoresC)
|
||||||
|
CRect := faceRect.CRect(imgWidth, imgHeight)
|
||||||
|
errCode := C.eye_status(
|
||||||
|
(C.IEye)(r.Pointer()),
|
||||||
|
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||||
|
C.int(imgWidth), C.int(imgHeight),
|
||||||
|
(*C.Rect)(unsafe.Pointer(CRect)),
|
||||||
|
(*C.FloatVector)(unsafe.Pointer(scoresC)),
|
||||||
|
)
|
||||||
|
C.free(unsafe.Pointer(CRect))
|
||||||
|
if errCode != 0 {
|
||||||
|
return false, openvision.RecognizeFaceError(int(errCode))
|
||||||
|
}
|
||||||
|
scores := common.GoFloatVector(scoresC)
|
||||||
|
return len(scores) > 0 && scores[0] == 1, nil
|
||||||
|
}
|
2
go/face/eye/doc.go
Normal file
2
go/face/eye/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package eye include eye status detector
|
||||||
|
package eye
|
45
go/face/eye/lenet.go
Normal file
45
go/face/eye/lenet.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package eye
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/face/eye.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Lenet represents Lenet detecter
|
||||||
|
type Lenet struct {
|
||||||
|
d C.IEye
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetLenet returns a new Lenet detecter
|
||||||
|
func NewLenet() *Lenet {
|
||||||
|
return &Lenet{
|
||||||
|
d: C.new_lenet_eye(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pointer implement Estimator interface
|
||||||
|
func (d *Lenet) Pointer() unsafe.Pointer {
|
||||||
|
return unsafe.Pointer(d.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel implement Recognizer interface
|
||||||
|
func (d *Lenet) LoadModel(modelPath string) error {
|
||||||
|
return common.EstimatorLoadModel(d, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy implement Recognizer interface
|
||||||
|
func (d *Lenet) Destroy() {
|
||||||
|
common.DestroyEstimator(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed implement Eye Detecter interface
|
||||||
|
func (d *Lenet) IsClosed(img *common.Image, faceRect common.Rectangle) (bool, error) {
|
||||||
|
return IsClosed(d, img, faceRect)
|
||||||
|
}
|
@@ -22,6 +22,8 @@ type FaceInfo struct {
|
|||||||
Keypoints [5]common.Point
|
Keypoints [5]common.Point
|
||||||
// Mask has mask or not
|
// Mask has mask or not
|
||||||
Mask bool
|
Mask bool
|
||||||
|
// Label
|
||||||
|
Label string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GoFaceInfo convert c FaceInfo to go type
|
// GoFaceInfo convert c FaceInfo to go type
|
||||||
|
11
go/face/hair/cgo.go
Normal file
11
go/face/hair/cgo.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build !vulkan
|
||||||
|
|
||||||
|
package hair
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
11
go/face/hair/cgo_vulkan.go
Normal file
11
go/face/hair/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build vulkan
|
||||||
|
|
||||||
|
package hair
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
2
go/face/hair/doc.go
Normal file
2
go/face/hair/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package hair include hair segmentation
|
||||||
|
package hair
|
61
go/face/hair/hair.go
Normal file
61
go/face/hair/hair.go
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
package hair
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/face/hair.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
openvision "github.com/bububa/openvision/go"
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Hair represents Hair segmentor
|
||||||
|
type Hair struct {
|
||||||
|
d C.IHair
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHair returns a new Hair
|
||||||
|
func NewHair() *Hair {
|
||||||
|
return &Hair{
|
||||||
|
d: C.new_hair(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pointer implement Estimator interface
|
||||||
|
func (h *Hair) Pointer() unsafe.Pointer {
|
||||||
|
return unsafe.Pointer(h.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadModel load detecter model
|
||||||
|
func (h *Hair) LoadModel(modelPath string) error {
|
||||||
|
return common.EstimatorLoadModel(h, modelPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Destroy destroy C.IHair
|
||||||
|
func (h *Hair) Destroy() {
|
||||||
|
common.DestroyEstimator(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Matting returns hair matting image
|
||||||
|
func (h *Hair) Matting(img *common.Image, out *common.Image) error {
|
||||||
|
imgWidth := img.WidthF64()
|
||||||
|
imgHeight := img.HeightF64()
|
||||||
|
data := img.Bytes()
|
||||||
|
outImgC := common.NewCImage()
|
||||||
|
defer common.FreeCImage(outImgC)
|
||||||
|
errCode := C.hair_matting(
|
||||||
|
(C.IHair)(h.Pointer()),
|
||||||
|
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||||
|
C.int(imgWidth), C.int(imgHeight),
|
||||||
|
(*C.Image)(unsafe.Pointer(outImgC)),
|
||||||
|
)
|
||||||
|
if errCode != 0 {
|
||||||
|
return openvision.HairMattingError(int(errCode))
|
||||||
|
}
|
||||||
|
common.GoImage(outImgC, out)
|
||||||
|
return nil
|
||||||
|
}
|
@@ -34,7 +34,6 @@ func (h *Hopenet) Pointer() unsafe.Pointer {
|
|||||||
// LoadModel load detecter model
|
// LoadModel load detecter model
|
||||||
func (h *Hopenet) LoadModel(modelPath string) error {
|
func (h *Hopenet) LoadModel(modelPath string) error {
|
||||||
return common.EstimatorLoadModel(h, modelPath)
|
return common.EstimatorLoadModel(h, modelPath)
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destroy destroy C.IHopeNet
|
// Destroy destroy C.IHopeNet
|
||||||
|
@@ -7,71 +7,16 @@ import (
|
|||||||
const (
|
const (
|
||||||
// DefaultBorderColor default drawer border color
|
// DefaultBorderColor default drawer border color
|
||||||
DefaultBorderColor = common.Green
|
DefaultBorderColor = common.Green
|
||||||
|
// DefaultKeypointColor default drawer keypoint color
|
||||||
|
DefaultKeypointColor = common.Pink
|
||||||
// DefaultBorderStrokeWidth default drawer border stroke width
|
// DefaultBorderStrokeWidth default drawer border stroke width
|
||||||
DefaultBorderStrokeWidth = 3
|
DefaultBorderStrokeWidth = 3
|
||||||
// DefaultKeypointRadius default drawer keypoint radius
|
// DefaultKeypointRadius default drawer keypoint radius
|
||||||
DefaultKeypointRadius = 3
|
DefaultKeypointRadius = 3
|
||||||
// DefaultKeypointStrokeWidth default drawer keypoint stroke width
|
// DefaultKeypointStrokeWidth default drawer keypoint stroke width
|
||||||
DefaultKeypointStrokeWidth = 1
|
DefaultKeypointStrokeWidth = 1
|
||||||
)
|
// DefaultLabelColor default label color
|
||||||
|
DefaultLabelColor = common.White
|
||||||
// CocoPart coco part define
|
|
||||||
type CocoPart = int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// CocoPartNose nose
|
|
||||||
CocoPartNose CocoPart = iota
|
|
||||||
// CocoPartLEye left eye
|
|
||||||
CocoPartLEye
|
|
||||||
// CocoPartREye right eye
|
|
||||||
CocoPartREye
|
|
||||||
// CocoPartLEar left ear
|
|
||||||
CocoPartLEar
|
|
||||||
// CocoPartREar right ear
|
|
||||||
CocoPartREar
|
|
||||||
// CocoPartLShoulder left sholder
|
|
||||||
CocoPartLShoulder
|
|
||||||
// CocoPartRShoulder right sholder
|
|
||||||
CocoPartRShoulder
|
|
||||||
// CocoPartLElbow left elbow
|
|
||||||
CocoPartLElbow
|
|
||||||
// CocoPartRElbow right elbow
|
|
||||||
CocoPartRElbow
|
|
||||||
// CocoPartLWrist left wrist
|
|
||||||
CocoPartLWrist
|
|
||||||
// CocoPartRWrist right wrist
|
|
||||||
CocoPartRWrist
|
|
||||||
// CocoPartLHip left hip
|
|
||||||
CocoPartLHip
|
|
||||||
// CocoPartRHip right hip
|
|
||||||
CocoPartRHip
|
|
||||||
// CocoPartLKnee left knee
|
|
||||||
CocoPartLKnee
|
|
||||||
// CocoPartRKnee right knee
|
|
||||||
CocoPartRKnee
|
|
||||||
// CocoPartRAnkle right ankle
|
|
||||||
CocoPartRAnkle
|
|
||||||
// CocoPartLAnkle left ankle
|
|
||||||
CocoPartLAnkle
|
|
||||||
// CocoPartNeck neck
|
|
||||||
CocoPartNeck
|
|
||||||
// CocoPartBackground background
|
|
||||||
CocoPartBackground
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// CocoPair represents joints pair
|
|
||||||
CocoPair = [16][2]CocoPart{
|
|
||||||
{0, 1}, {1, 3}, {0, 2}, {2, 4}, {5, 6}, {5, 7}, {7, 9}, {6, 8}, {8, 10}, {5, 11}, {6, 12}, {11, 12}, {11, 13}, {12, 14}, {13, 15}, {14, 16},
|
|
||||||
}
|
|
||||||
// CocoColors represents color for coco parts
|
|
||||||
CocoColors = [17]string{
|
|
||||||
"#ff0000", "#ff5500", "#ffaa00", "#ffff00",
|
|
||||||
"#aaff00", "#55ff00", "#00ff00", "#00ff55", "#00ffaa",
|
|
||||||
"#00ffff", "#00aaff", "#0055ff",
|
|
||||||
"#0000ff", "#aa00ff", "#ff00ff",
|
|
||||||
"#ff00aa", "#ff0055",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@@ -2,8 +2,12 @@ package drawer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"image"
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/llgcode/draw2d"
|
||||||
"github.com/llgcode/draw2d/draw2dimg"
|
"github.com/llgcode/draw2d/draw2dimg"
|
||||||
|
"github.com/llgcode/draw2d/draw2dkit"
|
||||||
|
|
||||||
"github.com/bububa/openvision/go/common"
|
"github.com/bububa/openvision/go/common"
|
||||||
)
|
)
|
||||||
@@ -18,6 +22,12 @@ type Drawer struct {
|
|||||||
KeypointStrokeWidth float64
|
KeypointStrokeWidth float64
|
||||||
// KeypointRadius represents keypoints circle radius
|
// KeypointRadius represents keypoints circle radius
|
||||||
KeypointRadius float64
|
KeypointRadius float64
|
||||||
|
// KeypointColor represents keypoint color
|
||||||
|
KeypointColor string
|
||||||
|
// LabelColor string
|
||||||
|
LabelColor string
|
||||||
|
// Font
|
||||||
|
Font *common.Font
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new Drawer
|
// New returns a new Drawer
|
||||||
@@ -27,6 +37,8 @@ func New(options ...Option) *Drawer {
|
|||||||
BorderStrokeWidth: DefaultBorderStrokeWidth,
|
BorderStrokeWidth: DefaultBorderStrokeWidth,
|
||||||
KeypointStrokeWidth: DefaultKeypointStrokeWidth,
|
KeypointStrokeWidth: DefaultKeypointStrokeWidth,
|
||||||
KeypointRadius: DefaultKeypointRadius,
|
KeypointRadius: DefaultKeypointRadius,
|
||||||
|
KeypointColor: DefaultKeypointColor,
|
||||||
|
LabelColor: DefaultLabelColor,
|
||||||
}
|
}
|
||||||
for _, opt := range options {
|
for _, opt := range options {
|
||||||
opt.apply(d)
|
opt.apply(d)
|
||||||
@@ -42,15 +54,15 @@ func (d *Drawer) Draw(img image.Image, rois []common.ObjectInfo, drawBorder bool
|
|||||||
gc := draw2dimg.NewGraphicContext(out)
|
gc := draw2dimg.NewGraphicContext(out)
|
||||||
gc.DrawImage(img)
|
gc.DrawImage(img)
|
||||||
for _, roi := range rois {
|
for _, roi := range rois {
|
||||||
|
rect := common.Rect(
|
||||||
|
roi.Rect.X*imgW,
|
||||||
|
roi.Rect.Y*imgH,
|
||||||
|
roi.Rect.Width*imgW,
|
||||||
|
roi.Rect.Height*imgH,
|
||||||
|
)
|
||||||
|
borderColor := d.BorderColor
|
||||||
if drawBorder {
|
if drawBorder {
|
||||||
// draw rect
|
// draw rect
|
||||||
rect := common.Rect(
|
|
||||||
roi.Rect.X*imgW,
|
|
||||||
roi.Rect.Y*imgH,
|
|
||||||
roi.Rect.Width*imgW,
|
|
||||||
roi.Rect.Height*imgH,
|
|
||||||
)
|
|
||||||
borderColor := d.BorderColor
|
|
||||||
common.DrawRectangle(gc, rect, borderColor, "", d.BorderStrokeWidth)
|
common.DrawRectangle(gc, rect, borderColor, "", d.BorderStrokeWidth)
|
||||||
}
|
}
|
||||||
l := len(roi.Keypoints)
|
l := len(roi.Keypoints)
|
||||||
@@ -95,6 +107,127 @@ func (d *Drawer) Draw(img image.Image, rois []common.ObjectInfo, drawBorder bool
|
|||||||
poseColor := PoseColors[colorIdx]
|
poseColor := PoseColors[colorIdx]
|
||||||
common.DrawCircle(gc, common.Pt(pt.Point.X*imgW, pt.Point.Y*imgH), d.KeypointRadius, poseColor, "", d.KeypointStrokeWidth)
|
common.DrawCircle(gc, common.Pt(pt.Point.X*imgW, pt.Point.Y*imgH), d.KeypointRadius, poseColor, "", d.KeypointStrokeWidth)
|
||||||
}
|
}
|
||||||
|
// draw name
|
||||||
|
if roi.Name != "" {
|
||||||
|
common.DrawLabelInWidth(gc, d.Font, roi.Name, common.Pt(rect.X, rect.MaxY()), d.LabelColor, borderColor, rect.Width)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// DrawPalm draw PalmObject
|
||||||
|
func (d *Drawer) DrawPalm(img image.Image, rois []common.PalmObject) image.Image {
|
||||||
|
imgW := float64(img.Bounds().Dx())
|
||||||
|
imgH := float64(img.Bounds().Dy())
|
||||||
|
out := image.NewRGBA(img.Bounds())
|
||||||
|
gc := draw2dimg.NewGraphicContext(out)
|
||||||
|
gc.DrawImage(img)
|
||||||
|
for _, roi := range rois {
|
||||||
|
gc.SetLineWidth(d.BorderStrokeWidth)
|
||||||
|
gc.SetStrokeColor(common.ColorFromHex(d.BorderColor))
|
||||||
|
gc.BeginPath()
|
||||||
|
for idx, pt := range roi.RectPoints {
|
||||||
|
gc.MoveTo(pt.X*imgW, pt.Y*imgH)
|
||||||
|
if idx == 3 {
|
||||||
|
gc.LineTo(roi.RectPoints[0].X*imgW, roi.RectPoints[0].Y*imgH)
|
||||||
|
} else {
|
||||||
|
gc.LineTo(roi.RectPoints[idx+1].X*imgW, roi.RectPoints[idx+1].Y*imgH)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
gc.Close()
|
||||||
|
gc.Stroke()
|
||||||
|
|
||||||
|
l := len(roi.Skeleton)
|
||||||
|
if l == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// draw skeleton
|
||||||
|
for idx := range roi.Skeleton[:l-1] {
|
||||||
|
var (
|
||||||
|
p0 common.Point
|
||||||
|
p1 common.Point
|
||||||
|
poseColor = PoseColors[idx/4]
|
||||||
|
)
|
||||||
|
gc.SetStrokeColor(common.ColorFromHex(poseColor))
|
||||||
|
if idx == 5 || idx == 9 || idx == 13 || idx == 17 {
|
||||||
|
p0 = roi.Skeleton[0]
|
||||||
|
p1 = roi.Skeleton[idx]
|
||||||
|
gc.BeginPath()
|
||||||
|
gc.MoveTo(p0.X*imgW, p0.Y*imgH)
|
||||||
|
gc.LineTo(p1.X*imgW, p1.Y*imgH)
|
||||||
|
gc.Close()
|
||||||
|
gc.Stroke()
|
||||||
|
} else if idx == 4 || idx == 8 || idx == 12 || idx == 16 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p0 = roi.Skeleton[idx]
|
||||||
|
p1 = roi.Skeleton[idx+1]
|
||||||
|
gc.BeginPath()
|
||||||
|
gc.MoveTo(p0.X*imgW, p0.Y*imgH)
|
||||||
|
gc.LineTo(p1.X*imgW, p1.Y*imgH)
|
||||||
|
gc.Close()
|
||||||
|
gc.Stroke()
|
||||||
|
}
|
||||||
|
for _, pt := range roi.Landmarks {
|
||||||
|
common.DrawCircle(gc, common.Pt(pt.X*imgW, pt.Y*imgH), d.KeypointRadius, d.KeypointColor, "", d.KeypointStrokeWidth)
|
||||||
|
}
|
||||||
|
// draw name
|
||||||
|
if roi.Name != "" {
|
||||||
|
deltaX := (roi.RectPoints[2].X - roi.RectPoints[3].X) * imgW
|
||||||
|
deltaY := (roi.RectPoints[2].Y - roi.RectPoints[3].Y) * imgH
|
||||||
|
width := math.Sqrt(math.Abs(deltaX*deltaX) + math.Abs(deltaY*deltaY))
|
||||||
|
metrix := draw2d.NewRotationMatrix(roi.Rotation)
|
||||||
|
ptX, ptY := metrix.InverseTransformPoint(roi.RectPoints[3].X*imgW, roi.RectPoints[3].Y*imgH)
|
||||||
|
gc.Save()
|
||||||
|
gc.Rotate(roi.Rotation)
|
||||||
|
common.DrawLabelInWidth(gc, d.Font, roi.Name, common.Pt(ptX, ptY), d.LabelColor, d.BorderColor, width)
|
||||||
|
gc.Restore()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// DrawPalm3D draw 3d PalmObject
|
||||||
|
func (d *Drawer) DrawPalm3D(roi common.PalmObject, size float64, bg string) image.Image {
|
||||||
|
out := image.NewRGBA(image.Rect(0, 0, int(size), int(size)))
|
||||||
|
gc := draw2dimg.NewGraphicContext(out)
|
||||||
|
l := len(roi.Skeleton3d)
|
||||||
|
if l == 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
if bg != "" {
|
||||||
|
bgColor := common.ColorFromHex(bg)
|
||||||
|
gc.SetFillColor(bgColor)
|
||||||
|
draw2dkit.Rectangle(gc, 0, 0, size, size)
|
||||||
|
gc.Fill()
|
||||||
|
gc.SetFillColor(color.Transparent)
|
||||||
|
}
|
||||||
|
// draw skeleton3d
|
||||||
|
for idx := range roi.Skeleton3d[:l-1] {
|
||||||
|
var (
|
||||||
|
p0 common.Point3d
|
||||||
|
p1 common.Point3d
|
||||||
|
poseColor = PoseColors[idx/4]
|
||||||
|
)
|
||||||
|
gc.SetStrokeColor(common.ColorFromHex(poseColor))
|
||||||
|
if idx == 5 || idx == 9 || idx == 13 || idx == 17 {
|
||||||
|
p0 = roi.Skeleton3d[0]
|
||||||
|
p1 = roi.Skeleton3d[idx]
|
||||||
|
gc.BeginPath()
|
||||||
|
gc.MoveTo(p0.X*size, p0.Y*size)
|
||||||
|
gc.LineTo(p1.X*size, p1.Y*size)
|
||||||
|
gc.Close()
|
||||||
|
gc.Stroke()
|
||||||
|
} else if idx == 4 || idx == 8 || idx == 12 || idx == 16 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
p0 = roi.Skeleton3d[idx]
|
||||||
|
p1 = roi.Skeleton3d[idx+1]
|
||||||
|
gc.BeginPath()
|
||||||
|
gc.MoveTo(p0.X*size, p0.Y*size)
|
||||||
|
gc.LineTo(p1.X*size, p1.Y*size)
|
||||||
|
gc.Close()
|
||||||
|
gc.Stroke()
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
@@ -1,5 +1,9 @@
|
|||||||
package drawer
|
package drawer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
// Option represents Drawer option interface
|
// Option represents Drawer option interface
|
||||||
type Option interface {
|
type Option interface {
|
||||||
apply(*Drawer)
|
apply(*Drawer)
|
||||||
@@ -38,3 +42,17 @@ func WithKeypointStrokeWidth(w float64) Option {
|
|||||||
d.KeypointStrokeWidth = w
|
d.KeypointStrokeWidth = w
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithKeypointColor set Drawer KeypointColor
|
||||||
|
func WithKeypointColor(color string) Option {
|
||||||
|
return optionFunc(func(d *Drawer) {
|
||||||
|
d.KeypointColor = color
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithFont set Drawer Font
|
||||||
|
func WithFont(font *common.Font) Option {
|
||||||
|
return optionFunc(func(d *Drawer) {
|
||||||
|
d.Font = font
|
||||||
|
})
|
||||||
|
}
|
||||||
|
11
go/hand/pose3d/cgo.go
Normal file
11
go/hand/pose3d/cgo.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build !vulkan
|
||||||
|
|
||||||
|
package pose3d
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
11
go/hand/pose3d/cgo_vulkan.go
Normal file
11
go/hand/pose3d/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
// +build vulkan
|
||||||
|
|
||||||
|
package pose3d
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||||
|
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||||
|
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||||
|
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||||
|
*/
|
||||||
|
import "C"
|
2
go/hand/pose3d/doc.go
Normal file
2
go/hand/pose3d/doc.go
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// Package pose hand 3d pose estimator
|
||||||
|
package pose3d
|
62
go/hand/pose3d/mediapipe.go
Normal file
62
go/hand/pose3d/mediapipe.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package pose3d
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include "openvision/common/common.h"
|
||||||
|
#include "openvision/hand/pose3d.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
openvision "github.com/bububa/openvision/go"
|
||||||
|
"github.com/bububa/openvision/go/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mediapipe represents mediapipe estimator interface
|
||||||
|
type Mediapipe struct {
|
||||||
|
d C.IHandPose3DEstimator
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMediapipe() *Mediapipe {
|
||||||
|
return &Mediapipe{
|
||||||
|
d: C.new_mediapipe_hand(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mediapipe) Destroy() {
|
||||||
|
C.destroy_mediapipe_hand(m.d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Mediapipe) LoadModel(palmPath string, handPath string) error {
|
||||||
|
cPalm := C.CString(palmPath)
|
||||||
|
defer C.free(unsafe.Pointer(cPalm))
|
||||||
|
cHand := C.CString(handPath)
|
||||||
|
defer C.free(unsafe.Pointer(cHand))
|
||||||
|
retCode := C.mediapipe_hand_load_model(m.d, cPalm, cHand)
|
||||||
|
if retCode != 0 {
|
||||||
|
return openvision.LoadModelError(int(retCode))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect detect hand 3d pose
|
||||||
|
func (m *Mediapipe) Detect(img *common.Image) ([]common.PalmObject, error) {
|
||||||
|
imgWidth := img.WidthF64()
|
||||||
|
imgHeight := img.HeightF64()
|
||||||
|
data := img.Bytes()
|
||||||
|
cObjs := common.NewCPalmObjectVector()
|
||||||
|
defer common.FreeCPalmObjectVector(cObjs)
|
||||||
|
errCode := C.mediapipe_hand_detect(
|
||||||
|
m.d,
|
||||||
|
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||||
|
C.int(imgWidth), C.int(imgHeight),
|
||||||
|
(*C.PalmObjectVector)(unsafe.Pointer(cObjs)),
|
||||||
|
)
|
||||||
|
if errCode != 0 {
|
||||||
|
return nil, openvision.DetectHandError(int(errCode))
|
||||||
|
}
|
||||||
|
return common.GoPalmObjectVector(cObjs, imgWidth, imgHeight), nil
|
||||||
|
}
|
@@ -7,7 +7,6 @@ package segmentor
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/bububa/openvision/go/common"
|
"github.com/bububa/openvision/go/common"
|
||||||
@@ -41,11 +40,11 @@ func (d *Deeplabv3plus) LoadModel(modelPath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Matting implement Segmentor interface
|
// Matting implement Segmentor interface
|
||||||
func (d *Deeplabv3plus) Matting(img *common.Image) (image.Image, error) {
|
func (d *Deeplabv3plus) Matting(img *common.Image, out *common.Image) error {
|
||||||
return Matting(d, img)
|
return Matting(d, img, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge implement Segmentor interface
|
// Merge implement Segmentor interface
|
||||||
func (d *Deeplabv3plus) Merge(img *common.Image, bg *common.Image) (image.Image, error) {
|
func (d *Deeplabv3plus) Merge(img *common.Image, bg *common.Image, out *common.Image) error {
|
||||||
return Merge(d, img, bg)
|
return Merge(d, img, bg, out)
|
||||||
}
|
}
|
||||||
|
@@ -7,7 +7,6 @@ package segmentor
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/bububa/openvision/go/common"
|
"github.com/bububa/openvision/go/common"
|
||||||
@@ -41,11 +40,11 @@ func (d *ERDNet) LoadModel(modelPath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Matting implement Segmentor interface
|
// Matting implement Segmentor interface
|
||||||
func (d *ERDNet) Matting(img *common.Image) (image.Image, error) {
|
func (d *ERDNet) Matting(img *common.Image, out *common.Image) error {
|
||||||
return Matting(d, img)
|
return Matting(d, img, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge implement Segmentor interface
|
// Merge implement Segmentor interface
|
||||||
func (d *ERDNet) Merge(img *common.Image, bg *common.Image) (image.Image, error) {
|
func (d *ERDNet) Merge(img *common.Image, bg *common.Image, out *common.Image) error {
|
||||||
return Merge(d, img, bg)
|
return Merge(d, img, bg, out)
|
||||||
}
|
}
|
||||||
|
@@ -7,7 +7,6 @@ package segmentor
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/bububa/openvision/go/common"
|
"github.com/bububa/openvision/go/common"
|
||||||
@@ -44,11 +43,11 @@ func (d *RVM) LoadModel(modelPath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Matting implement Segmentor interface
|
// Matting implement Segmentor interface
|
||||||
func (d *RVM) Matting(img *common.Image) (image.Image, error) {
|
func (d *RVM) Matting(img *common.Image, out *common.Image) error {
|
||||||
return Matting(d, img)
|
return Matting(d, img, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge implement Segmentor interface
|
// Merge implement Segmentor interface
|
||||||
func (d *RVM) Merge(img *common.Image, bg *common.Image) (image.Image, error) {
|
func (d *RVM) Merge(img *common.Image, bg *common.Image, out *common.Image) error {
|
||||||
return Merge(d, img, bg)
|
return Merge(d, img, bg, out)
|
||||||
}
|
}
|
||||||
|
@@ -8,7 +8,6 @@ package segmentor
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
openvision "github.com/bububa/openvision/go"
|
openvision "github.com/bububa/openvision/go"
|
||||||
@@ -18,12 +17,12 @@ import (
|
|||||||
// Segmentor represents segmentor interface
|
// Segmentor represents segmentor interface
|
||||||
type Segmentor interface {
|
type Segmentor interface {
|
||||||
common.Estimator
|
common.Estimator
|
||||||
Matting(img *common.Image) (image.Image, error)
|
Matting(img *common.Image, out *common.Image) error
|
||||||
Merge(img *common.Image, bg *common.Image) (image.Image, error)
|
Merge(img *common.Image, bg *common.Image, out *common.Image) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matting returns pose segment matting image
|
// Matting returns pose segment matting image
|
||||||
func Matting(d Segmentor, img *common.Image) (image.Image, error) {
|
func Matting(d Segmentor, img *common.Image, out *common.Image) error {
|
||||||
imgWidth := img.WidthF64()
|
imgWidth := img.WidthF64()
|
||||||
imgHeight := img.HeightF64()
|
imgHeight := img.HeightF64()
|
||||||
data := img.Bytes()
|
data := img.Bytes()
|
||||||
@@ -36,13 +35,14 @@ func Matting(d Segmentor, img *common.Image) (image.Image, error) {
|
|||||||
C.int(imgHeight),
|
C.int(imgHeight),
|
||||||
(*C.Image)(unsafe.Pointer(outImgC)))
|
(*C.Image)(unsafe.Pointer(outImgC)))
|
||||||
if errCode != 0 {
|
if errCode != 0 {
|
||||||
return nil, openvision.DetectPoseError(int(errCode))
|
return openvision.DetectPoseError(int(errCode))
|
||||||
}
|
}
|
||||||
return common.GoImage(outImgC)
|
common.GoImage(outImgC, out)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge merge pose with background
|
// Merge merge pose with background
|
||||||
func Merge(d Segmentor, img *common.Image, bg *common.Image) (image.Image, error) {
|
func Merge(d Segmentor, img *common.Image, bg *common.Image, out *common.Image) error {
|
||||||
imgWidth := img.WidthF64()
|
imgWidth := img.WidthF64()
|
||||||
imgHeight := img.HeightF64()
|
imgHeight := img.HeightF64()
|
||||||
data := img.Bytes()
|
data := img.Bytes()
|
||||||
@@ -59,7 +59,8 @@ func Merge(d Segmentor, img *common.Image, bg *common.Image) (image.Image, error
|
|||||||
C.int(bgWidth), C.int(bgHeight),
|
C.int(bgWidth), C.int(bgHeight),
|
||||||
(*C.Image)(unsafe.Pointer(outImgC)))
|
(*C.Image)(unsafe.Pointer(outImgC)))
|
||||||
if errCode != 0 {
|
if errCode != 0 {
|
||||||
return nil, openvision.DetectPoseError(int(errCode))
|
return openvision.DetectPoseError(int(errCode))
|
||||||
}
|
}
|
||||||
return common.GoImage(outImgC)
|
common.GoImage(outImgC, out)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -7,7 +7,6 @@ package styletransfer
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/bububa/openvision/go/common"
|
"github.com/bububa/openvision/go/common"
|
||||||
@@ -41,6 +40,6 @@ func (d *AnimeGan2) LoadModel(modelPath string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transform implement StyleTransfer interface
|
// Transform implement StyleTransfer interface
|
||||||
func (d *AnimeGan2) Transform(img *common.Image) (image.Image, error) {
|
func (d *AnimeGan2) Transform(img *common.Image, out *common.Image) error {
|
||||||
return Transform(d, img)
|
return Transform(d, img, out)
|
||||||
}
|
}
|
||||||
|
@@ -8,7 +8,6 @@ package styletransfer
|
|||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
import (
|
import (
|
||||||
"image"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
openvision "github.com/bububa/openvision/go"
|
openvision "github.com/bububa/openvision/go"
|
||||||
@@ -18,11 +17,11 @@ import (
|
|||||||
// StyleTransfer represents Style Transfer interface
|
// StyleTransfer represents Style Transfer interface
|
||||||
type StyleTransfer interface {
|
type StyleTransfer interface {
|
||||||
common.Estimator
|
common.Estimator
|
||||||
Transform(img *common.Image) (image.Image, error)
|
Transform(img *common.Image, out *common.Image) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transform returns style transform image
|
// Transform returns style transform image
|
||||||
func Transform(d StyleTransfer, img *common.Image) (image.Image, error) {
|
func Transform(d StyleTransfer, img *common.Image, out *common.Image) error {
|
||||||
imgWidth := img.WidthF64()
|
imgWidth := img.WidthF64()
|
||||||
imgHeight := img.HeightF64()
|
imgHeight := img.HeightF64()
|
||||||
data := img.Bytes()
|
data := img.Bytes()
|
||||||
@@ -35,7 +34,8 @@ func Transform(d StyleTransfer, img *common.Image) (image.Image, error) {
|
|||||||
C.int(imgHeight),
|
C.int(imgHeight),
|
||||||
(*C.Image)(unsafe.Pointer(outImgC)))
|
(*C.Image)(unsafe.Pointer(outImgC)))
|
||||||
if errCode != 0 {
|
if errCode != 0 {
|
||||||
return nil, openvision.DetectPoseError(int(errCode))
|
return openvision.DetectPoseError(int(errCode))
|
||||||
}
|
}
|
||||||
return common.GoImage(outImgC)
|
common.GoImage(outImgC, out)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
file(GLOB_RECURSE SRC_FILES
|
file(GLOB_RECURSE SRC_FILES
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cxx
|
${CMAKE_CURRENT_SOURCE_DIR}/*.cxx
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/*.c
|
||||||
)
|
)
|
||||||
|
|
||||||
message(${SRC_FILES})
|
message(${SRC_FILES})
|
||||||
@@ -10,6 +11,11 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O2 -fPIC -std=c++11 -fopenmp")
|
|||||||
add_library(openvision STATIC ${SRC_FILES})
|
add_library(openvision STATIC ${SRC_FILES})
|
||||||
target_link_libraries(openvision PUBLIC ncnn)
|
target_link_libraries(openvision PUBLIC ncnn)
|
||||||
|
|
||||||
|
# set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||||
|
# find_package(Torch REQUIRED)
|
||||||
|
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||||
|
# target_link_libraries(openvision PUBLIC ${TORCH_LIBRARIES})
|
||||||
|
|
||||||
if(OV_OPENMP)
|
if(OV_OPENMP)
|
||||||
find_package(OpenMP)
|
find_package(OpenMP)
|
||||||
if(NOT TARGET OpenMP::OpenMP_CXX AND (OpenMP_CXX_FOUND OR OPENMP_FOUND))
|
if(NOT TARGET OpenMP::OpenMP_CXX AND (OpenMP_CXX_FOUND OR OPENMP_FOUND))
|
||||||
@@ -57,13 +63,15 @@ target_include_directories(openvision
|
|||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/recognizer/mobilefacenet>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/recognizer/mobilefacenet>
|
||||||
|
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/tracker>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/tracker>
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/hair>
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/eye>
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/hopenet>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/hopenet>
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/aligner>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/aligner>
|
||||||
|
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand>
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/detecter>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/detecter>
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/pose>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/pose>
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/pose3d>
|
||||||
|
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/pose>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/pose>
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/pose/detecter>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/pose/detecter>
|
||||||
@@ -73,6 +81,11 @@ target_include_directories(openvision
|
|||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/styletransfer>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/styletransfer>
|
||||||
|
|
||||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/tracker>
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/tracker>
|
||||||
|
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/counter>
|
||||||
|
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/classifier>
|
||||||
|
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/classifier/svm>
|
||||||
)
|
)
|
||||||
|
|
||||||
#install(TARGETS openvision EXPORT openvision ARCHIVE DESTINATION ${LIBRARY_OUTPUT_PATH})
|
#install(TARGETS openvision EXPORT openvision ARCHIVE DESTINATION ${LIBRARY_OUTPUT_PATH})
|
||||||
@@ -89,12 +102,15 @@ file(COPY
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/face/tracker.h
|
${CMAKE_CURRENT_SOURCE_DIR}/face/tracker.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/face/hopenet.h
|
${CMAKE_CURRENT_SOURCE_DIR}/face/hopenet.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/face/aligner.h
|
${CMAKE_CURRENT_SOURCE_DIR}/face/aligner.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/face/hair.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/face/eye.h
|
||||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/face
|
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/face
|
||||||
)
|
)
|
||||||
|
|
||||||
file(COPY
|
file(COPY
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hand/detecter.h
|
${CMAKE_CURRENT_SOURCE_DIR}/hand/detecter.h
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hand/pose.h
|
${CMAKE_CURRENT_SOURCE_DIR}/hand/pose.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/hand/pose3d.h
|
||||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/hand
|
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/hand
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,3 +131,13 @@ file(COPY
|
|||||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/tracker
|
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/tracker
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file(COPY
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/counter/counter.h
|
||||||
|
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/counter
|
||||||
|
)
|
||||||
|
|
||||||
|
file(COPY
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/classifier/svm_trainer.h
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/classifier/svm_classifier.h
|
||||||
|
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/classifier
|
||||||
|
)
|
||||||
|
55
src/classifier/svm/svm_binary_classifier.cpp
Normal file
55
src/classifier/svm/svm_binary_classifier.cpp
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#include "svm_binary_classifier.hpp"
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
SVMBinaryClassifier::SVMBinaryClassifier() {}
|
||||||
|
|
||||||
|
SVMBinaryClassifier::~SVMBinaryClassifier() {
|
||||||
|
if (model_ != NULL) {
|
||||||
|
free_model(model_, 1);
|
||||||
|
model_ = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int SVMBinaryClassifier::LoadModel(const char *modelfile) {
|
||||||
|
if (model_ != NULL) {
|
||||||
|
free_model(model_, 1);
|
||||||
|
}
|
||||||
|
model_ = (MODEL *)my_malloc(sizeof(MODEL));
|
||||||
|
model_ = read_model((char *)modelfile);
|
||||||
|
if (model_->kernel_parm.kernel_type == 0) { /* linear kernel */
|
||||||
|
/* compute weight vector */
|
||||||
|
add_weight_vector_to_linear_model(model_);
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
double SVMBinaryClassifier::Predict(const float *vec) {
|
||||||
|
WORD *words = (WORD *)malloc(sizeof(WORD) * (model_->totwords + 10));
|
||||||
|
for (int i = 0; i < (model_->totwords + 10); ++i) {
|
||||||
|
if (i >= model_->totwords) {
|
||||||
|
words[i].wnum = 0;
|
||||||
|
words[i].weight = 0;
|
||||||
|
} else {
|
||||||
|
words[i].wnum = i + 1;
|
||||||
|
words[i].weight = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DOC *doc =
|
||||||
|
create_example(-1, 0, 0, 0.0, create_svector(words, (char *)"", 1.0));
|
||||||
|
free(words);
|
||||||
|
double dist;
|
||||||
|
if (model_->kernel_parm.kernel_type == 0) {
|
||||||
|
dist = classify_example_linear(model_, doc);
|
||||||
|
} else {
|
||||||
|
dist = classify_example(model_, doc);
|
||||||
|
}
|
||||||
|
free_example(doc, 1);
|
||||||
|
return dist;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SVMBinaryClassifier::Classify(const float *vec,
|
||||||
|
std::vector<float> &scores) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
} // namespace ovclassifier
|
19
src/classifier/svm/svm_binary_classifier.hpp
Normal file
19
src/classifier/svm/svm_binary_classifier.hpp
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#ifndef _CLASSIFIER_SVM_BINARY_CLASSIFIER_H_
|
||||||
|
#define _CLASSIFIER_SVM_BINARY_CLASSIFIER_H_
|
||||||
|
|
||||||
|
#include "svm_classifier.hpp"
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
class SVMBinaryClassifier : public SVMClassifier {
|
||||||
|
public:
|
||||||
|
SVMBinaryClassifier();
|
||||||
|
~SVMBinaryClassifier();
|
||||||
|
int LoadModel(const char *modelfile);
|
||||||
|
double Predict(const float *vec);
|
||||||
|
int Classify(const float *vec, std::vector<float> &scores);
|
||||||
|
|
||||||
|
private:
|
||||||
|
MODEL *model_ = NULL;
|
||||||
|
};
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // !_CLASSIFIER_SVM_BINARY_CLASSIFIER_H_
|
126
src/classifier/svm/svm_binary_trainer.cpp
Normal file
126
src/classifier/svm/svm_binary_trainer.cpp
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
#include "svm_binary_trainer.hpp"
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
#include "svm_light/svm_learn.h"
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
|
||||||
|
SVMBinaryTrainer::SVMBinaryTrainer() {
|
||||||
|
learn_parm = (LEARN_PARM *)malloc(sizeof(LEARN_PARM));
|
||||||
|
strcpy(learn_parm->predfile, "trans_predictions");
|
||||||
|
strcpy(learn_parm->alphafile, "");
|
||||||
|
learn_parm->biased_hyperplane = 1;
|
||||||
|
learn_parm->sharedslack = 0;
|
||||||
|
learn_parm->remove_inconsistent = 0;
|
||||||
|
learn_parm->skip_final_opt_check = 0;
|
||||||
|
learn_parm->svm_maxqpsize = 10;
|
||||||
|
learn_parm->svm_newvarsinqp = 0;
|
||||||
|
learn_parm->svm_iter_to_shrink = -9999;
|
||||||
|
learn_parm->maxiter = 100000;
|
||||||
|
learn_parm->kernel_cache_size = 40;
|
||||||
|
learn_parm->svm_c = 0.0;
|
||||||
|
learn_parm->eps = 0.1;
|
||||||
|
learn_parm->transduction_posratio = -1.0;
|
||||||
|
learn_parm->svm_costratio = 1.0;
|
||||||
|
learn_parm->svm_costratio_unlab = 1.0;
|
||||||
|
learn_parm->svm_unlabbound = 1E-5;
|
||||||
|
learn_parm->epsilon_crit = 0.001;
|
||||||
|
learn_parm->epsilon_a = 1E-15;
|
||||||
|
learn_parm->compute_loo = 0;
|
||||||
|
learn_parm->rho = 1.0;
|
||||||
|
learn_parm->xa_depth = 0;
|
||||||
|
kernel_parm = (KERNEL_PARM *)malloc(sizeof(KERNEL_PARM));
|
||||||
|
kernel_parm->kernel_type = 0;
|
||||||
|
kernel_parm->poly_degree = 3;
|
||||||
|
kernel_parm->rbf_gamma = 1.0;
|
||||||
|
kernel_parm->coef_lin = 1;
|
||||||
|
kernel_parm->coef_const = 1;
|
||||||
|
strcpy(kernel_parm->custom, "empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
SVMBinaryTrainer::~SVMBinaryTrainer() {
|
||||||
|
if (learn_parm != NULL) {
|
||||||
|
free(learn_parm);
|
||||||
|
learn_parm = NULL;
|
||||||
|
}
|
||||||
|
if (kernel_parm != NULL) {
|
||||||
|
free(kernel_parm);
|
||||||
|
kernel_parm = NULL;
|
||||||
|
}
|
||||||
|
if (kernel_cache != NULL) {
|
||||||
|
kernel_cache_cleanup(kernel_cache);
|
||||||
|
kernel_cache = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SVMBinaryTrainer::Reset() {
|
||||||
|
feats_ = 0;
|
||||||
|
items_.clear();
|
||||||
|
if (kernel_cache != NULL) {
|
||||||
|
kernel_cache_cleanup(kernel_cache);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SVMBinaryTrainer::SetLabels(int labels) {}
|
||||||
|
|
||||||
|
void SVMBinaryTrainer::SetFeatures(int feats) { feats_ = feats; }
|
||||||
|
|
||||||
|
void SVMBinaryTrainer::AddData(int label, const float *vec) {
|
||||||
|
if (label != 1 && label != -1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
LabelItem itm;
|
||||||
|
itm.label = label;
|
||||||
|
for (int i = 0; i < feats_; ++i) {
|
||||||
|
itm.vec.push_back(vec[i]);
|
||||||
|
}
|
||||||
|
items_.push_back(itm);
|
||||||
|
}
|
||||||
|
|
||||||
|
int SVMBinaryTrainer::Train(const char *modelfile) {
|
||||||
|
int totdoc = items_.size();
|
||||||
|
if (totdoc == 0 || feats_ == 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
kernel_cache = kernel_cache_init(totdoc, learn_parm->kernel_cache_size);
|
||||||
|
double *labels = (double *)malloc(sizeof(double) * totdoc);
|
||||||
|
double *alphas = (double *)malloc(sizeof(double) * totdoc);
|
||||||
|
DOC **docs = (DOC **)malloc(sizeof(DOC *) * totdoc);
|
||||||
|
WORD *words = (WORD *)malloc(sizeof(WORD) * (feats_ + 10));
|
||||||
|
for (int dnum = 0; dnum < totdoc; ++dnum) {
|
||||||
|
const int docFeats = items_[dnum].vec.size();
|
||||||
|
for (int i = 0; i < (feats_ + 10); ++i) {
|
||||||
|
if (i >= feats_) {
|
||||||
|
words[i].wnum = 0;
|
||||||
|
} else {
|
||||||
|
words[i].wnum = i + 1;
|
||||||
|
}
|
||||||
|
if (i >= docFeats) {
|
||||||
|
words[i].weight = 0;
|
||||||
|
} else {
|
||||||
|
words[i].weight = items_[dnum].vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
labels[dnum] = items_[dnum].label;
|
||||||
|
docs[dnum] =
|
||||||
|
create_example(dnum, 0, 0, 0, create_svector(words, (char *)"", 1.0));
|
||||||
|
}
|
||||||
|
free(words);
|
||||||
|
|
||||||
|
MODEL *model_ = (MODEL *)malloc(sizeof(MODEL));
|
||||||
|
svm_learn_classification(docs, labels, (long int)totdoc, (long int)feats_,
|
||||||
|
learn_parm, kernel_parm, kernel_cache, model_,
|
||||||
|
alphas);
|
||||||
|
write_model((char *)modelfile, model_);
|
||||||
|
free(labels);
|
||||||
|
labels = NULL;
|
||||||
|
free(alphas);
|
||||||
|
alphas = NULL;
|
||||||
|
for (int i = 0; i < totdoc; i++) {
|
||||||
|
free_example(docs[i], 1);
|
||||||
|
}
|
||||||
|
free_model(model_, 0);
|
||||||
|
model_ = NULL;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ovclassifier
|
28
src/classifier/svm/svm_binary_trainer.hpp
Normal file
28
src/classifier/svm/svm_binary_trainer.hpp
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
#ifndef _SVM_BINARY_TRAINER_H_
|
||||||
|
#define _SVM_BINARY_TRAINER_H_
|
||||||
|
|
||||||
|
#include "svm_common.hpp"
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
#include "svm_trainer.hpp"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
class SVMBinaryTrainer : public SVMTrainer {
|
||||||
|
public:
|
||||||
|
SVMBinaryTrainer();
|
||||||
|
~SVMBinaryTrainer();
|
||||||
|
void Reset();
|
||||||
|
void SetLabels(int labels);
|
||||||
|
void SetFeatures(int feats);
|
||||||
|
void AddData(int label, const float *vec);
|
||||||
|
int Train(const char *modelfile);
|
||||||
|
|
||||||
|
private:
|
||||||
|
KERNEL_PARM *kernel_parm = NULL;
|
||||||
|
LEARN_PARM *learn_parm = NULL;
|
||||||
|
KERNEL_CACHE *kernel_cache = NULL;
|
||||||
|
int feats_;
|
||||||
|
std::vector<LabelItem> items_;
|
||||||
|
};
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // _SVM_BINARY_TRAINER_H_
|
33
src/classifier/svm/svm_classifier.cpp
Normal file
33
src/classifier/svm/svm_classifier.cpp
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#include "../svm_classifier.h"
|
||||||
|
#include "svm_binary_classifier.hpp"
|
||||||
|
#include "svm_multiclass_classifier.hpp"
|
||||||
|
|
||||||
|
ISVMClassifier new_svm_binary_classifier() {
|
||||||
|
return new ovclassifier::SVMBinaryClassifier();
|
||||||
|
}
|
||||||
|
ISVMClassifier new_svm_multiclass_classifier() {
|
||||||
|
return new ovclassifier::SVMMultiClassClassifier();
|
||||||
|
}
|
||||||
|
void destroy_svm_classifier(ISVMClassifier e) {
|
||||||
|
delete static_cast<ovclassifier::SVMClassifier *>(e);
|
||||||
|
}
|
||||||
|
int svm_classifier_load_model(ISVMClassifier e, const char *modelfile) {
|
||||||
|
return static_cast<ovclassifier::SVMClassifier *>(e)->LoadModel(modelfile);
|
||||||
|
}
|
||||||
|
double svm_predict(ISVMClassifier e, const float *vec) {
|
||||||
|
return static_cast<ovclassifier::SVMClassifier *>(e)->Predict(vec);
|
||||||
|
}
|
||||||
|
int svm_classify(ISVMClassifier e, const float *vec, FloatVector *scores) {
|
||||||
|
std::vector<float> scores_;
|
||||||
|
int ret =
|
||||||
|
static_cast<ovclassifier::SVMClassifier *>(e)->Classify(vec, scores_);
|
||||||
|
if (ret != 0) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
scores->length = scores_.size();
|
||||||
|
scores->values = (float *)malloc(sizeof(float) * scores->length);
|
||||||
|
for (int i = 0; i < scores->length; ++i) {
|
||||||
|
scores->values[i] = scores_[i];
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
16
src/classifier/svm/svm_classifier.hpp
Normal file
16
src/classifier/svm/svm_classifier.hpp
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
#ifndef _CLASSIFIER_SVM_CLASSIFIER_H_
|
||||||
|
#define _CLASSIFIER_SVM_CLASSIFIER_H_
|
||||||
|
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
class SVMClassifier {
|
||||||
|
public:
|
||||||
|
virtual ~SVMClassifier(){};
|
||||||
|
virtual int LoadModel(const char *modelfile) = 0;
|
||||||
|
virtual double Predict(const float *vec) = 0;
|
||||||
|
virtual int Classify(const float *vec, std::vector<float> &scores) = 0;
|
||||||
|
};
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // !_CLASSIFIER_SVM_CLASSIFIER_H_
|
11
src/classifier/svm/svm_common.hpp
Normal file
11
src/classifier/svm/svm_common.hpp
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#ifndef _SVM_COMMON_HPP_
|
||||||
|
#define _SVM_COMMON_HPP_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
namespace ovclassifier {
|
||||||
|
struct LabelItem {
|
||||||
|
int label;
|
||||||
|
std::vector<float> vec;
|
||||||
|
};
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // !_SVM_COMMON_HPP_
|
40
src/classifier/svm/svm_light/kernel.h
Normal file
40
src/classifier/svm/svm_light/kernel.h
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/************************************************************************/
|
||||||
|
/* */
|
||||||
|
/* kernel.h */
|
||||||
|
/* */
|
||||||
|
/* User defined kernel function. Feel free to plug in your own. */
|
||||||
|
/* */
|
||||||
|
/* Copyright: Thorsten Joachims */
|
||||||
|
/* Date: 16.12.97 */
|
||||||
|
/* */
|
||||||
|
/************************************************************************/
|
||||||
|
|
||||||
|
/* KERNEL_PARM is defined in svm_common.h The field 'custom' is reserved for */
|
||||||
|
/* parameters of the user defined kernel. You can also access and use */
|
||||||
|
/* the parameters of the other kernels. Just replace the line
|
||||||
|
return((double)(1.0));
|
||||||
|
with your own kernel. */
|
||||||
|
|
||||||
|
/* Example: The following computes the polynomial kernel. sprod_ss
|
||||||
|
computes the inner product between two sparse vectors.
|
||||||
|
|
||||||
|
return((CFLOAT)pow(kernel_parm->coef_lin*sprod_ss(a,b)
|
||||||
|
+kernel_parm->coef_const,(double)kernel_parm->poly_degree));
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* If you are implementing a kernel that is not based on a
|
||||||
|
feature/value representation, you might want to make use of the
|
||||||
|
field "userdefined" in SVECTOR. By default, this field will contain
|
||||||
|
whatever string you put behind a # sign in the example file. So, if
|
||||||
|
a line in your training file looks like
|
||||||
|
|
||||||
|
-1 1:3 5:6 #abcdefg
|
||||||
|
|
||||||
|
then the SVECTOR field "words" will contain the vector 1:3 5:6, and
|
||||||
|
"userdefined" will contain the string "abcdefg". */
|
||||||
|
|
||||||
|
double custom_kernel(KERNEL_PARM *kernel_parm, SVECTOR *a, SVECTOR *b)
|
||||||
|
/* plug in you favorite kernel */
|
||||||
|
{
|
||||||
|
return((double)(1.0));
|
||||||
|
}
|
619
src/classifier/svm/svm_light/pr_loqo/pr_loqo.c
Normal file
619
src/classifier/svm/svm_light/pr_loqo/pr_loqo.c
Normal file
@@ -0,0 +1,619 @@
|
|||||||
|
/*
|
||||||
|
* File: pr_loqo.c
|
||||||
|
* Purpose: solves quadratic programming problem for pattern recognition
|
||||||
|
* for support vectors
|
||||||
|
*
|
||||||
|
* Author: Alex J. Smola
|
||||||
|
* Created: 10/14/97
|
||||||
|
* Updated: 11/08/97
|
||||||
|
* Updated: 13/08/98 (removed exit(1) as it crashes svm lite when the margin
|
||||||
|
* in a not sufficiently conservative manner)
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* Copyright (c) 1997 GMD Berlin - All rights reserved
|
||||||
|
* THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE of GMD Berlin
|
||||||
|
* The copyright notice above does not evidence any
|
||||||
|
* actual or intended publication of this work.
|
||||||
|
*
|
||||||
|
* Unauthorized commercial use of this software is not allowed
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "pr_loqo.h"
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <time.h>
|
||||||
|
|
||||||
|
#define max(A, B) ((A) > (B) ? (A) : (B))
|
||||||
|
#define min(A, B) ((A) < (B) ? (A) : (B))
|
||||||
|
#define sqr(A) ((A) * (A))
|
||||||
|
#define ABS(A) ((A) > 0 ? (A) : (-(A)))
|
||||||
|
|
||||||
|
#define PREDICTOR 1
|
||||||
|
#define CORRECTOR 2
|
||||||
|
|
||||||
|
/*****************************************************************
|
||||||
|
replace this by any other function that will exit gracefully
|
||||||
|
in a larger system
|
||||||
|
***************************************************************/
|
||||||
|
|
||||||
|
void nrerror(char error_text[]) {
|
||||||
|
printf("ERROR: terminating optimizer - %s\n", error_text);
|
||||||
|
/* exit(1); */
|
||||||
|
}
|
||||||
|
|
||||||
|
/*****************************************************************
|
||||||
|
taken from numerical recipes and modified to accept pointers
|
||||||
|
moreover numerical recipes code seems to be buggy (at least the
|
||||||
|
ones on the web)
|
||||||
|
|
||||||
|
cholesky solver and backsubstitution
|
||||||
|
leaves upper right triangle intact (rows first order)
|
||||||
|
***************************************************************/
|
||||||
|
|
||||||
|
void choldc(double a[], int n, double p[]) {
|
||||||
|
void nrerror(char error_text[]);
|
||||||
|
int i, j, k;
|
||||||
|
double sum;
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
for (j = i; j < n; j++) {
|
||||||
|
sum = a[n * i + j];
|
||||||
|
for (k = i - 1; k >= 0; k--)
|
||||||
|
sum -= a[n * i + k] * a[n * j + k];
|
||||||
|
if (i == j) {
|
||||||
|
if (sum <= 0.0) {
|
||||||
|
nrerror((char *)"choldc failed, matrix not positive definite");
|
||||||
|
sum = 0.0;
|
||||||
|
}
|
||||||
|
p[i] = sqrt(sum);
|
||||||
|
} else
|
||||||
|
a[n * j + i] = sum / p[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void cholsb(double a[], int n, double p[], double b[], double x[]) {
|
||||||
|
int i, k;
|
||||||
|
double sum;
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
sum = b[i];
|
||||||
|
for (k = i - 1; k >= 0; k--)
|
||||||
|
sum -= a[n * i + k] * x[k];
|
||||||
|
x[i] = sum / p[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = n - 1; i >= 0; i--) {
|
||||||
|
sum = x[i];
|
||||||
|
for (k = i + 1; k < n; k++)
|
||||||
|
sum -= a[n * k + i] * x[k];
|
||||||
|
x[i] = sum / p[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*****************************************************************
|
||||||
|
sometimes we only need the forward or backward pass of the
|
||||||
|
backsubstitution, hence we provide these two routines separately
|
||||||
|
***************************************************************/
|
||||||
|
|
||||||
|
void chol_forward(double a[], int n, double p[], double b[], double x[]) {
|
||||||
|
int i, k;
|
||||||
|
double sum;
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
sum = b[i];
|
||||||
|
for (k = i - 1; k >= 0; k--)
|
||||||
|
sum -= a[n * i + k] * x[k];
|
||||||
|
x[i] = sum / p[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void chol_backward(double a[], int n, double p[], double b[], double x[]) {
|
||||||
|
int i, k;
|
||||||
|
double sum;
|
||||||
|
|
||||||
|
for (i = n - 1; i >= 0; i--) {
|
||||||
|
sum = b[i];
|
||||||
|
for (k = i + 1; k < n; k++)
|
||||||
|
sum -= a[n * k + i] * x[k];
|
||||||
|
x[i] = sum / p[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*****************************************************************
|
||||||
|
solves the system | -H_x A' | |x_x| = |c_x|
|
||||||
|
| A H_y| |x_y| |c_y|
|
||||||
|
|
||||||
|
with H_x (and H_y) positive (semidefinite) matrices
|
||||||
|
and n, m the respective sizes of H_x and H_y
|
||||||
|
|
||||||
|
for variables see pg. 48 of notebook or do the calculations on a
|
||||||
|
sheet of paper again
|
||||||
|
|
||||||
|
predictor solves the whole thing, corrector assues that H_x didn't
|
||||||
|
change and relies on the results of the predictor. therefore do
|
||||||
|
_not_ modify workspace
|
||||||
|
|
||||||
|
if you want to speed tune anything in the code here's the right
|
||||||
|
place to do so: about 95% of the time is being spent in
|
||||||
|
here. something like an iterative refinement would be nice,
|
||||||
|
especially when switching from double to single precision. if you
|
||||||
|
have a fast parallel cholesky use it instead of the numrec
|
||||||
|
implementations.
|
||||||
|
|
||||||
|
side effects: changes H_y (but this is just the unit matrix or zero anyway
|
||||||
|
in our case)
|
||||||
|
***************************************************************/
|
||||||
|
|
||||||
|
void solve_reduced(int n, int m, double h_x[], double h_y[], double a[],
|
||||||
|
double x_x[], double x_y[], double c_x[], double c_y[],
|
||||||
|
double workspace[], int step) {
|
||||||
|
int i, j, k;
|
||||||
|
|
||||||
|
double *p_x;
|
||||||
|
double *p_y;
|
||||||
|
double *t_a;
|
||||||
|
double *t_c;
|
||||||
|
double *t_y;
|
||||||
|
|
||||||
|
p_x = workspace; /* together n + m + n*m + n + m = n*(m+2)+2*m */
|
||||||
|
p_y = p_x + n;
|
||||||
|
t_a = p_y + m;
|
||||||
|
t_c = t_a + n * m;
|
||||||
|
t_y = t_c + n;
|
||||||
|
|
||||||
|
if (step == PREDICTOR) {
|
||||||
|
choldc(h_x, n, p_x); /* do cholesky decomposition */
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++) /* forward pass for A' */
|
||||||
|
chol_forward(h_x, n, p_x, a + i * n, t_a + i * n);
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++) /* compute (h_y + a h_x^-1A') */
|
||||||
|
for (j = i; j < m; j++)
|
||||||
|
for (k = 0; k < n; k++)
|
||||||
|
h_y[m * i + j] += t_a[n * j + k] * t_a[n * i + k];
|
||||||
|
|
||||||
|
choldc(h_y, m, p_y); /* and cholesky decomposition */
|
||||||
|
}
|
||||||
|
|
||||||
|
chol_forward(h_x, n, p_x, c_x, t_c);
|
||||||
|
/* forward pass for c */
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++) { /* and solve for x_y */
|
||||||
|
t_y[i] = c_y[i];
|
||||||
|
for (j = 0; j < n; j++)
|
||||||
|
t_y[i] += t_a[i * n + j] * t_c[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
cholsb(h_y, m, p_y, t_y, x_y);
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) { /* finally solve for x_x */
|
||||||
|
t_c[i] = -t_c[i];
|
||||||
|
for (j = 0; j < m; j++)
|
||||||
|
t_c[i] += t_a[j * n + i] * x_y[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
chol_backward(h_x, n, p_x, t_c, x_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*****************************************************************
|
||||||
|
matrix vector multiplication (symmetric matrix but only one triangle
|
||||||
|
given). computes m*x = y
|
||||||
|
no need to tune it as it's only of O(n^2) but cholesky is of
|
||||||
|
O(n^3). so don't waste your time _here_ although it isn't very
|
||||||
|
elegant.
|
||||||
|
***************************************************************/
|
||||||
|
|
||||||
|
void matrix_vector(int n, double m[], double x[], double y[]) {
|
||||||
|
int i, j;
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
y[i] = m[(n + 1) * i] * x[i];
|
||||||
|
|
||||||
|
for (j = 0; j < i; j++)
|
||||||
|
y[i] += m[i + n * j] * x[j];
|
||||||
|
|
||||||
|
for (j = i + 1; j < n; j++)
|
||||||
|
y[i] += m[n * i + j] * x[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*****************************************************************
|
||||||
|
call only this routine; this is the only one you're interested in
|
||||||
|
for doing quadratical optimization
|
||||||
|
|
||||||
|
the restart feature exists but it may not be of much use due to the
|
||||||
|
fact that an initial setting, although close but not very close the
|
||||||
|
the actual solution will result in very good starting diagnostics
|
||||||
|
(primal and dual feasibility and small infeasibility gap) but incur
|
||||||
|
later stalling of the optimizer afterwards as we have to enforce
|
||||||
|
positivity of the slacks.
|
||||||
|
***************************************************************/
|
||||||
|
|
||||||
|
int pr_loqo(int n, int m, double c[], double h_x[], double a[], double b[],
|
||||||
|
double l[], double u[], double primal[], double dual[], int verb,
|
||||||
|
double sigfig_max, int counter_max, double margin, double bound,
|
||||||
|
int restart) {
|
||||||
|
/* the knobs to be tuned ... */
|
||||||
|
/* double margin = -0.95; we will go up to 95% of the
|
||||||
|
distance between old variables and zero */
|
||||||
|
/* double bound = 10; preset value for the start. small
|
||||||
|
values give good initial
|
||||||
|
feasibility but may result in slow
|
||||||
|
convergence afterwards: we're too
|
||||||
|
close to zero */
|
||||||
|
/* to be allocated */
|
||||||
|
double *workspace;
|
||||||
|
double *diag_h_x;
|
||||||
|
double *h_y;
|
||||||
|
double *c_x;
|
||||||
|
double *c_y;
|
||||||
|
double *h_dot_x;
|
||||||
|
double *rho;
|
||||||
|
double *nu;
|
||||||
|
double *tau;
|
||||||
|
double *sigma;
|
||||||
|
double *gamma_z;
|
||||||
|
double *gamma_s;
|
||||||
|
|
||||||
|
double *hat_nu;
|
||||||
|
double *hat_tau;
|
||||||
|
|
||||||
|
double *delta_x;
|
||||||
|
double *delta_y;
|
||||||
|
double *delta_s;
|
||||||
|
double *delta_z;
|
||||||
|
double *delta_g;
|
||||||
|
double *delta_t;
|
||||||
|
|
||||||
|
double *d;
|
||||||
|
|
||||||
|
/* from the header - pointers into primal and dual */
|
||||||
|
double *x;
|
||||||
|
double *y;
|
||||||
|
double *g;
|
||||||
|
double *z;
|
||||||
|
double *s;
|
||||||
|
double *t;
|
||||||
|
|
||||||
|
/* auxiliary variables */
|
||||||
|
double b_plus_1;
|
||||||
|
double c_plus_1;
|
||||||
|
|
||||||
|
double x_h_x;
|
||||||
|
double primal_inf;
|
||||||
|
double dual_inf;
|
||||||
|
|
||||||
|
double sigfig;
|
||||||
|
double primal_obj, dual_obj;
|
||||||
|
double mu;
|
||||||
|
double alfa, step;
|
||||||
|
int counter = 0;
|
||||||
|
|
||||||
|
int status = STILL_RUNNING;
|
||||||
|
|
||||||
|
int i, j, k;
|
||||||
|
|
||||||
|
/* memory allocation */
|
||||||
|
workspace = malloc((n * (m + 2) + 2 * m) * sizeof(double));
|
||||||
|
diag_h_x = malloc(n * sizeof(double));
|
||||||
|
h_y = malloc(m * m * sizeof(double));
|
||||||
|
c_x = malloc(n * sizeof(double));
|
||||||
|
c_y = malloc(m * sizeof(double));
|
||||||
|
h_dot_x = malloc(n * sizeof(double));
|
||||||
|
|
||||||
|
rho = malloc(m * sizeof(double));
|
||||||
|
nu = malloc(n * sizeof(double));
|
||||||
|
tau = malloc(n * sizeof(double));
|
||||||
|
sigma = malloc(n * sizeof(double));
|
||||||
|
|
||||||
|
gamma_z = malloc(n * sizeof(double));
|
||||||
|
gamma_s = malloc(n * sizeof(double));
|
||||||
|
|
||||||
|
hat_nu = malloc(n * sizeof(double));
|
||||||
|
hat_tau = malloc(n * sizeof(double));
|
||||||
|
|
||||||
|
delta_x = malloc(n * sizeof(double));
|
||||||
|
delta_y = malloc(m * sizeof(double));
|
||||||
|
delta_s = malloc(n * sizeof(double));
|
||||||
|
delta_z = malloc(n * sizeof(double));
|
||||||
|
delta_g = malloc(n * sizeof(double));
|
||||||
|
delta_t = malloc(n * sizeof(double));
|
||||||
|
|
||||||
|
d = malloc(n * sizeof(double));
|
||||||
|
|
||||||
|
/* pointers into the external variables */
|
||||||
|
x = primal; /* n */
|
||||||
|
g = x + n; /* n */
|
||||||
|
t = g + n; /* n */
|
||||||
|
|
||||||
|
y = dual; /* m */
|
||||||
|
z = y + m; /* n */
|
||||||
|
s = z + n; /* n */
|
||||||
|
|
||||||
|
/* initial settings */
|
||||||
|
b_plus_1 = 1;
|
||||||
|
c_plus_1 = 0;
|
||||||
|
for (i = 0; i < n; i++)
|
||||||
|
c_plus_1 += c[i];
|
||||||
|
|
||||||
|
/* get diagonal terms */
|
||||||
|
for (i = 0; i < n; i++)
|
||||||
|
diag_h_x[i] = h_x[(n + 1) * i];
|
||||||
|
|
||||||
|
/* starting point */
|
||||||
|
if (restart == 1) {
|
||||||
|
/* x, y already preset */
|
||||||
|
for (i = 0; i < n; i++) { /* compute g, t for primal feasibility */
|
||||||
|
g[i] = max(ABS(x[i] - l[i]), bound);
|
||||||
|
t[i] = max(ABS(u[i] - x[i]), bound);
|
||||||
|
}
|
||||||
|
|
||||||
|
matrix_vector(n, h_x, x, h_dot_x); /* h_dot_x = h_x * x */
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) { /* sigma is a dummy variable to calculate z, s */
|
||||||
|
sigma[i] = c[i] + h_dot_x[i];
|
||||||
|
for (j = 0; j < m; j++)
|
||||||
|
sigma[i] -= a[n * j + i] * y[j];
|
||||||
|
|
||||||
|
if (sigma[i] > 0) {
|
||||||
|
s[i] = bound;
|
||||||
|
z[i] = sigma[i] + bound;
|
||||||
|
} else {
|
||||||
|
s[i] = bound - sigma[i];
|
||||||
|
z[i] = bound;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else { /* use default start settings */
|
||||||
|
for (i = 0; i < m; i++)
|
||||||
|
for (j = i; j < m; j++)
|
||||||
|
h_y[i * m + j] = (i == j) ? 1 : 0;
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
c_x[i] = c[i];
|
||||||
|
h_x[(n + 1) * i] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++)
|
||||||
|
c_y[i] = b[i];
|
||||||
|
|
||||||
|
/* and solve the system [-H_x A'; A H_y] [x, y] = [c_x; c_y] */
|
||||||
|
solve_reduced(n, m, h_x, h_y, a, x, y, c_x, c_y, workspace, PREDICTOR);
|
||||||
|
|
||||||
|
/* initialize the other variables */
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
g[i] = max(ABS(x[i] - l[i]), bound);
|
||||||
|
z[i] = max(ABS(x[i]), bound);
|
||||||
|
t[i] = max(ABS(u[i] - x[i]), bound);
|
||||||
|
s[i] = max(ABS(x[i]), bound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0, mu = 0; i < n; i++)
|
||||||
|
mu += z[i] * g[i] + s[i] * t[i];
|
||||||
|
mu = mu / (2 * n);
|
||||||
|
|
||||||
|
/* the main loop */
|
||||||
|
if (verb >= STATUS) {
|
||||||
|
printf("counter | pri_inf | dual_inf | pri_obj | dual_obj | ");
|
||||||
|
printf("sigfig | alpha | nu \n");
|
||||||
|
printf("-------------------------------------------------------");
|
||||||
|
printf("---------------------------\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
while (status == STILL_RUNNING) {
|
||||||
|
/* predictor */
|
||||||
|
|
||||||
|
/* put back original diagonal values */
|
||||||
|
for (i = 0; i < n; i++)
|
||||||
|
h_x[(n + 1) * i] = diag_h_x[i];
|
||||||
|
|
||||||
|
matrix_vector(n, h_x, x, h_dot_x); /* compute h_dot_x = h_x * x */
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++) {
|
||||||
|
rho[i] = b[i];
|
||||||
|
for (j = 0; j < n; j++)
|
||||||
|
rho[i] -= a[n * i + j] * x[j];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
nu[i] = l[i] - x[i] + g[i];
|
||||||
|
tau[i] = u[i] - x[i] - t[i];
|
||||||
|
|
||||||
|
sigma[i] = c[i] - z[i] + s[i] + h_dot_x[i];
|
||||||
|
for (j = 0; j < m; j++)
|
||||||
|
sigma[i] -= a[n * j + i] * y[j];
|
||||||
|
|
||||||
|
gamma_z[i] = -z[i];
|
||||||
|
gamma_s[i] = -s[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
/* instrumentation */
|
||||||
|
x_h_x = 0;
|
||||||
|
primal_inf = 0;
|
||||||
|
dual_inf = 0;
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
x_h_x += h_dot_x[i] * x[i];
|
||||||
|
primal_inf += sqr(tau[i]);
|
||||||
|
primal_inf += sqr(nu[i]);
|
||||||
|
dual_inf += sqr(sigma[i]);
|
||||||
|
}
|
||||||
|
for (i = 0; i < m; i++)
|
||||||
|
primal_inf += sqr(rho[i]);
|
||||||
|
primal_inf = sqrt(primal_inf) / b_plus_1;
|
||||||
|
dual_inf = sqrt(dual_inf) / c_plus_1;
|
||||||
|
|
||||||
|
primal_obj = 0.5 * x_h_x;
|
||||||
|
dual_obj = -0.5 * x_h_x;
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
primal_obj += c[i] * x[i];
|
||||||
|
dual_obj += l[i] * z[i] - u[i] * s[i];
|
||||||
|
}
|
||||||
|
for (i = 0; i < m; i++)
|
||||||
|
dual_obj += b[i] * y[i];
|
||||||
|
|
||||||
|
sigfig = log10(ABS(primal_obj) + 1) - log10(ABS(primal_obj - dual_obj));
|
||||||
|
sigfig = max(sigfig, 0);
|
||||||
|
|
||||||
|
/* the diagnostics - after we computed our results we will
|
||||||
|
analyze them */
|
||||||
|
|
||||||
|
if (counter > counter_max)
|
||||||
|
status = ITERATION_LIMIT;
|
||||||
|
if (sigfig > sigfig_max)
|
||||||
|
status = OPTIMAL_SOLUTION;
|
||||||
|
if (primal_inf > 10e100)
|
||||||
|
status = PRIMAL_INFEASIBLE;
|
||||||
|
if (dual_inf > 10e100)
|
||||||
|
status = DUAL_INFEASIBLE;
|
||||||
|
if ((primal_inf > 10e100) & (dual_inf > 10e100))
|
||||||
|
status = PRIMAL_AND_DUAL_INFEASIBLE;
|
||||||
|
if (ABS(primal_obj) > 10e100)
|
||||||
|
status = PRIMAL_UNBOUNDED;
|
||||||
|
if (ABS(dual_obj) > 10e100)
|
||||||
|
status = DUAL_UNBOUNDED;
|
||||||
|
|
||||||
|
/* write some nice routine to enforce the time limit if you
|
||||||
|
_really_ want, however it's quite useless as you can compute
|
||||||
|
the time from the maximum number of iterations as every
|
||||||
|
iteration costs one cholesky decomposition plus a couple of
|
||||||
|
backsubstitutions */
|
||||||
|
|
||||||
|
/* generate report */
|
||||||
|
if ((verb >= FLOOD) | ((verb == STATUS) & (status != 0)))
|
||||||
|
printf("%7i | %.2e | %.2e | % .2e | % .2e | %6.3f | %.4f | %.2e\n",
|
||||||
|
counter, primal_inf, dual_inf, primal_obj, dual_obj, sigfig, alfa,
|
||||||
|
mu);
|
||||||
|
|
||||||
|
counter++;
|
||||||
|
|
||||||
|
if (status == 0) { /* we may keep on going, otherwise
|
||||||
|
it'll cost one loop extra plus a
|
||||||
|
messed up main diagonal of h_x */
|
||||||
|
/* intermediate variables (the ones with hat) */
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
hat_nu[i] = nu[i] + g[i] * gamma_z[i] / z[i];
|
||||||
|
hat_tau[i] = tau[i] - t[i] * gamma_s[i] / s[i];
|
||||||
|
/* diagonal terms */
|
||||||
|
d[i] = z[i] / g[i] + s[i] / t[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
/* initialization before the cholesky solver */
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
h_x[(n + 1) * i] = diag_h_x[i] + d[i];
|
||||||
|
c_x[i] = sigma[i] - z[i] * hat_nu[i] / g[i] - s[i] * hat_tau[i] / t[i];
|
||||||
|
}
|
||||||
|
for (i = 0; i < m; i++) {
|
||||||
|
c_y[i] = rho[i];
|
||||||
|
for (j = i; j < m; j++)
|
||||||
|
h_y[m * i + j] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* and do it */
|
||||||
|
solve_reduced(n, m, h_x, h_y, a, delta_x, delta_y, c_x, c_y, workspace,
|
||||||
|
PREDICTOR);
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
/* backsubstitution */
|
||||||
|
delta_s[i] = s[i] * (delta_x[i] - hat_tau[i]) / t[i];
|
||||||
|
delta_z[i] = z[i] * (hat_nu[i] - delta_x[i]) / g[i];
|
||||||
|
|
||||||
|
delta_g[i] = g[i] * (gamma_z[i] - delta_z[i]) / z[i];
|
||||||
|
delta_t[i] = t[i] * (gamma_s[i] - delta_s[i]) / s[i];
|
||||||
|
|
||||||
|
/* central path (corrector) */
|
||||||
|
gamma_z[i] = mu / g[i] - z[i] - delta_z[i] * delta_g[i] / g[i];
|
||||||
|
gamma_s[i] = mu / t[i] - s[i] - delta_s[i] * delta_t[i] / t[i];
|
||||||
|
|
||||||
|
/* (some more intermediate variables) the hat variables */
|
||||||
|
hat_nu[i] = nu[i] + g[i] * gamma_z[i] / z[i];
|
||||||
|
hat_tau[i] = tau[i] - t[i] * gamma_s[i] / s[i];
|
||||||
|
|
||||||
|
/* initialization before the cholesky */
|
||||||
|
c_x[i] = sigma[i] - z[i] * hat_nu[i] / g[i] - s[i] * hat_tau[i] / t[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++) { /* comput c_y and rho */
|
||||||
|
c_y[i] = rho[i];
|
||||||
|
for (j = i; j < m; j++)
|
||||||
|
h_y[m * i + j] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* and do it */
|
||||||
|
solve_reduced(n, m, h_x, h_y, a, delta_x, delta_y, c_x, c_y, workspace,
|
||||||
|
CORRECTOR);
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
/* backsubstitution */
|
||||||
|
delta_s[i] = s[i] * (delta_x[i] - hat_tau[i]) / t[i];
|
||||||
|
delta_z[i] = z[i] * (hat_nu[i] - delta_x[i]) / g[i];
|
||||||
|
|
||||||
|
delta_g[i] = g[i] * (gamma_z[i] - delta_z[i]) / z[i];
|
||||||
|
delta_t[i] = t[i] * (gamma_s[i] - delta_s[i]) / s[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
alfa = -1;
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
alfa = min(alfa, delta_g[i] / g[i]);
|
||||||
|
alfa = min(alfa, delta_t[i] / t[i]);
|
||||||
|
alfa = min(alfa, delta_s[i] / s[i]);
|
||||||
|
alfa = min(alfa, delta_z[i] / z[i]);
|
||||||
|
}
|
||||||
|
alfa = (margin - 1) / alfa;
|
||||||
|
|
||||||
|
/* compute mu */
|
||||||
|
for (i = 0, mu = 0; i < n; i++)
|
||||||
|
mu += z[i] * g[i] + s[i] * t[i];
|
||||||
|
mu = mu / (2 * n);
|
||||||
|
mu = mu * sqr((alfa - 1) / (alfa + 10));
|
||||||
|
|
||||||
|
for (i = 0; i < n; i++) {
|
||||||
|
x[i] += alfa * delta_x[i];
|
||||||
|
g[i] += alfa * delta_g[i];
|
||||||
|
t[i] += alfa * delta_t[i];
|
||||||
|
z[i] += alfa * delta_z[i];
|
||||||
|
s[i] += alfa * delta_s[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i = 0; i < m; i++)
|
||||||
|
y[i] += alfa * delta_y[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ((status == 1) && (verb >= STATUS)) {
|
||||||
|
printf("-------------------------------------------------------------------"
|
||||||
|
"---------------\n");
|
||||||
|
printf("optimization converged\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* free memory */
|
||||||
|
free(workspace);
|
||||||
|
free(diag_h_x);
|
||||||
|
free(h_y);
|
||||||
|
free(c_x);
|
||||||
|
free(c_y);
|
||||||
|
free(h_dot_x);
|
||||||
|
|
||||||
|
free(rho);
|
||||||
|
free(nu);
|
||||||
|
free(tau);
|
||||||
|
free(sigma);
|
||||||
|
free(gamma_z);
|
||||||
|
free(gamma_s);
|
||||||
|
|
||||||
|
free(hat_nu);
|
||||||
|
free(hat_tau);
|
||||||
|
|
||||||
|
free(delta_x);
|
||||||
|
free(delta_y);
|
||||||
|
free(delta_s);
|
||||||
|
free(delta_z);
|
||||||
|
free(delta_g);
|
||||||
|
free(delta_t);
|
||||||
|
|
||||||
|
free(d);
|
||||||
|
|
||||||
|
/* and return to sender */
|
||||||
|
return status;
|
||||||
|
}
|
93
src/classifier/svm/svm_light/pr_loqo/pr_loqo.h
Normal file
93
src/classifier/svm/svm_light/pr_loqo/pr_loqo.h
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
/*
|
||||||
|
* File: pr_loqo.h
|
||||||
|
* Purpose: solves quadratic programming problem for pattern recognition
|
||||||
|
* for support vectors
|
||||||
|
*
|
||||||
|
* Author: Alex J. Smola
|
||||||
|
* Created: 10/14/97
|
||||||
|
* Updated: 11/08/97
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* Copyright (c) 1997 GMD Berlin - All rights reserved
|
||||||
|
* THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE of GMD Berlin
|
||||||
|
* The copyright notice above does not evidence any
|
||||||
|
* actual or intended publication of this work.
|
||||||
|
*
|
||||||
|
* Unauthorized commercial use of this software is not allowed
|
||||||
|
*/
|
||||||
|
|
||||||
|
/* verbosity levels */
|
||||||
|
#ifndef _PR_LOQO_H_
|
||||||
|
#define _PR_LOQO_H_
|
||||||
|
|
||||||
|
#define QUIET 0
|
||||||
|
#define STATUS 1
|
||||||
|
#define FLOOD 2
|
||||||
|
|
||||||
|
/* status outputs */
|
||||||
|
|
||||||
|
#define STILL_RUNNING 0
|
||||||
|
#define OPTIMAL_SOLUTION 1
|
||||||
|
#define SUBOPTIMAL_SOLUTION 2
|
||||||
|
#define ITERATION_LIMIT 3
|
||||||
|
#define PRIMAL_INFEASIBLE 4
|
||||||
|
#define DUAL_INFEASIBLE 5
|
||||||
|
#define PRIMAL_AND_DUAL_INFEASIBLE 6
|
||||||
|
#define INCONSISTENT 7
|
||||||
|
#define PRIMAL_UNBOUNDED 8
|
||||||
|
#define DUAL_UNBOUNDED 9
|
||||||
|
#define TIME_LIMIT 10
|
||||||
|
|
||||||
|
/*
|
||||||
|
* solve the quadratic programming problem
|
||||||
|
*
|
||||||
|
* minimize c' * x + 1/2 x' * H * x
|
||||||
|
* subject to A*x = b
|
||||||
|
* l <= x <= u
|
||||||
|
*
|
||||||
|
* for a documentation see R. Vanderbei, LOQO: an Interior Point Code
|
||||||
|
* for Quadratic Programming
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* n : number of primal variables
|
||||||
|
* m : number of constraints (typically 1)
|
||||||
|
* h_x : dot product matrix (n.n)
|
||||||
|
* a : constraint matrix (n.m)
|
||||||
|
* b : constant term (m)
|
||||||
|
* l : lower bound (n)
|
||||||
|
* u : upper bound (m)
|
||||||
|
*
|
||||||
|
* primal : workspace for primal variables, has to be of size 3 n
|
||||||
|
*
|
||||||
|
* x = primal; n
|
||||||
|
* g = x + n; n
|
||||||
|
* t = g + n; n
|
||||||
|
*
|
||||||
|
* dual : workspace for dual variables, has to be of size m + 2 n
|
||||||
|
*
|
||||||
|
* y = dual; m
|
||||||
|
* z = y + m; n
|
||||||
|
* s = z + n; n
|
||||||
|
*
|
||||||
|
* verb : verbosity level
|
||||||
|
* sigfig_max : number of significant digits
|
||||||
|
* counter_max: stopping criterion
|
||||||
|
* restart : 1 if restart desired
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
int pr_loqo(int n, int m, double c[], double h_x[], double a[], double b[],
|
||||||
|
double l[], double u[], double primal[], double dual[], int verb,
|
||||||
|
double sigfig_max, int counter_max, double margin, double bound,
|
||||||
|
int restart);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* compile with
|
||||||
|
cc -O4 -c pr_loqo.c
|
||||||
|
cc -xO4 -fast -xarch=v8plus -xchip=ultra -xparallel -c pr_loqo.c
|
||||||
|
mex pr_loqo_c.c pr_loqo.o
|
||||||
|
cmex4 pr_loqo_c.c pr_loqo.o -DMATLAB4 -o pr_loqo_c4
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#endif // !_PR_LOQO_H_
|
198
src/classifier/svm/svm_light/svm_classify.c
Normal file
198
src/classifier/svm/svm_light/svm_classify.c
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_classify.c */
|
||||||
|
/* */
|
||||||
|
/* Classification module of Support Vector Machine. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 02.07.02 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2002 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/************************************************************************/
|
||||||
|
|
||||||
|
# include "svm_common.h"
|
||||||
|
|
||||||
|
char docfile[200];
|
||||||
|
char modelfile[200];
|
||||||
|
char predictionsfile[200];
|
||||||
|
|
||||||
|
void read_input_parameters(int, char **, char *, char *, char *, long *,
|
||||||
|
long *);
|
||||||
|
void print_help(void);
|
||||||
|
|
||||||
|
|
||||||
|
int main (int argc, char* argv[])
|
||||||
|
{
|
||||||
|
DOC *doc; /* test example */
|
||||||
|
WORD *words;
|
||||||
|
long max_docs,max_words_doc,lld;
|
||||||
|
long totdoc=0,queryid,slackid;
|
||||||
|
long correct=0,incorrect=0,no_accuracy=0;
|
||||||
|
long res_a=0,res_b=0,res_c=0,res_d=0,wnum,pred_format;
|
||||||
|
long j;
|
||||||
|
double t1,runtime=0;
|
||||||
|
double dist,doc_label,costfactor;
|
||||||
|
char *line,*comment;
|
||||||
|
FILE *predfl,*docfl;
|
||||||
|
MODEL *model;
|
||||||
|
|
||||||
|
read_input_parameters(argc,argv,docfile,modelfile,predictionsfile,
|
||||||
|
&verbosity,&pred_format);
|
||||||
|
|
||||||
|
nol_ll(docfile,&max_docs,&max_words_doc,&lld); /* scan size of input file */
|
||||||
|
max_words_doc+=2;
|
||||||
|
lld+=2;
|
||||||
|
|
||||||
|
line = (char *)my_malloc(sizeof(char)*lld);
|
||||||
|
words = (WORD *)my_malloc(sizeof(WORD)*(max_words_doc+10));
|
||||||
|
|
||||||
|
model=read_model(modelfile);
|
||||||
|
|
||||||
|
if(model->kernel_parm.kernel_type == 0) { /* linear kernel */
|
||||||
|
/* compute weight vector */
|
||||||
|
add_weight_vector_to_linear_model(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(verbosity>=2) {
|
||||||
|
printf("Classifying test examples.."); fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((docfl = fopen (docfile, "r")) == NULL)
|
||||||
|
{ perror (docfile); exit (1); }
|
||||||
|
if ((predfl = fopen (predictionsfile, "w")) == NULL)
|
||||||
|
{ perror (predictionsfile); exit (1); }
|
||||||
|
|
||||||
|
while((!feof(docfl)) && fgets(line,(int)lld,docfl)) {
|
||||||
|
if(line[0] == '#') continue; /* line contains comments */
|
||||||
|
parse_document(line,words,&doc_label,&queryid,&slackid,&costfactor,&wnum,
|
||||||
|
max_words_doc,&comment);
|
||||||
|
totdoc++;
|
||||||
|
if(model->kernel_parm.kernel_type == LINEAR) {/* For linear kernel, */
|
||||||
|
for(j=0;(words[j]).wnum != 0;j++) { /* check if feature numbers */
|
||||||
|
if((words[j]).wnum>model->totwords) /* are not larger than in */
|
||||||
|
(words[j]).wnum=0; /* model. Remove feature if */
|
||||||
|
} /* necessary. */
|
||||||
|
}
|
||||||
|
doc = create_example(-1,0,0,0.0,create_svector(words,comment,1.0));
|
||||||
|
t1=get_runtime();
|
||||||
|
|
||||||
|
if(model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
|
||||||
|
dist=classify_example_linear(model,doc);
|
||||||
|
}
|
||||||
|
else { /* non-linear kernel */
|
||||||
|
dist=classify_example(model,doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime+=(get_runtime()-t1);
|
||||||
|
free_example(doc,1);
|
||||||
|
|
||||||
|
if(dist>0) {
|
||||||
|
if(pred_format==0) { /* old weired output format */
|
||||||
|
fprintf(predfl,"%.8g:+1 %.8g:-1\n",dist,-dist);
|
||||||
|
}
|
||||||
|
if(doc_label>0) correct++; else incorrect++;
|
||||||
|
if(doc_label>0) res_a++; else res_b++;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if(pred_format==0) { /* old weired output format */
|
||||||
|
fprintf(predfl,"%.8g:-1 %.8g:+1\n",-dist,dist);
|
||||||
|
}
|
||||||
|
if(doc_label<0) correct++; else incorrect++;
|
||||||
|
if(doc_label>0) res_c++; else res_d++;
|
||||||
|
}
|
||||||
|
if(pred_format==1) { /* output the value of decision function */
|
||||||
|
fprintf(predfl,"%.8g\n",dist);
|
||||||
|
}
|
||||||
|
if((int)(0.01+(doc_label*doc_label)) != 1)
|
||||||
|
{ no_accuracy=1; } /* test data is not binary labeled */
|
||||||
|
if(verbosity>=2) {
|
||||||
|
if(totdoc % 100 == 0) {
|
||||||
|
printf("%ld..",totdoc); fflush(stdout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(predfl);
|
||||||
|
fclose(docfl);
|
||||||
|
free(line);
|
||||||
|
free(words);
|
||||||
|
free_model(model,1);
|
||||||
|
|
||||||
|
if(verbosity>=2) {
|
||||||
|
printf("done\n");
|
||||||
|
|
||||||
|
/* Note by Gary Boone Date: 29 April 2000 */
|
||||||
|
/* o Timing is inaccurate. The timer has 0.01 second resolution. */
|
||||||
|
/* Because classification of a single vector takes less than */
|
||||||
|
/* 0.01 secs, the timer was underflowing. */
|
||||||
|
printf("Runtime (without IO) in cpu-seconds: %.2f\n",
|
||||||
|
(float)(runtime/100.0));
|
||||||
|
|
||||||
|
}
|
||||||
|
if((!no_accuracy) && (verbosity>=1)) {
|
||||||
|
printf("Accuracy on test set: %.2f%% (%ld correct, %ld incorrect, %ld total)\n",(float)(correct)*100.0/totdoc,correct,incorrect,totdoc);
|
||||||
|
printf("Precision/recall on test set: %.2f%%/%.2f%%\n",(float)(res_a)*100.0/(res_a+res_b),(float)(res_a)*100.0/(res_a+res_c));
|
||||||
|
}
|
||||||
|
|
||||||
|
return(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_input_parameters(int argc, char **argv, char *docfile,
|
||||||
|
char *modelfile, char *predictionsfile,
|
||||||
|
long int *verbosity, long int *pred_format)
|
||||||
|
{
|
||||||
|
long i;
|
||||||
|
|
||||||
|
/* set default */
|
||||||
|
strcpy (modelfile, "svm_model");
|
||||||
|
strcpy (predictionsfile, "svm_predictions");
|
||||||
|
(*verbosity)=2;
|
||||||
|
(*pred_format)=1;
|
||||||
|
|
||||||
|
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||||
|
switch ((argv[i])[1])
|
||||||
|
{
|
||||||
|
case 'h': print_help(); exit(0);
|
||||||
|
case 'v': i++; (*verbosity)=atol(argv[i]); break;
|
||||||
|
case 'f': i++; (*pred_format)=atol(argv[i]); break;
|
||||||
|
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if((i+1)>=argc) {
|
||||||
|
printf("\nNot enough input parameters!\n\n");
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
strcpy (docfile, argv[i]);
|
||||||
|
strcpy (modelfile, argv[i+1]);
|
||||||
|
if((i+2)<argc) {
|
||||||
|
strcpy (predictionsfile, argv[i+2]);
|
||||||
|
}
|
||||||
|
if(((*pred_format) != 0) && ((*pred_format) != 1)) {
|
||||||
|
printf("\nOutput format can only take the values 0 or 1!\n\n");
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_help(void)
|
||||||
|
{
|
||||||
|
printf("\nSVM-light %s: Support Vector Machine, classification module %s\n",VERSION,VERSION_DATE);
|
||||||
|
copyright_notice();
|
||||||
|
printf(" usage: svm_classify [options] example_file model_file output_file\n\n");
|
||||||
|
printf("options: -h -> this help\n");
|
||||||
|
printf(" -v [0..3] -> verbosity level (default 2)\n");
|
||||||
|
printf(" -f [0,1] -> 0: old output format of V1.0\n");
|
||||||
|
printf(" -> 1: output the value of decision function (default)\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
1854
src/classifier/svm/svm_light/svm_common.c
Normal file
1854
src/classifier/svm/svm_light/svm_common.c
Normal file
File diff suppressed because it is too large
Load Diff
385
src/classifier/svm/svm_light/svm_common.h
Normal file
385
src/classifier/svm/svm_light/svm_common.h
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
/************************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_common.h */
|
||||||
|
/* */
|
||||||
|
/* Definitions and functions used in both svm_learn and svm_classify. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 31.10.05 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2005 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/************************************************************************/
|
||||||
|
|
||||||
|
#ifndef SVM_COMMON
|
||||||
|
#define SVM_COMMON
|
||||||
|
|
||||||
|
#include <ctype.h>
|
||||||
|
#include <float.h>
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <time.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define VERSION "V6.20"
|
||||||
|
#define VERSION_DATE "14.08.08"
|
||||||
|
|
||||||
|
#define CFLOAT float /* the type of float to use for caching */
|
||||||
|
/* kernel evaluations. Using float saves */
|
||||||
|
/* us some memory, but you can use double, too */
|
||||||
|
#define FNUM int32_t /* the type used for storing feature ids */
|
||||||
|
#define FNUM_MAX 2147483647 /* maximum value that FNUM type can take */
|
||||||
|
#define FVAL float /* the type used for storing feature values */
|
||||||
|
#define MAXFEATNUM \
|
||||||
|
99999999 /* maximum feature number (must be in \
|
||||||
|
valid range of FNUM type and long int!) */
|
||||||
|
|
||||||
|
#define LINEAR 0 /* linear kernel type */
|
||||||
|
#define POLY 1 /* polynomial kernel type */
|
||||||
|
#define RBF 2 /* rbf kernel type */
|
||||||
|
#define SIGMOID 3 /* sigmoid kernel type */
|
||||||
|
#define CUSTOM 4 /* userdefined kernel function from kernel.h */
|
||||||
|
#define GRAM 5 /* use explicit gram matrix from kernel_parm */
|
||||||
|
|
||||||
|
#define CLASSIFICATION 1 /* train classification model */
|
||||||
|
#define REGRESSION 2 /* train regression model */
|
||||||
|
#define RANKING 3 /* train ranking model */
|
||||||
|
#define OPTIMIZATION 4 /* train on general set of constraints */
|
||||||
|
|
||||||
|
#define MAXSHRINK 50000 /* maximum number of shrinking rounds */
|
||||||
|
|
||||||
|
typedef struct word {
|
||||||
|
FNUM wnum; /* word number */
|
||||||
|
FVAL weight; /* word weight */
|
||||||
|
} WORD;
|
||||||
|
|
||||||
|
typedef struct svector {
|
||||||
|
WORD *words; /* The features/values in the vector by
|
||||||
|
increasing feature-number. Feature
|
||||||
|
numbers that are skipped are
|
||||||
|
interpreted as having value zero. */
|
||||||
|
double twonorm_sq; /* The squared euclidian length of the
|
||||||
|
vector. Used to speed up the RBF kernel. */
|
||||||
|
char *userdefined; /* You can put additional information
|
||||||
|
here. This can be useful, if you are
|
||||||
|
implementing your own kernel that
|
||||||
|
does not work with feature/values
|
||||||
|
representations (for example a
|
||||||
|
string kernel). By default,
|
||||||
|
svm-light will put here the string
|
||||||
|
after the # sign from each line of
|
||||||
|
the input file. */
|
||||||
|
long kernel_id; /* Feature vectors with different
|
||||||
|
kernel_id's are orthogonal (ie. the
|
||||||
|
feature number do not match). This
|
||||||
|
is used for computing component
|
||||||
|
kernels for linear constraints which
|
||||||
|
are a sum of several different
|
||||||
|
weight vectors. (currently not
|
||||||
|
implemented). */
|
||||||
|
struct svector *next; /* Let's you set up a list of SVECTOR's
|
||||||
|
for linear constraints which are a
|
||||||
|
sum of multiple feature
|
||||||
|
vectors. List is terminated by
|
||||||
|
NULL. */
|
||||||
|
double factor; /* Factor by which this feature vector
|
||||||
|
is multiplied in the sum. */
|
||||||
|
} SVECTOR;
|
||||||
|
|
||||||
|
typedef struct doc {
|
||||||
|
long docnum; /* Document ID. This has to be the position of
|
||||||
|
the document in the training set array. */
|
||||||
|
long queryid; /* for learning rankings, constraints are
|
||||||
|
generated for documents with the same
|
||||||
|
queryID. */
|
||||||
|
double costfactor; /* Scales the cost of misclassifying this
|
||||||
|
document by this factor. The effect of this
|
||||||
|
value is, that the upper bound on the alpha
|
||||||
|
for this example is scaled by this factor.
|
||||||
|
The factors are set by the feature
|
||||||
|
'cost:<val>' in the training data. */
|
||||||
|
long slackid; /* Index of the slack variable
|
||||||
|
corresponding to this
|
||||||
|
constraint. All constraints with the
|
||||||
|
same slackid share the same slack
|
||||||
|
variable. This can only be used for
|
||||||
|
svm_learn_optimization. */
|
||||||
|
long kernelid; /* Position in gram matrix where kernel
|
||||||
|
value can be found when using an
|
||||||
|
explicit gram matrix
|
||||||
|
(i.e. kernel_type=GRAM). */
|
||||||
|
SVECTOR *fvec; /* Feature vector of the example. The
|
||||||
|
feature vector can actually be a
|
||||||
|
list of feature vectors. For
|
||||||
|
example, the list will have two
|
||||||
|
elements, if this DOC is a
|
||||||
|
preference constraint. The one
|
||||||
|
vector that is supposed to be ranked
|
||||||
|
higher, will have a factor of +1,
|
||||||
|
the lower ranked one should have a
|
||||||
|
factor of -1. */
|
||||||
|
} DOC;
|
||||||
|
|
||||||
|
typedef struct learn_parm {
|
||||||
|
long type; /* selects between regression and
|
||||||
|
classification */
|
||||||
|
double svm_c; /* upper bound C on alphas */
|
||||||
|
double eps; /* regression epsilon (eps=1.0 for
|
||||||
|
classification */
|
||||||
|
double svm_costratio; /* factor to multiply C for positive examples */
|
||||||
|
double transduction_posratio; /* fraction of unlabeled examples to be */
|
||||||
|
/* classified as positives */
|
||||||
|
long biased_hyperplane; /* if nonzero, use hyperplane w*x+b=0
|
||||||
|
otherwise w*x=0 */
|
||||||
|
long sharedslack; /* if nonzero, it will use the shared
|
||||||
|
slack variable mode in
|
||||||
|
svm_learn_optimization. It requires
|
||||||
|
that the slackid is set for every
|
||||||
|
training example */
|
||||||
|
long svm_maxqpsize; /* size q of working set */
|
||||||
|
long svm_newvarsinqp; /* new variables to enter the working set
|
||||||
|
in each iteration */
|
||||||
|
long kernel_cache_size; /* size of kernel cache in megabytes */
|
||||||
|
double epsilon_crit; /* tolerable error for distances used
|
||||||
|
in stopping criterion */
|
||||||
|
double epsilon_shrink; /* how much a multiplier should be above
|
||||||
|
zero for shrinking */
|
||||||
|
long svm_iter_to_shrink; /* iterations h after which an example can
|
||||||
|
be removed by shrinking */
|
||||||
|
long maxiter; /* number of iterations after which the
|
||||||
|
optimizer terminates, if there was
|
||||||
|
no progress in maxdiff */
|
||||||
|
long remove_inconsistent; /* exclude examples with alpha at C and
|
||||||
|
retrain */
|
||||||
|
long skip_final_opt_check; /* do not check KT-Conditions at the end of
|
||||||
|
optimization for examples removed by
|
||||||
|
shrinking. WARNING: This might lead to
|
||||||
|
sub-optimal solutions! */
|
||||||
|
long compute_loo; /* if nonzero, computes leave-one-out
|
||||||
|
estimates */
|
||||||
|
double rho; /* parameter in xi/alpha-estimates and for
|
||||||
|
pruning leave-one-out range [1..2] */
|
||||||
|
long xa_depth; /* parameter in xi/alpha-estimates upper
|
||||||
|
bounding the number of SV the current
|
||||||
|
alpha_t is distributed over */
|
||||||
|
char predfile[200]; /* file for predicitions on unlabeled examples
|
||||||
|
in transduction */
|
||||||
|
char alphafile[200]; /* file to store optimal alphas in. use
|
||||||
|
empty string if alphas should not be
|
||||||
|
output */
|
||||||
|
|
||||||
|
/* you probably do not want to touch the following */
|
||||||
|
double epsilon_const; /* tolerable error on eq-constraint */
|
||||||
|
double epsilon_a; /* tolerable error on alphas at bounds */
|
||||||
|
double opt_precision; /* precision of solver, set to e.g. 1e-21
|
||||||
|
if you get convergence problems */
|
||||||
|
|
||||||
|
/* the following are only for internal use */
|
||||||
|
long svm_c_steps; /* do so many steps for finding optimal C */
|
||||||
|
double svm_c_factor; /* increase C by this factor every step */
|
||||||
|
double svm_costratio_unlab;
|
||||||
|
double svm_unlabbound;
|
||||||
|
double *svm_cost; /* individual upper bounds for each var */
|
||||||
|
long totwords; /* number of features */
|
||||||
|
} LEARN_PARM;
|
||||||
|
|
||||||
|
typedef struct matrix {
|
||||||
|
int n; /* number of rows */
|
||||||
|
int m; /* number of colums */
|
||||||
|
double **element;
|
||||||
|
} MATRIX;
|
||||||
|
|
||||||
|
typedef struct kernel_parm {
|
||||||
|
long kernel_type; /* 0=linear, 1=poly, 2=rbf, 3=sigmoid,
|
||||||
|
4=custom, 5=matrix */
|
||||||
|
long poly_degree;
|
||||||
|
double rbf_gamma;
|
||||||
|
double coef_lin;
|
||||||
|
double coef_const;
|
||||||
|
char custom[50]; /* for user supplied kernel */
|
||||||
|
MATRIX *gram_matrix; /* here one can directly supply the kernel
|
||||||
|
matrix. The matrix is accessed if
|
||||||
|
kernel_type=5 is selected. */
|
||||||
|
} KERNEL_PARM;
|
||||||
|
|
||||||
|
typedef struct model {
|
||||||
|
long sv_num;
|
||||||
|
long at_upper_bound;
|
||||||
|
double b;
|
||||||
|
DOC **supvec;
|
||||||
|
double *alpha;
|
||||||
|
long *index; /* index from docnum to position in model */
|
||||||
|
long totwords; /* number of features */
|
||||||
|
long totdoc; /* number of training documents */
|
||||||
|
KERNEL_PARM kernel_parm; /* kernel */
|
||||||
|
|
||||||
|
/* the following values are not written to file */
|
||||||
|
double loo_error, loo_recall, loo_precision; /* leave-one-out estimates */
|
||||||
|
double xa_error, xa_recall, xa_precision; /* xi/alpha estimates */
|
||||||
|
double *lin_weights; /* weights for linear case using
|
||||||
|
folding */
|
||||||
|
double maxdiff; /* precision, up to which this
|
||||||
|
model is accurate */
|
||||||
|
} MODEL;
|
||||||
|
|
||||||
|
/* The following specifies a quadratic problem of the following form
|
||||||
|
|
||||||
|
minimize g0 * x + 1/2 x' * G * x
|
||||||
|
subject to ce*x - ce0 = 0
|
||||||
|
l <= x <= u
|
||||||
|
*/
|
||||||
|
typedef struct quadratic_program {
|
||||||
|
long opt_n; /* number of variables */
|
||||||
|
long opt_m; /* number of linear equality constraints */
|
||||||
|
double *opt_ce, *opt_ce0; /* linear equality constraints
|
||||||
|
opt_ce[i]*x - opt_ceo[i]=0 */
|
||||||
|
double *opt_g; /* hessian of objective */
|
||||||
|
double *opt_g0; /* linear part of objective */
|
||||||
|
double *opt_xinit; /* initial value for variables */
|
||||||
|
double *opt_low, *opt_up; /* box constraints */
|
||||||
|
} QP;
|
||||||
|
|
||||||
|
typedef struct kernel_cache {
|
||||||
|
long *index; /* cache some kernel evalutations */
|
||||||
|
CFLOAT *buffer; /* to improve speed */
|
||||||
|
long *invindex;
|
||||||
|
long *active2totdoc;
|
||||||
|
long *totdoc2active;
|
||||||
|
long *lru;
|
||||||
|
long *occu;
|
||||||
|
long elems;
|
||||||
|
long max_elems;
|
||||||
|
long time;
|
||||||
|
long activenum;
|
||||||
|
long buffsize;
|
||||||
|
} KERNEL_CACHE;
|
||||||
|
|
||||||
|
typedef struct timing_profile {
|
||||||
|
double time_kernel;
|
||||||
|
double time_opti;
|
||||||
|
double time_shrink;
|
||||||
|
double time_update;
|
||||||
|
double time_model;
|
||||||
|
double time_check;
|
||||||
|
double time_select;
|
||||||
|
} TIMING;
|
||||||
|
|
||||||
|
typedef struct shrink_state {
|
||||||
|
long *active;
|
||||||
|
long *inactive_since;
|
||||||
|
long deactnum;
|
||||||
|
double **a_history; /* for shrinking with non-linear kernel */
|
||||||
|
long maxhistory;
|
||||||
|
double *last_a; /* for shrinking with linear kernel */
|
||||||
|
double *last_lin; /* for shrinking with linear kernel */
|
||||||
|
} SHRINK_STATE;
|
||||||
|
|
||||||
|
typedef struct randpair {
|
||||||
|
long val, sort;
|
||||||
|
} RANDPAIR;
|
||||||
|
|
||||||
|
double classify_example(MODEL *, DOC *);
|
||||||
|
double classify_example_linear(MODEL *, DOC *);
|
||||||
|
double kernel(KERNEL_PARM *, DOC *, DOC *);
|
||||||
|
double single_kernel(KERNEL_PARM *, SVECTOR *, SVECTOR *);
|
||||||
|
double custom_kernel(KERNEL_PARM *, SVECTOR *, SVECTOR *);
|
||||||
|
SVECTOR *create_svector(WORD *, char *, double);
|
||||||
|
SVECTOR *create_svector_shallow(WORD *, char *, double);
|
||||||
|
SVECTOR *create_svector_n(double *, long, char *, double);
|
||||||
|
SVECTOR *create_svector_n_r(double *, long, char *, double, double);
|
||||||
|
SVECTOR *copy_svector(SVECTOR *);
|
||||||
|
SVECTOR *copy_svector_shallow(SVECTOR *);
|
||||||
|
void free_svector(SVECTOR *);
|
||||||
|
void free_svector_shallow(SVECTOR *);
|
||||||
|
double sprod_ss(SVECTOR *, SVECTOR *);
|
||||||
|
SVECTOR *sub_ss(SVECTOR *, SVECTOR *);
|
||||||
|
SVECTOR *sub_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||||
|
SVECTOR *add_ss(SVECTOR *, SVECTOR *);
|
||||||
|
SVECTOR *add_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||||
|
SVECTOR *multadd_ss(SVECTOR *a, SVECTOR *b, double fa, double fb);
|
||||||
|
SVECTOR *multadd_ss_r(SVECTOR *a, SVECTOR *b, double fa, double fb,
|
||||||
|
double min_non_zero);
|
||||||
|
SVECTOR *add_list_ns(SVECTOR *a);
|
||||||
|
SVECTOR *add_dual_list_ns_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||||
|
SVECTOR *add_list_ns_r(SVECTOR *a, double min_non_zero);
|
||||||
|
SVECTOR *add_list_ss(SVECTOR *);
|
||||||
|
SVECTOR *add_dual_list_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||||
|
SVECTOR *add_list_ss_r(SVECTOR *, double min_non_zero);
|
||||||
|
SVECTOR *add_list_sort_ss(SVECTOR *);
|
||||||
|
SVECTOR *add_dual_list_sort_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||||
|
SVECTOR *add_list_sort_ss_r(SVECTOR *, double min_non_zero);
|
||||||
|
void add_list_n_ns(double *vec_n, SVECTOR *vec_s, double faktor);
|
||||||
|
void append_svector_list(SVECTOR *a, SVECTOR *b);
|
||||||
|
void mult_svector_list(SVECTOR *a, double factor);
|
||||||
|
void setfactor_svector_list(SVECTOR *a, double factor);
|
||||||
|
SVECTOR *smult_s(SVECTOR *, double);
|
||||||
|
SVECTOR *shift_s(SVECTOR *a, long shift);
|
||||||
|
int featvec_eq(SVECTOR *, SVECTOR *);
|
||||||
|
double model_length_s(MODEL *);
|
||||||
|
double model_length_n(MODEL *);
|
||||||
|
void mult_vector_ns(double *, SVECTOR *, double);
|
||||||
|
void add_vector_ns(double *, SVECTOR *, double);
|
||||||
|
double sprod_ns(double *, SVECTOR *);
|
||||||
|
void add_weight_vector_to_linear_model(MODEL *);
|
||||||
|
DOC *create_example(long, long, long, double, SVECTOR *);
|
||||||
|
void free_example(DOC *, long);
|
||||||
|
long *random_order(long n);
|
||||||
|
void print_percent_progress(long *progress, long maximum, long percentperdot,
|
||||||
|
char *symbol);
|
||||||
|
MATRIX *create_matrix(int n, int m);
|
||||||
|
MATRIX *realloc_matrix(MATRIX *matrix, int n, int m);
|
||||||
|
double *create_nvector(int n);
|
||||||
|
void clear_nvector(double *vec, long int n);
|
||||||
|
MATRIX *copy_matrix(MATRIX *matrix);
|
||||||
|
void free_matrix(MATRIX *matrix);
|
||||||
|
void free_nvector(double *vector);
|
||||||
|
MATRIX *transpose_matrix(MATRIX *matrix);
|
||||||
|
MATRIX *cholesky_matrix(MATRIX *A);
|
||||||
|
double *find_indep_subset_of_matrix(MATRIX *A, double epsilon);
|
||||||
|
MATRIX *invert_ltriangle_matrix(MATRIX *L);
|
||||||
|
double *prod_nvector_matrix(double *v, MATRIX *A);
|
||||||
|
double *prod_matrix_nvector(MATRIX *A, double *v);
|
||||||
|
double *prod_nvector_ltmatrix(double *v, MATRIX *A);
|
||||||
|
double *prod_ltmatrix_nvector(MATRIX *A, double *v);
|
||||||
|
MATRIX *prod_matrix_matrix(MATRIX *A, MATRIX *B);
|
||||||
|
void print_matrix(MATRIX *matrix);
|
||||||
|
MODEL *read_model(char *);
|
||||||
|
MODEL *copy_model(MODEL *);
|
||||||
|
MODEL *compact_linear_model(MODEL *model);
|
||||||
|
void free_model(MODEL *, int);
|
||||||
|
void read_documents(char *, DOC ***, double **, long *, long *);
|
||||||
|
int parse_document(char *, WORD *, double *, long *, long *, double *, long *,
|
||||||
|
long, char **);
|
||||||
|
int read_word(char *in, char *out);
|
||||||
|
double *read_alphas(char *, long);
|
||||||
|
void set_learning_defaults(LEARN_PARM *, KERNEL_PARM *);
|
||||||
|
int check_learning_parms(LEARN_PARM *, KERNEL_PARM *);
|
||||||
|
void nol_ll(char *, long *, long *, long *);
|
||||||
|
long minl(long, long);
|
||||||
|
long maxl(long, long);
|
||||||
|
double get_runtime(void);
|
||||||
|
int space_or_null(int);
|
||||||
|
void *my_malloc(size_t);
|
||||||
|
void copyright_notice(void);
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
int isnan(double);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
extern long verbosity; /* verbosity level (0-4) */
|
||||||
|
extern long kernel_cache_statistic;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
1054
src/classifier/svm/svm_light/svm_hideo.c
Normal file
1054
src/classifier/svm/svm_light/svm_hideo.c
Normal file
File diff suppressed because it is too large
Load Diff
4223
src/classifier/svm/svm_light/svm_learn.c
Normal file
4223
src/classifier/svm/svm_light/svm_learn.c
Normal file
File diff suppressed because it is too large
Load Diff
169
src/classifier/svm/svm_light/svm_learn.h
Normal file
169
src/classifier/svm/svm_light/svm_learn.h
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_learn.h */
|
||||||
|
/* */
|
||||||
|
/* Declarations for learning module of Support Vector Machine. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 02.07.02 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2002 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#ifndef SVM_LEARN
|
||||||
|
#define SVM_LEARN
|
||||||
|
|
||||||
|
#include "svm_common.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void svm_learn_classification(DOC **, double *, long, long, LEARN_PARM *,
|
||||||
|
KERNEL_PARM *, KERNEL_CACHE *, MODEL *, double *);
|
||||||
|
void svm_learn_regression(DOC **, double *, long, long, LEARN_PARM *,
|
||||||
|
KERNEL_PARM *, KERNEL_CACHE **, MODEL *);
|
||||||
|
void svm_learn_ranking(DOC **, double *, long, long, LEARN_PARM *,
|
||||||
|
KERNEL_PARM *, KERNEL_CACHE **, MODEL *);
|
||||||
|
void svm_learn_optimization(DOC **, double *, long, long, LEARN_PARM *,
|
||||||
|
KERNEL_PARM *, KERNEL_CACHE *, MODEL *, double *);
|
||||||
|
long optimize_to_convergence(DOC **, long *, long, long, LEARN_PARM *,
|
||||||
|
KERNEL_PARM *, KERNEL_CACHE *, SHRINK_STATE *,
|
||||||
|
MODEL *, long *, long *, double *, double *,
|
||||||
|
double *, TIMING *, double *, long, long);
|
||||||
|
long optimize_to_convergence_sharedslack(DOC **, long *, long, long,
|
||||||
|
LEARN_PARM *, KERNEL_PARM *,
|
||||||
|
KERNEL_CACHE *, SHRINK_STATE *,
|
||||||
|
MODEL *, double *, double *, double *,
|
||||||
|
TIMING *, double *);
|
||||||
|
double compute_objective_function(double *, double *, double *, double, long *,
|
||||||
|
long *);
|
||||||
|
void clear_index(long *);
|
||||||
|
void add_to_index(long *, long);
|
||||||
|
long compute_index(long *, long, long *);
|
||||||
|
void optimize_svm(DOC **, long *, long *, long *, double, long *, long *,
|
||||||
|
MODEL *, long, long *, long, double *, double *, double *,
|
||||||
|
LEARN_PARM *, CFLOAT *, KERNEL_PARM *, QP *, double *);
|
||||||
|
void compute_matrices_for_optimization(DOC **, long *, long *, long *, double,
|
||||||
|
long *, long *, long *, MODEL *,
|
||||||
|
double *, double *, double *, long, long,
|
||||||
|
LEARN_PARM *, CFLOAT *, KERNEL_PARM *,
|
||||||
|
QP *);
|
||||||
|
long calculate_svm_model(DOC **, long *, long *, double *, double *, double *,
|
||||||
|
double *, LEARN_PARM *, long *, long *, MODEL *);
|
||||||
|
long check_optimality(MODEL *, long *, long *, double *, double *, double *,
|
||||||
|
long, LEARN_PARM *, double *, double, long *, long *,
|
||||||
|
long *, long *, long, KERNEL_PARM *);
|
||||||
|
long check_optimality_sharedslack(
|
||||||
|
DOC **docs, MODEL *model, long int *label, double *a, double *lin,
|
||||||
|
double *c, double *slack, double *alphaslack, long int totdoc,
|
||||||
|
LEARN_PARM *learn_parm, double *maxdiff, double epsilon_crit_org,
|
||||||
|
long int *misclassified, long int *active2dnum,
|
||||||
|
long int *last_suboptimal_at, long int iteration, KERNEL_PARM *kernel_parm);
|
||||||
|
void compute_shared_slacks(DOC **docs, long int *label, double *a, double *lin,
|
||||||
|
double *c, long int *active2dnum,
|
||||||
|
LEARN_PARM *learn_parm, double *slack,
|
||||||
|
double *alphaslack);
|
||||||
|
long identify_inconsistent(double *, long *, long *, long, LEARN_PARM *, long *,
|
||||||
|
long *);
|
||||||
|
long identify_misclassified(double *, long *, long *, long, MODEL *, long *,
|
||||||
|
long *);
|
||||||
|
long identify_one_misclassified(double *, long *, long *, long, MODEL *, long *,
|
||||||
|
long *);
|
||||||
|
long incorporate_unlabeled_examples(MODEL *, long *, long *, long *, double *,
|
||||||
|
double *, long, double *, long *, long *,
|
||||||
|
long, KERNEL_PARM *, LEARN_PARM *);
|
||||||
|
void update_linear_component(DOC **, long *, long *, double *, double *, long *,
|
||||||
|
long, long, KERNEL_PARM *, KERNEL_CACHE *,
|
||||||
|
double *, CFLOAT *, double *);
|
||||||
|
long select_next_qp_subproblem_grad(long *, long *, double *, double *,
|
||||||
|
double *, long, long, LEARN_PARM *, long *,
|
||||||
|
long *, long *, double *, long *,
|
||||||
|
KERNEL_CACHE *, long, long *, long *);
|
||||||
|
long select_next_qp_subproblem_rand(long *, long *, double *, double *,
|
||||||
|
double *, long, long, LEARN_PARM *, long *,
|
||||||
|
long *, long *, double *, long *,
|
||||||
|
KERNEL_CACHE *, long *, long *, long);
|
||||||
|
long select_next_qp_slackset(DOC **docs, long int *label, double *a,
|
||||||
|
double *lin, double *slack, double *alphaslack,
|
||||||
|
double *c, LEARN_PARM *learn_parm,
|
||||||
|
long int *active2dnum, double *maxviol);
|
||||||
|
void select_top_n(double *, long, long *, long);
|
||||||
|
void init_shrink_state(SHRINK_STATE *, long, long);
|
||||||
|
void shrink_state_cleanup(SHRINK_STATE *);
|
||||||
|
long shrink_problem(DOC **, LEARN_PARM *, SHRINK_STATE *, KERNEL_PARM *, long *,
|
||||||
|
long *, long, long, long, double *, long *);
|
||||||
|
void reactivate_inactive_examples(long *, long *, double *, SHRINK_STATE *,
|
||||||
|
double *, double *, long, long, long,
|
||||||
|
LEARN_PARM *, long *, DOC **, KERNEL_PARM *,
|
||||||
|
KERNEL_CACHE *, MODEL *, CFLOAT *, double *,
|
||||||
|
double *);
|
||||||
|
|
||||||
|
/* cache kernel evalutations to improve speed */
|
||||||
|
KERNEL_CACHE *kernel_cache_init(long, long);
|
||||||
|
void kernel_cache_cleanup(KERNEL_CACHE *);
|
||||||
|
void get_kernel_row(KERNEL_CACHE *, DOC **, long, long, long *, CFLOAT *,
|
||||||
|
KERNEL_PARM *);
|
||||||
|
void cache_kernel_row(KERNEL_CACHE *, DOC **, long, KERNEL_PARM *);
|
||||||
|
void cache_multiple_kernel_rows(KERNEL_CACHE *, DOC **, long *, long,
|
||||||
|
KERNEL_PARM *);
|
||||||
|
void kernel_cache_shrink(KERNEL_CACHE *, long, long, long *);
|
||||||
|
void kernel_cache_reset_lru(KERNEL_CACHE *);
|
||||||
|
long kernel_cache_malloc(KERNEL_CACHE *);
|
||||||
|
void kernel_cache_free(KERNEL_CACHE *, long);
|
||||||
|
long kernel_cache_free_lru(KERNEL_CACHE *);
|
||||||
|
CFLOAT *kernel_cache_clean_and_malloc(KERNEL_CACHE *, long);
|
||||||
|
long kernel_cache_touch(KERNEL_CACHE *, long);
|
||||||
|
long kernel_cache_check(KERNEL_CACHE *, long);
|
||||||
|
long kernel_cache_space_available(KERNEL_CACHE *);
|
||||||
|
|
||||||
|
void compute_xa_estimates(MODEL *, long *, long *, long, DOC **, double *,
|
||||||
|
double *, KERNEL_PARM *, LEARN_PARM *, double *,
|
||||||
|
double *, double *);
|
||||||
|
double xa_estimate_error(MODEL *, long *, long *, long, DOC **, double *,
|
||||||
|
double *, KERNEL_PARM *, LEARN_PARM *);
|
||||||
|
double xa_estimate_recall(MODEL *, long *, long *, long, DOC **, double *,
|
||||||
|
double *, KERNEL_PARM *, LEARN_PARM *);
|
||||||
|
double xa_estimate_precision(MODEL *, long *, long *, long, DOC **, double *,
|
||||||
|
double *, KERNEL_PARM *, LEARN_PARM *);
|
||||||
|
void avg_similarity_of_sv_of_one_class(MODEL *, DOC **, double *, long *,
|
||||||
|
KERNEL_PARM *, double *, double *);
|
||||||
|
double most_similar_sv_of_same_class(MODEL *, DOC **, double *, long, long *,
|
||||||
|
KERNEL_PARM *, LEARN_PARM *);
|
||||||
|
double distribute_alpha_t_greedily(long *, long, DOC **, double *, long, long *,
|
||||||
|
KERNEL_PARM *, LEARN_PARM *, double);
|
||||||
|
double distribute_alpha_t_greedily_noindex(MODEL *, DOC **, double *, long,
|
||||||
|
long *, KERNEL_PARM *, LEARN_PARM *,
|
||||||
|
double);
|
||||||
|
void estimate_transduction_quality(MODEL *, long *, long *, long, DOC **,
|
||||||
|
double *);
|
||||||
|
double estimate_margin_vcdim(MODEL *, double, double);
|
||||||
|
double estimate_sphere(MODEL *);
|
||||||
|
double estimate_r_delta_average(DOC **, long, KERNEL_PARM *);
|
||||||
|
double estimate_r_delta(DOC **, long, KERNEL_PARM *);
|
||||||
|
double length_of_longest_document_vector(DOC **, long, KERNEL_PARM *);
|
||||||
|
|
||||||
|
void write_model(char *, MODEL *);
|
||||||
|
void write_prediction(char *, MODEL *, double *, double *, long *, long *, long,
|
||||||
|
LEARN_PARM *);
|
||||||
|
void write_alphas(char *, double *, long *, long);
|
||||||
|
|
||||||
|
typedef struct cache_parm_s {
|
||||||
|
KERNEL_CACHE *kernel_cache;
|
||||||
|
CFLOAT *cache;
|
||||||
|
DOC **docs;
|
||||||
|
long m;
|
||||||
|
KERNEL_PARM *kernel_parm;
|
||||||
|
long offset, stepsize;
|
||||||
|
} cache_parm_t;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
303
src/classifier/svm/svm_light/svm_learn_main.c
Normal file
303
src/classifier/svm/svm_light/svm_learn_main.c
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_learn_main.c */
|
||||||
|
/* */
|
||||||
|
/* Command line interface to the learning module of the */
|
||||||
|
/* Support Vector Machine. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 02.07.02 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2000 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
/* if svm-learn is used out of C++, define it as extern "C" */
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
# include "svm_common.h"
|
||||||
|
# include "svm_learn.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
char docfile[200]; /* file with training examples */
|
||||||
|
char modelfile[200]; /* file for resulting classifier */
|
||||||
|
char restartfile[200]; /* file with initial alphas */
|
||||||
|
|
||||||
|
void read_input_parameters(int, char **, char *, char *, char *, long *,
|
||||||
|
LEARN_PARM *, KERNEL_PARM *);
|
||||||
|
void wait_any_key();
|
||||||
|
void print_help();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
int main (int argc, char* argv[])
|
||||||
|
{
|
||||||
|
DOC **docs; /* training examples */
|
||||||
|
long totwords,totdoc,i;
|
||||||
|
double *target;
|
||||||
|
double *alpha_in=NULL;
|
||||||
|
KERNEL_CACHE *kernel_cache;
|
||||||
|
LEARN_PARM learn_parm;
|
||||||
|
KERNEL_PARM kernel_parm;
|
||||||
|
MODEL *model=(MODEL *)my_malloc(sizeof(MODEL));
|
||||||
|
|
||||||
|
read_input_parameters(argc,argv,docfile,modelfile,restartfile,&verbosity,
|
||||||
|
&learn_parm,&kernel_parm);
|
||||||
|
read_documents(docfile,&docs,&target,&totwords,&totdoc);
|
||||||
|
if(restartfile[0]) alpha_in=read_alphas(restartfile,totdoc);
|
||||||
|
|
||||||
|
if(kernel_parm.kernel_type == LINEAR) { /* don't need the cache */
|
||||||
|
kernel_cache=NULL;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
/* Always get a new kernel cache. It is not possible to use the
|
||||||
|
same cache for two different training runs */
|
||||||
|
kernel_cache=kernel_cache_init(totdoc,learn_parm.kernel_cache_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(learn_parm.type == CLASSIFICATION) {
|
||||||
|
svm_learn_classification(docs,target,totdoc,totwords,&learn_parm,
|
||||||
|
&kernel_parm,kernel_cache,model,alpha_in);
|
||||||
|
}
|
||||||
|
else if(learn_parm.type == REGRESSION) {
|
||||||
|
svm_learn_regression(docs,target,totdoc,totwords,&learn_parm,
|
||||||
|
&kernel_parm,&kernel_cache,model);
|
||||||
|
}
|
||||||
|
else if(learn_parm.type == RANKING) {
|
||||||
|
svm_learn_ranking(docs,target,totdoc,totwords,&learn_parm,
|
||||||
|
&kernel_parm,&kernel_cache,model);
|
||||||
|
}
|
||||||
|
else if(learn_parm.type == OPTIMIZATION) {
|
||||||
|
svm_learn_optimization(docs,target,totdoc,totwords,&learn_parm,
|
||||||
|
&kernel_parm,kernel_cache,model,alpha_in);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(kernel_cache) {
|
||||||
|
/* Free the memory used for the cache. */
|
||||||
|
kernel_cache_cleanup(kernel_cache);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Warning: The model contains references to the original data 'docs'.
|
||||||
|
If you want to free the original data, and only keep the model, you
|
||||||
|
have to make a deep copy of 'model'. */
|
||||||
|
/* deep_copy_of_model=copy_model(model); */
|
||||||
|
write_model(modelfile,model);
|
||||||
|
|
||||||
|
free(alpha_in);
|
||||||
|
free_model(model,0);
|
||||||
|
for(i=0;i<totdoc;i++)
|
||||||
|
free_example(docs[i],1);
|
||||||
|
free(docs);
|
||||||
|
free(target);
|
||||||
|
|
||||||
|
return(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*---------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
void read_input_parameters(int argc,char *argv[],char *docfile,char *modelfile,
|
||||||
|
char *restartfile,long *verbosity,
|
||||||
|
LEARN_PARM *learn_parm,KERNEL_PARM *kernel_parm)
|
||||||
|
{
|
||||||
|
long i;
|
||||||
|
char type[100];
|
||||||
|
|
||||||
|
/* set default */
|
||||||
|
set_learning_defaults(learn_parm, kernel_parm);
|
||||||
|
strcpy (modelfile, "svm_model");
|
||||||
|
strcpy (restartfile, "");
|
||||||
|
(*verbosity)=1;
|
||||||
|
strcpy(type,"c");
|
||||||
|
|
||||||
|
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||||
|
switch ((argv[i])[1])
|
||||||
|
{
|
||||||
|
case '?': print_help(); exit(0);
|
||||||
|
case 'z': i++; strcpy(type,argv[i]); break;
|
||||||
|
case 'v': i++; (*verbosity)=atol(argv[i]); break;
|
||||||
|
case 'b': i++; learn_parm->biased_hyperplane=atol(argv[i]); break;
|
||||||
|
case 'i': i++; learn_parm->remove_inconsistent=atol(argv[i]); break;
|
||||||
|
case 'f': i++; learn_parm->skip_final_opt_check=!atol(argv[i]); break;
|
||||||
|
case 'q': i++; learn_parm->svm_maxqpsize=atol(argv[i]); break;
|
||||||
|
case 'n': i++; learn_parm->svm_newvarsinqp=atol(argv[i]); break;
|
||||||
|
case '#': i++; learn_parm->maxiter=atol(argv[i]); break;
|
||||||
|
case 'h': i++; learn_parm->svm_iter_to_shrink=atol(argv[i]); break;
|
||||||
|
case 'm': i++; learn_parm->kernel_cache_size=atol(argv[i]); break;
|
||||||
|
case 'c': i++; learn_parm->svm_c=atof(argv[i]); break;
|
||||||
|
case 'w': i++; learn_parm->eps=atof(argv[i]); break;
|
||||||
|
case 'p': i++; learn_parm->transduction_posratio=atof(argv[i]); break;
|
||||||
|
case 'j': i++; learn_parm->svm_costratio=atof(argv[i]); break;
|
||||||
|
case 'e': i++; learn_parm->epsilon_crit=atof(argv[i]); break;
|
||||||
|
case 'o': i++; learn_parm->rho=atof(argv[i]); break;
|
||||||
|
case 'k': i++; learn_parm->xa_depth=atol(argv[i]); break;
|
||||||
|
case 'x': i++; learn_parm->compute_loo=atol(argv[i]); break;
|
||||||
|
case 't': i++; kernel_parm->kernel_type=atol(argv[i]); break;
|
||||||
|
case 'd': i++; kernel_parm->poly_degree=atol(argv[i]); break;
|
||||||
|
case 'g': i++; kernel_parm->rbf_gamma=atof(argv[i]); break;
|
||||||
|
case 's': i++; kernel_parm->coef_lin=atof(argv[i]); break;
|
||||||
|
case 'r': i++; kernel_parm->coef_const=atof(argv[i]); break;
|
||||||
|
case 'u': i++; strcpy(kernel_parm->custom,argv[i]); break;
|
||||||
|
case 'l': i++; strcpy(learn_parm->predfile,argv[i]); break;
|
||||||
|
case 'a': i++; strcpy(learn_parm->alphafile,argv[i]); break;
|
||||||
|
case 'y': i++; strcpy(restartfile,argv[i]); break;
|
||||||
|
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(i>=argc) {
|
||||||
|
printf("\nNot enough input parameters!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
strcpy (docfile, argv[i]);
|
||||||
|
if((i+1)<argc) {
|
||||||
|
strcpy (modelfile, argv[i+1]);
|
||||||
|
}
|
||||||
|
if(learn_parm->svm_iter_to_shrink == -9999) {
|
||||||
|
if(kernel_parm->kernel_type == LINEAR)
|
||||||
|
learn_parm->svm_iter_to_shrink=2;
|
||||||
|
else
|
||||||
|
learn_parm->svm_iter_to_shrink=100;
|
||||||
|
}
|
||||||
|
if(strcmp(type,"c")==0) {
|
||||||
|
learn_parm->type=CLASSIFICATION;
|
||||||
|
}
|
||||||
|
else if(strcmp(type,"r")==0) {
|
||||||
|
learn_parm->type=REGRESSION;
|
||||||
|
}
|
||||||
|
else if(strcmp(type,"p")==0) {
|
||||||
|
learn_parm->type=RANKING;
|
||||||
|
}
|
||||||
|
else if(strcmp(type,"o")==0) {
|
||||||
|
learn_parm->type=OPTIMIZATION;
|
||||||
|
}
|
||||||
|
else if(strcmp(type,"s")==0) {
|
||||||
|
learn_parm->type=OPTIMIZATION;
|
||||||
|
learn_parm->sharedslack=1;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
printf("\nUnknown type '%s': Valid types are 'c' (classification), 'r' regession, and 'p' preference ranking.\n",type);
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if (!check_learning_parms(learn_parm, kernel_parm)) {
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait_any_key()
|
||||||
|
{
|
||||||
|
printf("\n(more)\n");
|
||||||
|
(void)getc(stdin);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_help()
|
||||||
|
{
|
||||||
|
printf("\nSVM-light %s: Support Vector Machine, learning module %s\n",VERSION,VERSION_DATE);
|
||||||
|
copyright_notice();
|
||||||
|
printf(" usage: svm_learn [options] example_file model_file\n\n");
|
||||||
|
printf("Arguments:\n");
|
||||||
|
printf(" example_file-> file with training data\n");
|
||||||
|
printf(" model_file -> file to store learned decision rule in\n");
|
||||||
|
|
||||||
|
printf("General options:\n");
|
||||||
|
printf(" -? -> this help\n");
|
||||||
|
printf(" -v [0..3] -> verbosity level (default 1)\n");
|
||||||
|
printf("Learning options:\n");
|
||||||
|
printf(" -z {c,r,p} -> select between classification (c), regression (r),\n");
|
||||||
|
printf(" and preference ranking (p) (default classification)\n");
|
||||||
|
printf(" -c float -> C: trade-off between training error\n");
|
||||||
|
printf(" and margin (default [avg. x*x]^-1)\n");
|
||||||
|
printf(" -w [0..] -> epsilon width of tube for regression\n");
|
||||||
|
printf(" (default 0.1)\n");
|
||||||
|
printf(" -j float -> Cost: cost-factor, by which training errors on\n");
|
||||||
|
printf(" positive examples outweight errors on negative\n");
|
||||||
|
printf(" examples (default 1) (see [4])\n");
|
||||||
|
printf(" -b [0,1] -> use biased hyperplane (i.e. x*w+b>0) instead\n");
|
||||||
|
printf(" of unbiased hyperplane (i.e. x*w>0) (default 1)\n");
|
||||||
|
printf(" -i [0,1] -> remove inconsistent training examples\n");
|
||||||
|
printf(" and retrain (default 0)\n");
|
||||||
|
printf("Performance estimation options:\n");
|
||||||
|
printf(" -x [0,1] -> compute leave-one-out estimates (default 0)\n");
|
||||||
|
printf(" (see [5])\n");
|
||||||
|
printf(" -o ]0..2] -> value of rho for XiAlpha-estimator and for pruning\n");
|
||||||
|
printf(" leave-one-out computation (default 1.0) (see [2])\n");
|
||||||
|
printf(" -k [0..100] -> search depth for extended XiAlpha-estimator \n");
|
||||||
|
printf(" (default 0)\n");
|
||||||
|
printf("Transduction options (see [3]):\n");
|
||||||
|
printf(" -p [0..1] -> fraction of unlabeled examples to be classified\n");
|
||||||
|
printf(" into the positive class (default is the ratio of\n");
|
||||||
|
printf(" positive and negative examples in the training data)\n");
|
||||||
|
printf("Kernel options:\n");
|
||||||
|
printf(" -t int -> type of kernel function:\n");
|
||||||
|
printf(" 0: linear (default)\n");
|
||||||
|
printf(" 1: polynomial (s a*b+c)^d\n");
|
||||||
|
printf(" 2: radial basis function exp(-gamma ||a-b||^2)\n");
|
||||||
|
printf(" 3: sigmoid tanh(s a*b + c)\n");
|
||||||
|
printf(" 4: user defined kernel from kernel.h\n");
|
||||||
|
printf(" -d int -> parameter d in polynomial kernel\n");
|
||||||
|
printf(" -g float -> parameter gamma in rbf kernel\n");
|
||||||
|
printf(" -s float -> parameter s in sigmoid/poly kernel\n");
|
||||||
|
printf(" -r float -> parameter c in sigmoid/poly kernel\n");
|
||||||
|
printf(" -u string -> parameter of user defined kernel\n");
|
||||||
|
printf("Optimization options (see [1]):\n");
|
||||||
|
printf(" -q [2..] -> maximum size of QP-subproblems (default 10)\n");
|
||||||
|
printf(" -n [2..q] -> number of new variables entering the working set\n");
|
||||||
|
printf(" in each iteration (default n = q). Set n < q to \n");
|
||||||
|
printf(" prevent zig-zagging.\n");
|
||||||
|
printf(" -m [5..] -> size of cache for kernel evaluations in MB (default 40)\n");
|
||||||
|
printf(" The larger the faster...\n");
|
||||||
|
printf(" -e float -> eps: Allow that error for termination criterion\n");
|
||||||
|
printf(" [y [w*x+b] - 1] >= eps (default 0.001)\n");
|
||||||
|
printf(" -y [0,1] -> restart the optimization from alpha values in file\n");
|
||||||
|
printf(" specified by -a option. (default 0)\n");
|
||||||
|
printf(" -h [5..] -> number of iterations a variable needs to be\n");
|
||||||
|
printf(" optimal before considered for shrinking (default 100)\n");
|
||||||
|
printf(" -f [0,1] -> do final optimality check for variables removed\n");
|
||||||
|
printf(" by shrinking. Although this test is usually \n");
|
||||||
|
printf(" positive, there is no guarantee that the optimum\n");
|
||||||
|
printf(" was found if the test is omitted. (default 1)\n");
|
||||||
|
printf(" -y string -> if option is given, reads alphas from file with given\n");
|
||||||
|
printf(" and uses them as starting point. (default 'disabled')\n");
|
||||||
|
printf(" -# int -> terminate optimization, if no progress after this\n");
|
||||||
|
printf(" number of iterations. (default 100000)\n");
|
||||||
|
printf("Output options:\n");
|
||||||
|
printf(" -l string -> file to write predicted labels of unlabeled\n");
|
||||||
|
printf(" examples into after transductive learning\n");
|
||||||
|
printf(" -a string -> write all alphas to this file after learning\n");
|
||||||
|
printf(" (in the same order as in the training set)\n");
|
||||||
|
wait_any_key();
|
||||||
|
printf("\nMore details in:\n");
|
||||||
|
printf("[1] T. Joachims, Making Large-Scale SVM Learning Practical. Advances in\n");
|
||||||
|
printf(" Kernel Methods - Support Vector Learning, B. Sch<63>lkopf and C. Burges and\n");
|
||||||
|
printf(" A. Smola (ed.), MIT Press, 1999.\n");
|
||||||
|
printf("[2] T. Joachims, Estimating the Generalization performance of an SVM\n");
|
||||||
|
printf(" Efficiently. International Conference on Machine Learning (ICML), 2000.\n");
|
||||||
|
printf("[3] T. Joachims, Transductive Inference for Text Classification using Support\n");
|
||||||
|
printf(" Vector Machines. International Conference on Machine Learning (ICML),\n");
|
||||||
|
printf(" 1999.\n");
|
||||||
|
printf("[4] K. Morik, P. Brockhausen, and T. Joachims, Combining statistical learning\n");
|
||||||
|
printf(" with a knowledge-based approach - A case study in intensive care \n");
|
||||||
|
printf(" monitoring. International Conference on Machine Learning (ICML), 1999.\n");
|
||||||
|
printf("[5] T. Joachims, Learning to Classify Text Using Support Vector\n");
|
||||||
|
printf(" Machines: Methods, Theory, and Algorithms. Dissertation, Kluwer,\n");
|
||||||
|
printf(" 2002.\n\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
211
src/classifier/svm/svm_light/svm_loqo.c
Normal file
211
src/classifier/svm/svm_light/svm_loqo.c
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_loqo.c */
|
||||||
|
/* */
|
||||||
|
/* Interface to the PR_LOQO optimization package for SVM. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 19.07.99 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 1999 Universitaet Dortmund - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
# include <math.h>
|
||||||
|
# include "pr_loqo/pr_loqo.h"
|
||||||
|
# include "svm_common.h"
|
||||||
|
|
||||||
|
/* Common Block Declarations */
|
||||||
|
|
||||||
|
long verbosity;
|
||||||
|
|
||||||
|
/* /////////////////////////////////////////////////////////////// */
|
||||||
|
|
||||||
|
# define DEF_PRECISION_LINEAR 1E-8
|
||||||
|
# define DEF_PRECISION_NONLINEAR 1E-14
|
||||||
|
|
||||||
|
double *optimize_qp();
|
||||||
|
double *primal=0,*dual=0;
|
||||||
|
double init_margin=0.15;
|
||||||
|
long init_iter=500,precision_violations=0;
|
||||||
|
double model_b;
|
||||||
|
double opt_precision=DEF_PRECISION_LINEAR;
|
||||||
|
|
||||||
|
/* /////////////////////////////////////////////////////////////// */
|
||||||
|
|
||||||
|
void *my_malloc();
|
||||||
|
|
||||||
|
double *optimize_qp(qp,epsilon_crit,nx,threshold,learn_parm)
|
||||||
|
QP *qp;
|
||||||
|
double *epsilon_crit;
|
||||||
|
long nx; /* Maximum number of variables in QP */
|
||||||
|
double *threshold;
|
||||||
|
LEARN_PARM *learn_parm;
|
||||||
|
/* start the optimizer and return the optimal values */
|
||||||
|
{
|
||||||
|
register long i,j,result;
|
||||||
|
double margin,obj_before,obj_after;
|
||||||
|
double sigdig,dist,epsilon_loqo;
|
||||||
|
int iter;
|
||||||
|
|
||||||
|
if(!primal) { /* allocate memory at first call */
|
||||||
|
primal=(double *)my_malloc(sizeof(double)*nx*3);
|
||||||
|
dual=(double *)my_malloc(sizeof(double)*(nx*2+1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if(verbosity>=4) { /* really verbose */
|
||||||
|
printf("\n\n");
|
||||||
|
for(i=0;i<qp->opt_n;i++) {
|
||||||
|
printf("%f: ",qp->opt_g0[i]);
|
||||||
|
for(j=0;j<qp->opt_n;j++) {
|
||||||
|
printf("%f ",qp->opt_g[i*qp->opt_n+j]);
|
||||||
|
}
|
||||||
|
printf(": a%ld=%.10f < %f",i,qp->opt_xinit[i],qp->opt_up[i]);
|
||||||
|
printf(": y=%f\n",qp->opt_ce[i]);
|
||||||
|
}
|
||||||
|
for(j=0;j<qp->opt_m;j++) {
|
||||||
|
printf("EQ-%ld: %f*a0",j,qp->opt_ce[j]);
|
||||||
|
for(i=1;i<qp->opt_n;i++) {
|
||||||
|
printf(" + %f*a%ld",qp->opt_ce[i],i);
|
||||||
|
}
|
||||||
|
printf(" = %f\n\n",-qp->opt_ce0[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
obj_before=0; /* calculate objective before optimization */
|
||||||
|
for(i=0;i<qp->opt_n;i++) {
|
||||||
|
obj_before+=(qp->opt_g0[i]*qp->opt_xinit[i]);
|
||||||
|
obj_before+=(0.5*qp->opt_xinit[i]*qp->opt_xinit[i]*qp->opt_g[i*qp->opt_n+i]);
|
||||||
|
for(j=0;j<i;j++) {
|
||||||
|
obj_before+=(qp->opt_xinit[j]*qp->opt_xinit[i]*qp->opt_g[j*qp->opt_n+i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result=STILL_RUNNING;
|
||||||
|
qp->opt_ce0[0]*=(-1.0);
|
||||||
|
/* Run pr_loqo. If a run fails, try again with parameters which lead */
|
||||||
|
/* to a slower, but more robust setting. */
|
||||||
|
for(margin=init_margin,iter=init_iter;
|
||||||
|
(margin<=0.9999999) && (result!=OPTIMAL_SOLUTION);) {
|
||||||
|
sigdig=-log10(opt_precision);
|
||||||
|
|
||||||
|
result=pr_loqo((int)qp->opt_n,(int)qp->opt_m,
|
||||||
|
(double *)qp->opt_g0,(double *)qp->opt_g,
|
||||||
|
(double *)qp->opt_ce,(double *)qp->opt_ce0,
|
||||||
|
(double *)qp->opt_low,(double *)qp->opt_up,
|
||||||
|
(double *)primal,(double *)dual,
|
||||||
|
(int)(verbosity-2),
|
||||||
|
(double)sigdig,(int)iter,
|
||||||
|
(double)margin,(double)(qp->opt_up[0])/4.0,(int)0);
|
||||||
|
|
||||||
|
if(isnan(dual[0])) { /* check for choldc problem */
|
||||||
|
if(verbosity>=2) {
|
||||||
|
printf("NOTICE: Restarting PR_LOQO with more conservative parameters.\n");
|
||||||
|
}
|
||||||
|
if(init_margin<0.80) { /* become more conservative in general */
|
||||||
|
init_margin=(4.0*margin+1.0)/5.0;
|
||||||
|
}
|
||||||
|
margin=(margin+1.0)/2.0;
|
||||||
|
(opt_precision)*=10.0; /* reduce precision */
|
||||||
|
if(verbosity>=2) {
|
||||||
|
printf("NOTICE: Reducing precision of PR_LOQO.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(result!=OPTIMAL_SOLUTION) {
|
||||||
|
iter+=2000;
|
||||||
|
init_iter+=10;
|
||||||
|
(opt_precision)*=10.0; /* reduce precision */
|
||||||
|
if(verbosity>=2) {
|
||||||
|
printf("NOTICE: Reducing precision of PR_LOQO due to (%ld).\n",result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(qp->opt_m) /* Thanks to Alex Smola for this hint */
|
||||||
|
model_b=dual[0];
|
||||||
|
else
|
||||||
|
model_b=0;
|
||||||
|
|
||||||
|
/* Check the precision of the alphas. If results of current optimization */
|
||||||
|
/* violate KT-Conditions, relax the epsilon on the bounds on alphas. */
|
||||||
|
epsilon_loqo=1E-10;
|
||||||
|
for(i=0;i<qp->opt_n;i++) {
|
||||||
|
dist=-model_b*qp->opt_ce[i];
|
||||||
|
dist+=(qp->opt_g0[i]+1.0);
|
||||||
|
for(j=0;j<i;j++) {
|
||||||
|
dist+=(primal[j]*qp->opt_g[j*qp->opt_n+i]);
|
||||||
|
}
|
||||||
|
for(j=i;j<qp->opt_n;j++) {
|
||||||
|
dist+=(primal[j]*qp->opt_g[i*qp->opt_n+j]);
|
||||||
|
}
|
||||||
|
/* printf("LOQO: a[%d]=%f, dist=%f, b=%f\n",i,primal[i],dist,dual[0]); */
|
||||||
|
if((primal[i]<(qp->opt_up[i]-epsilon_loqo)) && (dist < (1.0-(*epsilon_crit)))) {
|
||||||
|
epsilon_loqo=(qp->opt_up[i]-primal[i])*2.0;
|
||||||
|
}
|
||||||
|
else if((primal[i]>(0+epsilon_loqo)) && (dist > (1.0+(*epsilon_crit)))) {
|
||||||
|
epsilon_loqo=primal[i]*2.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for(i=0;i<qp->opt_n;i++) { /* clip alphas to bounds */
|
||||||
|
if(primal[i]<=(0+epsilon_loqo)) {
|
||||||
|
primal[i]=0;
|
||||||
|
}
|
||||||
|
else if(primal[i]>=(qp->opt_up[i]-epsilon_loqo)) {
|
||||||
|
primal[i]=qp->opt_up[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
obj_after=0; /* calculate objective after optimization */
|
||||||
|
for(i=0;i<qp->opt_n;i++) {
|
||||||
|
obj_after+=(qp->opt_g0[i]*primal[i]);
|
||||||
|
obj_after+=(0.5*primal[i]*primal[i]*qp->opt_g[i*qp->opt_n+i]);
|
||||||
|
for(j=0;j<i;j++) {
|
||||||
|
obj_after+=(primal[j]*primal[i]*qp->opt_g[j*qp->opt_n+i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* if optimizer returned NAN values, reset and retry with smaller */
|
||||||
|
/* working set. */
|
||||||
|
if(isnan(obj_after) || isnan(model_b)) {
|
||||||
|
for(i=0;i<qp->opt_n;i++) {
|
||||||
|
primal[i]=qp->opt_xinit[i];
|
||||||
|
}
|
||||||
|
model_b=0;
|
||||||
|
if(learn_parm->svm_maxqpsize>2) {
|
||||||
|
learn_parm->svm_maxqpsize--; /* decrease size of qp-subproblems */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(obj_after >= obj_before) { /* check whether there was progress */
|
||||||
|
(opt_precision)/=100.0;
|
||||||
|
precision_violations++;
|
||||||
|
if(verbosity>=2) {
|
||||||
|
printf("NOTICE: Increasing Precision of PR_LOQO.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(precision_violations > 500) {
|
||||||
|
(*epsilon_crit)*=10.0;
|
||||||
|
precision_violations=0;
|
||||||
|
if(verbosity>=1) {
|
||||||
|
printf("\nWARNING: Relaxing epsilon on KT-Conditions.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(*threshold)=model_b;
|
||||||
|
|
||||||
|
if(result!=OPTIMAL_SOLUTION) {
|
||||||
|
printf("\nERROR: PR_LOQO did not converge. \n");
|
||||||
|
return(qp->opt_xinit);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return(primal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
70
src/classifier/svm/svm_multiclass_classifier.cpp
Normal file
70
src/classifier/svm/svm_multiclass_classifier.cpp
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#include "svm_multiclass_classifier.hpp"
|
||||||
|
#include "svm_struct/svm_struct_common.h"
|
||||||
|
#include "svm_struct_api.h"
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
SVMMultiClassClassifier::SVMMultiClassClassifier() {}
|
||||||
|
SVMMultiClassClassifier::~SVMMultiClassClassifier() {
|
||||||
|
if (model_ != NULL) {
|
||||||
|
free_struct_model(*model_);
|
||||||
|
free(model_);
|
||||||
|
model_ = NULL;
|
||||||
|
}
|
||||||
|
if (sparm_ != NULL) {
|
||||||
|
free(sparm_);
|
||||||
|
sparm_ = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int SVMMultiClassClassifier::LoadModel(const char *modelfile) {
|
||||||
|
if (model_ != NULL) {
|
||||||
|
free_struct_model(*model_);
|
||||||
|
}
|
||||||
|
if (sparm_ != NULL) {
|
||||||
|
free(sparm_);
|
||||||
|
}
|
||||||
|
model_ = (STRUCTMODEL *)my_malloc(sizeof(STRUCTMODEL));
|
||||||
|
sparm_ = (STRUCT_LEARN_PARM *)my_malloc(sizeof(STRUCT_LEARN_PARM));
|
||||||
|
(*model_) = read_struct_model((char *)modelfile, sparm_);
|
||||||
|
if (model_->svm_model->kernel_parm.kernel_type ==
|
||||||
|
LINEAR) { /* linear kernel */
|
||||||
|
/* compute weight vector */
|
||||||
|
add_weight_vector_to_linear_model(model_->svm_model);
|
||||||
|
model_->w = model_->svm_model->lin_weights;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
double SVMMultiClassClassifier::Predict(const float *vec) { return 0; }
|
||||||
|
|
||||||
|
int SVMMultiClassClassifier::Classify(const float *vec,
|
||||||
|
std::vector<float> &scores) {
|
||||||
|
if (model_ == NULL || sparm_ == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
struct_verbosity = 5;
|
||||||
|
int feats = sparm_->num_features;
|
||||||
|
WORD *words = (WORD *)malloc(sizeof(WORD) * (feats + 10));
|
||||||
|
for (int i = 0; i < (feats + 10); ++i) {
|
||||||
|
if (i >= feats) {
|
||||||
|
words[i].wnum = 0;
|
||||||
|
words[i].weight = 0;
|
||||||
|
} else {
|
||||||
|
words[i].wnum = i + 1;
|
||||||
|
words[i].weight = vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DOC *doc =
|
||||||
|
create_example(-1, 0, 0, 0.0, create_svector(words, (char *)"", 1.0));
|
||||||
|
free(words);
|
||||||
|
PATTERN pattern;
|
||||||
|
pattern.doc = doc;
|
||||||
|
LABEL y = classify_struct_example(pattern, model_, sparm_);
|
||||||
|
free_pattern(pattern);
|
||||||
|
scores.clear();
|
||||||
|
for (int i = 1; i <= y.num_classes_; ++i) {
|
||||||
|
scores.push_back(y.scores[i]);
|
||||||
|
}
|
||||||
|
free_label(y);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} // namespace ovclassifier
|
21
src/classifier/svm/svm_multiclass_classifier.hpp
Normal file
21
src/classifier/svm/svm_multiclass_classifier.hpp
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
#ifndef _CLASSIFIER_SVM_MULTICLASS_CLASSIFIER_H_
|
||||||
|
#define _CLASSIFIER_SVM_MULTICLASS_CLASSIFIER_H_
|
||||||
|
|
||||||
|
#include "svm_classifier.hpp"
|
||||||
|
#include "svm_struct_api_types.h"
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
class SVMMultiClassClassifier : public SVMClassifier {
|
||||||
|
public:
|
||||||
|
SVMMultiClassClassifier();
|
||||||
|
~SVMMultiClassClassifier();
|
||||||
|
int LoadModel(const char *modelfile);
|
||||||
|
double Predict(const float *vec);
|
||||||
|
int Classify(const float *vec, std::vector<float> &scores);
|
||||||
|
|
||||||
|
private:
|
||||||
|
STRUCTMODEL *model_ = NULL;
|
||||||
|
STRUCT_LEARN_PARM *sparm_ = NULL;
|
||||||
|
};
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // !_CLASSIFIER_SVM_MULTICLASS_CLASSIFIER_H_
|
151
src/classifier/svm/svm_multiclass_trainer.cpp
Normal file
151
src/classifier/svm/svm_multiclass_trainer.cpp
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
#include "svm_multiclass_trainer.hpp"
|
||||||
|
#include "svm_light/svm_learn.h"
|
||||||
|
#include "svm_struct/svm_struct_learn.h"
|
||||||
|
#include "svm_struct_api.h"
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
|
||||||
|
SVMMultiClassTrainer::SVMMultiClassTrainer() {
|
||||||
|
alg_type = DEFAULT_ALG_TYPE;
|
||||||
|
struct_parm = (STRUCT_LEARN_PARM *)malloc(sizeof(STRUCT_LEARN_PARM));
|
||||||
|
struct_parm->C = 10000;
|
||||||
|
struct_parm->slack_norm = 1;
|
||||||
|
struct_parm->epsilon = DEFAULT_EPS;
|
||||||
|
struct_parm->custom_argc = 0;
|
||||||
|
struct_parm->loss_function = DEFAULT_LOSS_FCT;
|
||||||
|
struct_parm->loss_type = DEFAULT_RESCALING;
|
||||||
|
struct_parm->newconstretrain = 100;
|
||||||
|
struct_parm->ccache_size = 5;
|
||||||
|
struct_parm->batch_size = 100;
|
||||||
|
|
||||||
|
learn_parm = (LEARN_PARM *)malloc(sizeof(LEARN_PARM));
|
||||||
|
strcpy(learn_parm->predfile, "trans_predictions");
|
||||||
|
strcpy(learn_parm->alphafile, "");
|
||||||
|
learn_parm->biased_hyperplane = 1;
|
||||||
|
learn_parm->remove_inconsistent = 0;
|
||||||
|
learn_parm->skip_final_opt_check = 0;
|
||||||
|
learn_parm->svm_maxqpsize = 10;
|
||||||
|
learn_parm->svm_newvarsinqp = 0;
|
||||||
|
// learn_parm->svm_iter_to_shrink = -9999;
|
||||||
|
learn_parm->svm_iter_to_shrink = 100;
|
||||||
|
learn_parm->maxiter = 100000;
|
||||||
|
learn_parm->kernel_cache_size = 40;
|
||||||
|
learn_parm->svm_c = 99999999; /* overridden by struct_parm->C */
|
||||||
|
learn_parm->eps = 0.001; /* overridden by struct_parm->epsilon */
|
||||||
|
learn_parm->transduction_posratio = -1.0;
|
||||||
|
learn_parm->svm_costratio = 1.0;
|
||||||
|
learn_parm->svm_costratio_unlab = 1.0;
|
||||||
|
learn_parm->svm_unlabbound = 1E-5;
|
||||||
|
learn_parm->epsilon_crit = 0.001;
|
||||||
|
learn_parm->epsilon_a = 1E-10; /* changed from 1e-15 */
|
||||||
|
learn_parm->compute_loo = 0;
|
||||||
|
learn_parm->rho = 1.0;
|
||||||
|
learn_parm->xa_depth = 0;
|
||||||
|
kernel_parm = (KERNEL_PARM *)malloc(sizeof(KERNEL_PARM));
|
||||||
|
kernel_parm->kernel_type = 0;
|
||||||
|
kernel_parm->poly_degree = 3;
|
||||||
|
kernel_parm->rbf_gamma = 1.0;
|
||||||
|
kernel_parm->coef_lin = 1;
|
||||||
|
kernel_parm->coef_const = 1;
|
||||||
|
strcpy(kernel_parm->custom, "empty");
|
||||||
|
|
||||||
|
parse_struct_parameters(struct_parm);
|
||||||
|
}
|
||||||
|
|
||||||
|
SVMMultiClassTrainer::~SVMMultiClassTrainer() {
|
||||||
|
if (learn_parm != NULL) {
|
||||||
|
free(learn_parm);
|
||||||
|
learn_parm = NULL;
|
||||||
|
}
|
||||||
|
if (kernel_parm != NULL) {
|
||||||
|
free(kernel_parm);
|
||||||
|
kernel_parm = NULL;
|
||||||
|
}
|
||||||
|
if (learn_parm != NULL) {
|
||||||
|
free(learn_parm);
|
||||||
|
learn_parm = NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SVMMultiClassTrainer::Reset() {
|
||||||
|
labels_ = 0;
|
||||||
|
feats_ = 0;
|
||||||
|
items_.clear();
|
||||||
|
}
|
||||||
|
void SVMMultiClassTrainer::SetLabels(int labels) { labels_ = labels; }
|
||||||
|
void SVMMultiClassTrainer::SetFeatures(int feats) { feats_ = feats; }
|
||||||
|
void SVMMultiClassTrainer::AddData(int label, const float *vec) {
|
||||||
|
LabelItem itm;
|
||||||
|
itm.label = label;
|
||||||
|
for (int i = 0; i < feats_; ++i) {
|
||||||
|
itm.vec.push_back(vec[i]);
|
||||||
|
}
|
||||||
|
items_.push_back(itm);
|
||||||
|
}
|
||||||
|
|
||||||
|
int SVMMultiClassTrainer::Train(const char *modelfile) {
|
||||||
|
struct_verbosity = 2;
|
||||||
|
int totdoc = items_.size();
|
||||||
|
if (totdoc == 0 || feats_ == 0 || labels_ == 0) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
EXAMPLE *examples = (EXAMPLE *)my_malloc(sizeof(EXAMPLE) * totdoc);
|
||||||
|
WORD *words = (WORD *)my_malloc(sizeof(WORD) * (feats_ * 10));
|
||||||
|
for (int dnum = 0; dnum < totdoc; ++dnum) {
|
||||||
|
const int docFeats = items_[dnum].vec.size();
|
||||||
|
for (int i = 0; i < (feats_ + 10); ++i) {
|
||||||
|
if (i >= feats_) {
|
||||||
|
words[i].wnum = 0;
|
||||||
|
} else {
|
||||||
|
(words[i]).wnum = i + 1;
|
||||||
|
}
|
||||||
|
if (i >= docFeats) {
|
||||||
|
(words[i]).weight = 0;
|
||||||
|
} else {
|
||||||
|
(words[i]).weight = (FVAL)items_[dnum].vec[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DOC *doc =
|
||||||
|
create_example(dnum, 0, 0, 0, create_svector(words, (char *)"", 1.0));
|
||||||
|
examples[dnum].x.doc = doc;
|
||||||
|
examples[dnum].y.class_ = (double)items_[dnum].label + 0.1;
|
||||||
|
examples[dnum].y.scores = NULL;
|
||||||
|
examples[dnum].y.num_classes_ = (double)labels_ + 0.1;
|
||||||
|
}
|
||||||
|
free(words);
|
||||||
|
|
||||||
|
SAMPLE sample;
|
||||||
|
sample.n = totdoc;
|
||||||
|
sample.examples = examples;
|
||||||
|
STRUCTMODEL structmodel;
|
||||||
|
/* Do the learning and return structmodel. */
|
||||||
|
if (alg_type == 0)
|
||||||
|
svm_learn_struct(sample, struct_parm, learn_parm, kernel_parm, &structmodel,
|
||||||
|
NSLACK_ALG);
|
||||||
|
else if (alg_type == 1)
|
||||||
|
svm_learn_struct(sample, struct_parm, learn_parm, kernel_parm, &structmodel,
|
||||||
|
NSLACK_SHRINK_ALG);
|
||||||
|
else if (alg_type == 2)
|
||||||
|
svm_learn_struct_joint(sample, struct_parm, learn_parm, kernel_parm,
|
||||||
|
&structmodel, ONESLACK_PRIMAL_ALG);
|
||||||
|
else if (alg_type == 3)
|
||||||
|
svm_learn_struct_joint(sample, struct_parm, learn_parm, kernel_parm,
|
||||||
|
&structmodel, ONESLACK_DUAL_ALG);
|
||||||
|
else if (alg_type == 4)
|
||||||
|
svm_learn_struct_joint(sample, struct_parm, learn_parm, kernel_parm,
|
||||||
|
&structmodel, ONESLACK_DUAL_CACHE_ALG);
|
||||||
|
else if (alg_type == 9)
|
||||||
|
svm_learn_struct_joint_custom(sample, struct_parm, learn_parm, kernel_parm,
|
||||||
|
&structmodel);
|
||||||
|
else
|
||||||
|
return -1;
|
||||||
|
|
||||||
|
write_struct_model((char *)modelfile, &structmodel, struct_parm);
|
||||||
|
|
||||||
|
free_struct_sample(sample);
|
||||||
|
free_struct_model(structmodel);
|
||||||
|
|
||||||
|
svm_struct_learn_api_exit();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
} // namespace ovclassifier
|
31
src/classifier/svm/svm_multiclass_trainer.hpp
Normal file
31
src/classifier/svm/svm_multiclass_trainer.hpp
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
#ifndef _SVM_MULTCLASS_TRAINER_H_
|
||||||
|
#define _SVM_MULTCLASS_TRAINER_H_
|
||||||
|
|
||||||
|
#include "svm_common.hpp"
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
#include "svm_struct_api_types.h"
|
||||||
|
#include "svm_trainer.hpp"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
class SVMMultiClassTrainer : public SVMTrainer {
|
||||||
|
public:
|
||||||
|
SVMMultiClassTrainer();
|
||||||
|
~SVMMultiClassTrainer();
|
||||||
|
void Reset();
|
||||||
|
void SetLabels(int labels);
|
||||||
|
void SetFeatures(int feats);
|
||||||
|
void AddData(int label, const float *vec);
|
||||||
|
int Train(const char *modelfile);
|
||||||
|
|
||||||
|
private:
|
||||||
|
KERNEL_PARM *kernel_parm = NULL;
|
||||||
|
LEARN_PARM *learn_parm = NULL;
|
||||||
|
STRUCT_LEARN_PARM *struct_parm = NULL;
|
||||||
|
int alg_type;
|
||||||
|
int feats_;
|
||||||
|
int labels_;
|
||||||
|
std::vector<LabelItem> items_;
|
||||||
|
};
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // _SVM_MULTICLASS_TRAINER_H_
|
186
src/classifier/svm/svm_struct/svm_struct_classify.c
Executable file
186
src/classifier/svm/svm_struct/svm_struct_classify.c
Executable file
@@ -0,0 +1,186 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_classify.c */
|
||||||
|
/* */
|
||||||
|
/* Classification module of SVM-struct. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 03.07.04 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/************************************************************************/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
#include "../svm_light/svm_common.h"
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#include "../svm_struct_api.h"
|
||||||
|
#include "svm_struct_common.h"
|
||||||
|
|
||||||
|
char testfile[200];
|
||||||
|
char modelfile[200];
|
||||||
|
char predictionsfile[200];
|
||||||
|
|
||||||
|
void read_input_parameters(int, char **, char *, char *, char *,
|
||||||
|
STRUCT_LEARN_PARM *, long*, long *);
|
||||||
|
void print_help(void);
|
||||||
|
|
||||||
|
|
||||||
|
int main (int argc, char* argv[])
|
||||||
|
{
|
||||||
|
long correct=0,incorrect=0,no_accuracy=0;
|
||||||
|
long i;
|
||||||
|
double t1,runtime=0;
|
||||||
|
double avgloss=0,l;
|
||||||
|
FILE *predfl;
|
||||||
|
STRUCTMODEL model;
|
||||||
|
STRUCT_LEARN_PARM sparm;
|
||||||
|
STRUCT_TEST_STATS teststats;
|
||||||
|
SAMPLE testsample;
|
||||||
|
LABEL y;
|
||||||
|
|
||||||
|
svm_struct_classify_api_init(argc,argv);
|
||||||
|
|
||||||
|
read_input_parameters(argc,argv,testfile,modelfile,predictionsfile,&sparm,
|
||||||
|
&verbosity,&struct_verbosity);
|
||||||
|
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("Reading model..."); fflush(stdout);
|
||||||
|
}
|
||||||
|
model=read_struct_model(modelfile,&sparm);
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
fprintf(stdout, "done.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if(model.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
|
||||||
|
/* compute weight vector */
|
||||||
|
add_weight_vector_to_linear_model(model.svm_model);
|
||||||
|
model.w=model.svm_model->lin_weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("Reading test examples..."); fflush(stdout);
|
||||||
|
}
|
||||||
|
testsample=read_struct_examples(testfile,&sparm);
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("done.\n"); fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("Classifying test examples..."); fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((predfl = fopen (predictionsfile, "w")) == NULL)
|
||||||
|
{ perror (predictionsfile); exit (1); }
|
||||||
|
|
||||||
|
for(i=0;i<testsample.n;i++) {
|
||||||
|
t1=get_runtime();
|
||||||
|
y=classify_struct_example(testsample.examples[i].x,&model,&sparm);
|
||||||
|
runtime+=(get_runtime()-t1);
|
||||||
|
|
||||||
|
write_label(predfl,y);
|
||||||
|
l=loss(testsample.examples[i].y,y,&sparm);
|
||||||
|
avgloss+=l;
|
||||||
|
if(l == 0)
|
||||||
|
correct++;
|
||||||
|
else
|
||||||
|
incorrect++;
|
||||||
|
eval_prediction(i,testsample.examples[i],y,&model,&sparm,&teststats);
|
||||||
|
|
||||||
|
if(empty_label(testsample.examples[i].y))
|
||||||
|
{ no_accuracy=1; } /* test data is not labeled */
|
||||||
|
if(struct_verbosity>=2) {
|
||||||
|
if((i+1) % 100 == 0) {
|
||||||
|
printf("%ld..",i+1); fflush(stdout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
free_label(y);
|
||||||
|
}
|
||||||
|
avgloss/=testsample.n;
|
||||||
|
fclose(predfl);
|
||||||
|
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("done\n");
|
||||||
|
printf("Runtime (without IO) in cpu-seconds: %.2f\n",
|
||||||
|
(float)(runtime/100.0));
|
||||||
|
}
|
||||||
|
if((!no_accuracy) && (struct_verbosity>=1)) {
|
||||||
|
printf("Average loss on test set: %.4f\n",(float)avgloss);
|
||||||
|
printf("Zero/one-error on test set: %.2f%% (%ld correct, %ld incorrect, %d total)\n",(float)100.0*incorrect/testsample.n,correct,incorrect,testsample.n);
|
||||||
|
}
|
||||||
|
print_struct_testing_stats(testsample,&model,&sparm,&teststats);
|
||||||
|
free_struct_sample(testsample);
|
||||||
|
free_struct_model(model);
|
||||||
|
|
||||||
|
svm_struct_classify_api_exit();
|
||||||
|
|
||||||
|
return(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_input_parameters(int argc,char *argv[],char *testfile,
|
||||||
|
char *modelfile,char *predictionsfile,
|
||||||
|
STRUCT_LEARN_PARM *struct_parm,
|
||||||
|
long *verbosity,long *struct_verbosity)
|
||||||
|
{
|
||||||
|
long i;
|
||||||
|
|
||||||
|
/* set default */
|
||||||
|
strcpy (modelfile, "svm_model");
|
||||||
|
strcpy (predictionsfile, "svm_predictions");
|
||||||
|
(*verbosity)=0;/*verbosity for svm_light*/
|
||||||
|
(*struct_verbosity)=1; /*verbosity for struct learning portion*/
|
||||||
|
struct_parm->custom_argc=0;
|
||||||
|
|
||||||
|
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||||
|
switch ((argv[i])[1])
|
||||||
|
{
|
||||||
|
case 'h': print_help(); exit(0);
|
||||||
|
case '?': print_help(); exit(0);
|
||||||
|
case '-': strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);i++; strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);break;
|
||||||
|
case 'v': i++; (*struct_verbosity)=atol(argv[i]); break;
|
||||||
|
case 'y': i++; (*verbosity)=atol(argv[i]); break;
|
||||||
|
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if((i+1)>=argc) {
|
||||||
|
printf("\nNot enough input parameters!\n\n");
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
strcpy (testfile, argv[i]);
|
||||||
|
strcpy (modelfile, argv[i+1]);
|
||||||
|
if((i+2)<argc) {
|
||||||
|
strcpy (predictionsfile, argv[i+2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
parse_struct_parameters_classify(struct_parm);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_help(void)
|
||||||
|
{
|
||||||
|
printf("\nSVM-struct classification module: %s, %s, %s\n",INST_NAME,INST_VERSION,INST_VERSION_DATE);
|
||||||
|
printf(" includes SVM-struct %s for learning complex outputs, %s\n",STRUCT_VERSION,STRUCT_VERSION_DATE);
|
||||||
|
printf(" includes SVM-light %s quadratic optimizer, %s\n",VERSION,VERSION_DATE);
|
||||||
|
copyright_notice();
|
||||||
|
printf(" usage: svm_struct_classify [options] example_file model_file output_file\n\n");
|
||||||
|
printf("options: -h -> this help\n");
|
||||||
|
printf(" -v [0..3] -> verbosity level (default 2)\n\n");
|
||||||
|
|
||||||
|
print_struct_help_classify();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
66
src/classifier/svm/svm_struct/svm_struct_common.c
Normal file
66
src/classifier/svm/svm_struct/svm_struct_common.c
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_common.h */
|
||||||
|
/* */
|
||||||
|
/* Functions and types used by multiple components of SVM-struct. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 03.07.04 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "svm_struct_common.h"
|
||||||
|
|
||||||
|
long struct_verbosity; /* verbosity level (0-4) */
|
||||||
|
|
||||||
|
void printIntArray(int* x, int n)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for(i=0;i<n;i++)
|
||||||
|
printf("%i:",x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void printDoubleArray(double* x, int n)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for(i=0;i<n;i++)
|
||||||
|
printf("%f:",x[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void printWordArray(WORD* x)
|
||||||
|
{
|
||||||
|
int i=0;
|
||||||
|
for(;x[i].wnum!=0;i++)
|
||||||
|
if(x[i].weight != 0)
|
||||||
|
printf(" %i:%.2f ",(int)x[i].wnum,x[i].weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
void printW(double *w, long sizePhi, long n,double C)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
printf("---- w ----\n");
|
||||||
|
for(i=0;i<sizePhi;i++)
|
||||||
|
{
|
||||||
|
printf("%f ",w[i]);
|
||||||
|
}
|
||||||
|
printf("\n----- xi ----\n");
|
||||||
|
for(;i<sizePhi+2*n;i++)
|
||||||
|
{
|
||||||
|
printf("%f ",1/sqrt(2*C)*w[i]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
}
|
||||||
|
/**** end print methods ****/
|
||||||
|
|
61
src/classifier/svm/svm_struct/svm_struct_common.h
Executable file
61
src/classifier/svm/svm_struct/svm_struct_common.h
Executable file
@@ -0,0 +1,61 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_common.h */
|
||||||
|
/* */
|
||||||
|
/* Functions and types used by multiple components of SVM-struct. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 31.10.05 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2005 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#ifndef svm_struct_common
|
||||||
|
#define svm_struct_common
|
||||||
|
|
||||||
|
# define STRUCT_VERSION "V3.10"
|
||||||
|
# define STRUCT_VERSION_DATE "14.08.08"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
#include "../svm_light/svm_common.h"
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#include "../svm_struct_api_types.h"
|
||||||
|
|
||||||
|
typedef struct example { /* an example is a pair of pattern and label */
|
||||||
|
PATTERN x;
|
||||||
|
LABEL y;
|
||||||
|
} EXAMPLE;
|
||||||
|
|
||||||
|
typedef struct sample { /* a sample is a set of examples */
|
||||||
|
int n; /* n is the total number of examples */
|
||||||
|
EXAMPLE *examples;
|
||||||
|
} SAMPLE;
|
||||||
|
|
||||||
|
typedef struct constset { /* a set of linear inequality constrains of
|
||||||
|
for lhs[i]*w >= rhs[i] */
|
||||||
|
int m; /* m is the total number of constrains */
|
||||||
|
DOC **lhs;
|
||||||
|
double *rhs;
|
||||||
|
} CONSTSET;
|
||||||
|
|
||||||
|
|
||||||
|
/**** print methods ****/
|
||||||
|
void printIntArray(int*,int);
|
||||||
|
void printDoubleArray(double*,int);
|
||||||
|
void printWordArray(WORD*);
|
||||||
|
void printModel(MODEL *);
|
||||||
|
void printW(double *, long, long, double);
|
||||||
|
|
||||||
|
extern long struct_verbosity; /* verbosity level (0-4) */
|
||||||
|
|
||||||
|
#endif
|
1289
src/classifier/svm/svm_struct/svm_struct_learn.c
Executable file
1289
src/classifier/svm/svm_struct/svm_struct_learn.c
Executable file
File diff suppressed because it is too large
Load Diff
101
src/classifier/svm/svm_struct/svm_struct_learn.h
Executable file
101
src/classifier/svm/svm_struct/svm_struct_learn.h
Executable file
@@ -0,0 +1,101 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_learn.h */
|
||||||
|
/* */
|
||||||
|
/* Basic algorithm for learning structured outputs (e.g. parses, */
|
||||||
|
/* sequences, multi-label classification) with a Support Vector */
|
||||||
|
/* Machine. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 03.07.04 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#ifndef SVM_STRUCT_LEARN
|
||||||
|
#define SVM_STRUCT_LEARN
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
#include "../svm_light/svm_common.h"
|
||||||
|
#include "../svm_light/svm_learn.h"
|
||||||
|
#include "../svm_struct_api_types.h"
|
||||||
|
#include "svm_struct_common.h"
|
||||||
|
|
||||||
|
#define SLACK_RESCALING 1
|
||||||
|
#define MARGIN_RESCALING 2
|
||||||
|
|
||||||
|
#define NSLACK_ALG 0
|
||||||
|
#define NSLACK_SHRINK_ALG 1
|
||||||
|
#define ONESLACK_PRIMAL_ALG 2
|
||||||
|
#define ONESLACK_DUAL_ALG 3
|
||||||
|
#define ONESLACK_DUAL_CACHE_ALG 4
|
||||||
|
|
||||||
|
typedef struct ccacheelem {
|
||||||
|
SVECTOR *fydelta; /* left hand side of constraint */
|
||||||
|
double rhs; /* right hand side of constraint */
|
||||||
|
double viol; /* violation score under current model */
|
||||||
|
struct ccacheelem *next; /* next in linked list */
|
||||||
|
} CCACHEELEM;
|
||||||
|
|
||||||
|
typedef struct ccache {
|
||||||
|
int n; /* number of examples */
|
||||||
|
CCACHEELEM **constlist; /* array of pointers to constraint lists
|
||||||
|
- one list per example. The first
|
||||||
|
element of the list always points to
|
||||||
|
the most violated constraint under the
|
||||||
|
current model for each example. */
|
||||||
|
STRUCTMODEL *sm; /* pointer to model */
|
||||||
|
double *avg_viol_gain; /* array of average values by which
|
||||||
|
violation of globally most violated
|
||||||
|
constraint exceeds that of most violated
|
||||||
|
constraint in cache */
|
||||||
|
int *changed; /* array of boolean indicating whether the
|
||||||
|
most violated ybar change compared to
|
||||||
|
last iter? */
|
||||||
|
} CCACHE;
|
||||||
|
|
||||||
|
void find_most_violated_constraint(SVECTOR **fydelta, double *lossval,
|
||||||
|
EXAMPLE *ex, SVECTOR *fycached, long n,
|
||||||
|
STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
|
||||||
|
double *rt_viol, double *rt_psi,
|
||||||
|
long *argmax_count);
|
||||||
|
CCACHE *create_constraint_cache(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||||
|
STRUCTMODEL *sm);
|
||||||
|
void free_constraint_cache(CCACHE *ccache);
|
||||||
|
double add_constraint_to_constraint_cache(CCACHE *ccache, MODEL *svmModel,
|
||||||
|
int exnum, SVECTOR *fydelta,
|
||||||
|
double rhs, double gainthresh,
|
||||||
|
int maxconst, double *rt_cachesum);
|
||||||
|
void update_constraint_cache_for_model(CCACHE *ccache, MODEL *svmModel);
|
||||||
|
double compute_violation_of_constraint_in_cache(CCACHE *ccache, double thresh);
|
||||||
|
double find_most_violated_joint_constraint_in_cache(CCACHE *ccache,
|
||||||
|
double thresh,
|
||||||
|
double *lhs_n,
|
||||||
|
SVECTOR **lhs, double *rhs);
|
||||||
|
void svm_learn_struct(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm, STRUCTMODEL *sm,
|
||||||
|
int alg_type);
|
||||||
|
void svm_learn_struct_joint(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||||
|
STRUCTMODEL *sm, int alg_type);
|
||||||
|
void svm_learn_struct_joint_custom(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||||
|
STRUCTMODEL *sm);
|
||||||
|
void remove_inactive_constraints(CONSTSET *cset, double *alpha, long i,
|
||||||
|
long *alphahist, long mininactive);
|
||||||
|
MATRIX *init_kernel_matrix(CONSTSET *cset, KERNEL_PARM *kparm);
|
||||||
|
MATRIX *update_kernel_matrix(MATRIX *matrix, int newpos, CONSTSET *cset,
|
||||||
|
KERNEL_PARM *kparm);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
417
src/classifier/svm/svm_struct/svm_struct_main.c
Executable file
417
src/classifier/svm/svm_struct/svm_struct_main.c
Executable file
@@ -0,0 +1,417 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_main.c */
|
||||||
|
/* */
|
||||||
|
/* Command line interface to the alignment learning module of the */
|
||||||
|
/* Support Vector Machine. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 03.07.04 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
/* the following enables you to use svm-learn out of C++ */
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
#include "../svm_light/svm_common.h"
|
||||||
|
#include "../svm_light/svm_learn.h"
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
# include "svm_struct_learn.h"
|
||||||
|
# include "svm_struct_common.h"
|
||||||
|
# include "../svm_struct_api.h"
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <assert.h>
|
||||||
|
/* } */
|
||||||
|
|
||||||
|
char trainfile[200]; /* file with training examples */
|
||||||
|
char modelfile[200]; /* file for resulting classifier */
|
||||||
|
|
||||||
|
void read_input_parameters(int, char **, char *, char *,long *, long *,
|
||||||
|
STRUCT_LEARN_PARM *, LEARN_PARM *, KERNEL_PARM *,
|
||||||
|
int *);
|
||||||
|
void wait_any_key();
|
||||||
|
void print_help();
|
||||||
|
|
||||||
|
|
||||||
|
int main (int argc, char* argv[])
|
||||||
|
{
|
||||||
|
SAMPLE sample; /* training sample */
|
||||||
|
LEARN_PARM learn_parm;
|
||||||
|
KERNEL_PARM kernel_parm;
|
||||||
|
STRUCT_LEARN_PARM struct_parm;
|
||||||
|
STRUCTMODEL structmodel;
|
||||||
|
int alg_type;
|
||||||
|
|
||||||
|
svm_struct_learn_api_init(argc,argv);
|
||||||
|
|
||||||
|
read_input_parameters(argc,argv,trainfile,modelfile,&verbosity,
|
||||||
|
&struct_verbosity,&struct_parm,&learn_parm,
|
||||||
|
&kernel_parm,&alg_type);
|
||||||
|
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("Reading training examples..."); fflush(stdout);
|
||||||
|
}
|
||||||
|
/* read the training examples */
|
||||||
|
sample=read_struct_examples(trainfile,&struct_parm);
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("done\n"); fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Do the learning and return structmodel. */
|
||||||
|
if(alg_type == 0)
|
||||||
|
svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
|
||||||
|
else if(alg_type == 1)
|
||||||
|
svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
|
||||||
|
else if(alg_type == 2)
|
||||||
|
svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
|
||||||
|
else if(alg_type == 3)
|
||||||
|
svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
|
||||||
|
else if(alg_type == 4)
|
||||||
|
svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
|
||||||
|
else if(alg_type == 9)
|
||||||
|
svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
|
||||||
|
else
|
||||||
|
exit(1);
|
||||||
|
|
||||||
|
/* Warning: The model contains references to the original data 'docs'.
|
||||||
|
If you want to free the original data, and only keep the model, you
|
||||||
|
have to make a deep copy of 'model'. */
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("Writing learned model...");fflush(stdout);
|
||||||
|
}
|
||||||
|
write_struct_model(modelfile,&structmodel,&struct_parm);
|
||||||
|
if(struct_verbosity>=1) {
|
||||||
|
printf("done\n");fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
free_struct_sample(sample);
|
||||||
|
free_struct_model(structmodel);
|
||||||
|
|
||||||
|
svm_struct_learn_api_exit();
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*---------------------------------------------------------------------------*/
|
||||||
|
|
||||||
|
void read_input_parameters(int argc,char *argv[],char *trainfile,
|
||||||
|
char *modelfile,
|
||||||
|
long *verbosity,long *struct_verbosity,
|
||||||
|
STRUCT_LEARN_PARM *struct_parm,
|
||||||
|
LEARN_PARM *learn_parm, KERNEL_PARM *kernel_parm,
|
||||||
|
int *alg_type)
|
||||||
|
{
|
||||||
|
long i;
|
||||||
|
char type[100];
|
||||||
|
|
||||||
|
/* set default */
|
||||||
|
(*alg_type)=DEFAULT_ALG_TYPE;
|
||||||
|
struct_parm->C=-0.01;
|
||||||
|
struct_parm->slack_norm=1;
|
||||||
|
struct_parm->epsilon=DEFAULT_EPS;
|
||||||
|
struct_parm->custom_argc=0;
|
||||||
|
struct_parm->loss_function=DEFAULT_LOSS_FCT;
|
||||||
|
struct_parm->loss_type=DEFAULT_RESCALING;
|
||||||
|
struct_parm->newconstretrain=100;
|
||||||
|
struct_parm->ccache_size=5;
|
||||||
|
struct_parm->batch_size=100;
|
||||||
|
|
||||||
|
strcpy (modelfile, "svm_struct_model");
|
||||||
|
strcpy (learn_parm->predfile, "trans_predictions");
|
||||||
|
strcpy (learn_parm->alphafile, "");
|
||||||
|
(*verbosity)=0;/*verbosity for svm_light*/
|
||||||
|
(*struct_verbosity)=1; /*verbosity for struct learning portion*/
|
||||||
|
learn_parm->biased_hyperplane=1;
|
||||||
|
learn_parm->remove_inconsistent=0;
|
||||||
|
learn_parm->skip_final_opt_check=0;
|
||||||
|
learn_parm->svm_maxqpsize=10;
|
||||||
|
learn_parm->svm_newvarsinqp=0;
|
||||||
|
learn_parm->svm_iter_to_shrink=-9999;
|
||||||
|
learn_parm->maxiter=100000;
|
||||||
|
learn_parm->kernel_cache_size=40;
|
||||||
|
learn_parm->svm_c=99999999; /* overridden by struct_parm->C */
|
||||||
|
learn_parm->eps=0.001; /* overridden by struct_parm->epsilon */
|
||||||
|
learn_parm->transduction_posratio=-1.0;
|
||||||
|
learn_parm->svm_costratio=1.0;
|
||||||
|
learn_parm->svm_costratio_unlab=1.0;
|
||||||
|
learn_parm->svm_unlabbound=1E-5;
|
||||||
|
learn_parm->epsilon_crit=0.001;
|
||||||
|
learn_parm->epsilon_a=1E-10; /* changed from 1e-15 */
|
||||||
|
learn_parm->compute_loo=0;
|
||||||
|
learn_parm->rho=1.0;
|
||||||
|
learn_parm->xa_depth=0;
|
||||||
|
kernel_parm->kernel_type=0;
|
||||||
|
kernel_parm->poly_degree=3;
|
||||||
|
kernel_parm->rbf_gamma=1.0;
|
||||||
|
kernel_parm->coef_lin=1;
|
||||||
|
kernel_parm->coef_const=1;
|
||||||
|
strcpy(kernel_parm->custom,"empty");
|
||||||
|
strcpy(type,"c");
|
||||||
|
|
||||||
|
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||||
|
switch ((argv[i])[1])
|
||||||
|
{
|
||||||
|
case '?': print_help(); exit(0);
|
||||||
|
case 'a': i++; strcpy(learn_parm->alphafile,argv[i]); break;
|
||||||
|
case 'c': i++; struct_parm->C=atof(argv[i]); break;
|
||||||
|
case 'p': i++; struct_parm->slack_norm=atol(argv[i]); break;
|
||||||
|
case 'e': i++; struct_parm->epsilon=atof(argv[i]); break;
|
||||||
|
case 'k': i++; struct_parm->newconstretrain=atol(argv[i]); break;
|
||||||
|
case 'h': i++; learn_parm->svm_iter_to_shrink=atol(argv[i]); break;
|
||||||
|
case '#': i++; learn_parm->maxiter=atol(argv[i]); break;
|
||||||
|
case 'm': i++; learn_parm->kernel_cache_size=atol(argv[i]); break;
|
||||||
|
case 'w': i++; (*alg_type)=atol(argv[i]); break;
|
||||||
|
case 'o': i++; struct_parm->loss_type=atol(argv[i]); break;
|
||||||
|
case 'n': i++; learn_parm->svm_newvarsinqp=atol(argv[i]); break;
|
||||||
|
case 'q': i++; learn_parm->svm_maxqpsize=atol(argv[i]); break;
|
||||||
|
case 'l': i++; struct_parm->loss_function=atol(argv[i]); break;
|
||||||
|
case 'f': i++; struct_parm->ccache_size=atol(argv[i]); break;
|
||||||
|
case 'b': i++; struct_parm->batch_size=atof(argv[i]); break;
|
||||||
|
case 't': i++; kernel_parm->kernel_type=atol(argv[i]); break;
|
||||||
|
case 'd': i++; kernel_parm->poly_degree=atol(argv[i]); break;
|
||||||
|
case 'g': i++; kernel_parm->rbf_gamma=atof(argv[i]); break;
|
||||||
|
case 's': i++; kernel_parm->coef_lin=atof(argv[i]); break;
|
||||||
|
case 'r': i++; kernel_parm->coef_const=atof(argv[i]); break;
|
||||||
|
case 'u': i++; strcpy(kernel_parm->custom,argv[i]); break;
|
||||||
|
case '-': strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);i++; strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);break;
|
||||||
|
case 'v': i++; (*struct_verbosity)=atol(argv[i]); break;
|
||||||
|
case 'y': i++; (*verbosity)=atol(argv[i]); break;
|
||||||
|
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(i>=argc) {
|
||||||
|
printf("\nNot enough input parameters!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
strcpy (trainfile, argv[i]);
|
||||||
|
if((i+1)<argc) {
|
||||||
|
strcpy (modelfile, argv[i+1]);
|
||||||
|
}
|
||||||
|
if(learn_parm->svm_iter_to_shrink == -9999) {
|
||||||
|
learn_parm->svm_iter_to_shrink=100;
|
||||||
|
}
|
||||||
|
|
||||||
|
if((learn_parm->skip_final_opt_check)
|
||||||
|
&& (kernel_parm->kernel_type == LINEAR)) {
|
||||||
|
printf("\nIt does not make sense to skip the final optimality check for linear kernels.\n\n");
|
||||||
|
learn_parm->skip_final_opt_check=0;
|
||||||
|
}
|
||||||
|
if((learn_parm->skip_final_opt_check)
|
||||||
|
&& (learn_parm->remove_inconsistent)) {
|
||||||
|
printf("\nIt is necessary to do the final optimality check when removing inconsistent \nexamples.\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if((learn_parm->svm_maxqpsize<2)) {
|
||||||
|
printf("\nMaximum size of QP-subproblems not in valid range: %ld [2..]\n",learn_parm->svm_maxqpsize);
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if((learn_parm->svm_maxqpsize<learn_parm->svm_newvarsinqp)) {
|
||||||
|
printf("\nMaximum size of QP-subproblems [%ld] must be larger than the number of\n",learn_parm->svm_maxqpsize);
|
||||||
|
printf("new variables [%ld] entering the working set in each iteration.\n",learn_parm->svm_newvarsinqp);
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(learn_parm->svm_iter_to_shrink<1) {
|
||||||
|
printf("\nMaximum number of iterations for shrinking not in valid range: %ld [1,..]\n",learn_parm->svm_iter_to_shrink);
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(struct_parm->C<0) {
|
||||||
|
printf("\nYou have to specify a value for the parameter '-c' (C>0)!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(((*alg_type) < 0) || (((*alg_type) > 5) && ((*alg_type) != 9))) {
|
||||||
|
printf("\nAlgorithm type must be either '0', '1', '2', '3', '4', or '9'!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(learn_parm->transduction_posratio>1) {
|
||||||
|
printf("\nThe fraction of unlabeled examples to classify as positives must\n");
|
||||||
|
printf("be less than 1.0 !!!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(learn_parm->svm_costratio<=0) {
|
||||||
|
printf("\nThe COSTRATIO parameter must be greater than zero!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(struct_parm->epsilon<=0) {
|
||||||
|
printf("\nThe epsilon parameter must be greater than zero!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if((struct_parm->ccache_size<=0) && ((*alg_type) == 4)) {
|
||||||
|
printf("\nThe cache size must be at least 1!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(((struct_parm->batch_size<=0) || (struct_parm->batch_size>100))
|
||||||
|
&& ((*alg_type) == 4)) {
|
||||||
|
printf("\nThe batch size must be in the interval ]0,100]!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if((struct_parm->slack_norm<1) || (struct_parm->slack_norm>2)) {
|
||||||
|
printf("\nThe norm of the slacks must be either 1 (L1-norm) or 2 (L2-norm)!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if((struct_parm->loss_type != SLACK_RESCALING)
|
||||||
|
&& (struct_parm->loss_type != MARGIN_RESCALING)) {
|
||||||
|
printf("\nThe loss type must be either 1 (slack rescaling) or 2 (margin rescaling)!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if(learn_parm->rho<0) {
|
||||||
|
printf("\nThe parameter rho for xi/alpha-estimates and leave-one-out pruning must\n");
|
||||||
|
printf("be greater than zero (typically 1.0 or 2.0, see T. Joachims, Estimating the\n");
|
||||||
|
printf("Generalization Performance of an SVM Efficiently, ICML, 2000.)!\n\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
if((learn_parm->xa_depth<0) || (learn_parm->xa_depth>100)) {
|
||||||
|
printf("\nThe parameter depth for ext. xi/alpha-estimates must be in [0..100] (zero\n");
|
||||||
|
printf("for switching to the conventional xa/estimates described in T. Joachims,\n");
|
||||||
|
printf("Estimating the Generalization Performance of an SVM Efficiently, ICML, 2000.)\n");
|
||||||
|
wait_any_key();
|
||||||
|
print_help();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
parse_struct_parameters(struct_parm);
|
||||||
|
}
|
||||||
|
|
||||||
|
void wait_any_key()
|
||||||
|
{
|
||||||
|
printf("\n(more)\n");
|
||||||
|
(void)getc(stdin);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_help()
|
||||||
|
{
|
||||||
|
printf("\nSVM-struct learning module: %s, %s, %s\n",INST_NAME,INST_VERSION,INST_VERSION_DATE);
|
||||||
|
printf(" includes SVM-struct %s for learning complex outputs, %s\n",STRUCT_VERSION,STRUCT_VERSION_DATE);
|
||||||
|
printf(" includes SVM-light %s quadratic optimizer, %s\n",VERSION,VERSION_DATE);
|
||||||
|
copyright_notice();
|
||||||
|
printf(" usage: svm_struct_learn [options] example_file model_file\n\n");
|
||||||
|
printf("Arguments:\n");
|
||||||
|
printf(" example_file-> file with training data\n");
|
||||||
|
printf(" model_file -> file to store learned decision rule in\n");
|
||||||
|
|
||||||
|
printf("General Options:\n");
|
||||||
|
printf(" -? -> this help\n");
|
||||||
|
printf(" -v [0..3] -> verbosity level (default 1)\n");
|
||||||
|
printf(" -y [0..3] -> verbosity level for svm_light (default 0)\n");
|
||||||
|
printf("Learning Options:\n");
|
||||||
|
printf(" -c float -> C: trade-off between training error\n");
|
||||||
|
printf(" and margin (default 0.01)\n");
|
||||||
|
printf(" -p [1,2] -> L-norm to use for slack variables. Use 1 for L1-norm,\n");
|
||||||
|
printf(" use 2 for squared slacks. (default 1)\n");
|
||||||
|
printf(" -o [1,2] -> Rescaling method to use for loss.\n");
|
||||||
|
printf(" 1: slack rescaling\n");
|
||||||
|
printf(" 2: margin rescaling\n");
|
||||||
|
printf(" (default %d)\n",DEFAULT_RESCALING);
|
||||||
|
printf(" -l [0..] -> Loss function to use.\n");
|
||||||
|
printf(" 0: zero/one loss\n");
|
||||||
|
printf(" ?: see below in application specific options\n");
|
||||||
|
printf(" (default %d)\n",DEFAULT_LOSS_FCT);
|
||||||
|
printf("Optimization Options (see [2][5]):\n");
|
||||||
|
printf(" -w [0,..,9] -> choice of structural learning algorithm (default %d):\n",(int)DEFAULT_ALG_TYPE);
|
||||||
|
printf(" 0: n-slack algorithm described in [2]\n");
|
||||||
|
printf(" 1: n-slack algorithm with shrinking heuristic\n");
|
||||||
|
printf(" 2: 1-slack algorithm (primal) described in [5]\n");
|
||||||
|
printf(" 3: 1-slack algorithm (dual) described in [5]\n");
|
||||||
|
printf(" 4: 1-slack algorithm (dual) with constraint cache [5]\n");
|
||||||
|
printf(" 9: custom algorithm in svm_struct_learn_custom.c\n");
|
||||||
|
printf(" -e float -> epsilon: allow that tolerance for termination\n");
|
||||||
|
printf(" criterion (default %f)\n",DEFAULT_EPS);
|
||||||
|
printf(" -k [1..] -> number of new constraints to accumulate before\n");
|
||||||
|
printf(" recomputing the QP solution (default 100) (-w 0 and 1 only)\n");
|
||||||
|
printf(" -f [5..] -> number of constraints to cache for each example\n");
|
||||||
|
printf(" (default 5) (used with -w 4)\n");
|
||||||
|
printf(" -b [1..100] -> percentage of training set for which to refresh cache\n");
|
||||||
|
printf(" when no epsilon violated constraint can be constructed\n");
|
||||||
|
printf(" from current cache (default 100%%) (used with -w 4)\n");
|
||||||
|
printf("SVM-light Options for Solving QP Subproblems (see [3]):\n");
|
||||||
|
printf(" -n [2..q] -> number of new variables entering the working set\n");
|
||||||
|
printf(" in each svm-light iteration (default n = q). \n");
|
||||||
|
printf(" Set n < q to prevent zig-zagging.\n");
|
||||||
|
printf(" -m [5..] -> size of svm-light cache for kernel evaluations in MB\n");
|
||||||
|
printf(" (default 40) (used only for -w 1 with kernels)\n");
|
||||||
|
printf(" -h [5..] -> number of svm-light iterations a variable needs to be\n");
|
||||||
|
printf(" optimal before considered for shrinking (default 100)\n");
|
||||||
|
printf(" -# int -> terminate svm-light QP subproblem optimization, if no\n");
|
||||||
|
printf(" progress after this number of iterations.\n");
|
||||||
|
printf(" (default 100000)\n");
|
||||||
|
printf("Kernel Options:\n");
|
||||||
|
printf(" -t int -> type of kernel function:\n");
|
||||||
|
printf(" 0: linear (default)\n");
|
||||||
|
printf(" 1: polynomial (s a*b+c)^d\n");
|
||||||
|
printf(" 2: radial basis function exp(-gamma ||a-b||^2)\n");
|
||||||
|
printf(" 3: sigmoid tanh(s a*b + c)\n");
|
||||||
|
printf(" 4: user defined kernel from kernel.h\n");
|
||||||
|
printf(" -d int -> parameter d in polynomial kernel\n");
|
||||||
|
printf(" -g float -> parameter gamma in rbf kernel\n");
|
||||||
|
printf(" -s float -> parameter s in sigmoid/poly kernel\n");
|
||||||
|
printf(" -r float -> parameter c in sigmoid/poly kernel\n");
|
||||||
|
printf(" -u string -> parameter of user defined kernel\n");
|
||||||
|
printf("Output Options:\n");
|
||||||
|
printf(" -a string -> write all alphas to this file after learning\n");
|
||||||
|
printf(" (in the same order as in the training set)\n");
|
||||||
|
printf("Application-Specific Options:\n");
|
||||||
|
print_struct_help();
|
||||||
|
wait_any_key();
|
||||||
|
|
||||||
|
printf("\nMore details in:\n");
|
||||||
|
printf("[1] T. Joachims, Learning to Align Sequences: A Maximum Margin Aproach.\n");
|
||||||
|
printf(" Technical Report, September, 2003.\n");
|
||||||
|
printf("[2] I. Tsochantaridis, T. Joachims, T. Hofmann, and Y. Altun, Large Margin\n");
|
||||||
|
printf(" Methods for Structured and Interdependent Output Variables, Journal\n");
|
||||||
|
printf(" of Machine Learning Research (JMLR), Vol. 6(Sep):1453-1484, 2005.\n");
|
||||||
|
printf("[3] T. Joachims, Making Large-Scale SVM Learning Practical. Advances in\n");
|
||||||
|
printf(" Kernel Methods - Support Vector Learning, B. Sch<63>lkopf and C. Burges and\n");
|
||||||
|
printf(" A. Smola (ed.), MIT Press, 1999.\n");
|
||||||
|
printf("[4] T. Joachims, Learning to Classify Text Using Support Vector\n");
|
||||||
|
printf(" Machines: Methods, Theory, and Algorithms. Dissertation, Kluwer,\n");
|
||||||
|
printf(" 2002.\n");
|
||||||
|
printf("[5] T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of Structural\n");
|
||||||
|
printf(" SVMs, Machine Learning Journal, to appear.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
615
src/classifier/svm/svm_struct_api.c
Executable file
615
src/classifier/svm/svm_struct_api.c
Executable file
@@ -0,0 +1,615 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_api.c */
|
||||||
|
/* */
|
||||||
|
/* Definition of API for attaching implementing SVM learning of */
|
||||||
|
/* structures (e.g. parsing, multi-label classification, HMM) */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 03.07.04 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#include "svm_struct_api.h"
|
||||||
|
#include "svm_struct/svm_struct_common.h"
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
void svm_struct_learn_api_init(int argc, char *argv[]) {
|
||||||
|
/* Called in learning part before anything else is done to allow
|
||||||
|
any initializations that might be necessary. */
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_struct_learn_api_exit() {
|
||||||
|
/* Called in learning part at the very end to allow any clean-up
|
||||||
|
that might be necessary. */
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_struct_classify_api_init(int argc, char *argv[]) {
|
||||||
|
/* Called in prediction part before anything else is done to allow
|
||||||
|
any initializations that might be necessary. */
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_struct_classify_api_exit() {
|
||||||
|
/* Called in prediction part at the very end to allow any clean-up
|
||||||
|
that might be necessary. */
|
||||||
|
}
|
||||||
|
|
||||||
|
SAMPLE read_struct_examples(char *file, STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Reads training examples and returns them in sample. The number of
|
||||||
|
examples must be written into sample.n */
|
||||||
|
SAMPLE sample; /* sample */
|
||||||
|
EXAMPLE *examples;
|
||||||
|
long n; /* number of examples */
|
||||||
|
DOC **docs; /* examples in original SVM-light format */
|
||||||
|
double *target;
|
||||||
|
long totwords, i, num_classes = 0;
|
||||||
|
|
||||||
|
/* Using the read_documents function from SVM-light */
|
||||||
|
read_documents(file, &docs, &target, &totwords, &n);
|
||||||
|
examples = (EXAMPLE *)my_malloc(sizeof(EXAMPLE) * n);
|
||||||
|
for (i = 0; i < n; i++) /* find highest class label */
|
||||||
|
if (num_classes < (target[i] + 0.1))
|
||||||
|
num_classes = target[i] + 0.1;
|
||||||
|
for (i = 0; i < n; i++) /* make sure all class labels are positive */
|
||||||
|
if (target[i] < 1) {
|
||||||
|
printf("\nERROR: The class label '%lf' of example number %ld is not "
|
||||||
|
"greater than '1'!\n",
|
||||||
|
target[i], i + 1);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
for (i = 0; i < n; i++) { /* copy docs over into new datastructure */
|
||||||
|
examples[i].x.doc = docs[i];
|
||||||
|
examples[i].y.class_ = target[i] + 0.1;
|
||||||
|
examples[i].y.scores = NULL;
|
||||||
|
examples[i].y.num_classes_ = num_classes;
|
||||||
|
}
|
||||||
|
free(target);
|
||||||
|
free(docs);
|
||||||
|
sample.n = n;
|
||||||
|
sample.examples = examples;
|
||||||
|
|
||||||
|
if (struct_verbosity >= 0)
|
||||||
|
printf(" (%d examples) ", sample.n);
|
||||||
|
return (sample);
|
||||||
|
}
|
||||||
|
|
||||||
|
void init_struct_model(SAMPLE sample, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm) {
|
||||||
|
/* Initialize structmodel sm. The weight vector w does not need to be
|
||||||
|
initialized, but you need to provide the maximum size of the
|
||||||
|
feature space in sizePsi. This is the maximum number of different
|
||||||
|
weights that can be learned. Later, the weight vector w will
|
||||||
|
contain the learned weights for the model. */
|
||||||
|
long i, totwords = 0;
|
||||||
|
WORD *w;
|
||||||
|
|
||||||
|
sparm->num_classes_ = 1;
|
||||||
|
for (i = 0; i < sample.n; i++) /* find highest class label */
|
||||||
|
if (sparm->num_classes_ < (sample.examples[i].y.class_ + 0.1))
|
||||||
|
sparm->num_classes_ = sample.examples[i].y.class_ + 0.1;
|
||||||
|
for (i = 0; i < sample.n; i++) /* find highest feature number */
|
||||||
|
for (w = sample.examples[i].x.doc->fvec->words; w->wnum; w++)
|
||||||
|
if (totwords < w->wnum)
|
||||||
|
totwords = w->wnum;
|
||||||
|
sparm->num_features = totwords;
|
||||||
|
if (struct_verbosity >= 0)
|
||||||
|
printf("Training set properties: %d features, %d classes\n",
|
||||||
|
sparm->num_features, sparm->num_classes_);
|
||||||
|
sm->sizePsi = sparm->num_features * sparm->num_classes_;
|
||||||
|
if (struct_verbosity >= 2)
|
||||||
|
printf("Size of Phi: %ld\n", sm->sizePsi);
|
||||||
|
}
|
||||||
|
|
||||||
|
CONSTSET init_struct_constraints(SAMPLE sample, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Initializes the optimization problem. Typically, you do not need
|
||||||
|
to change this function, since you want to start with an empty
|
||||||
|
set of constraints. However, if for example you have constraints
|
||||||
|
that certain weights need to be positive, you might put that in
|
||||||
|
here. The constraints are represented as lhs[i]*w >= rhs[i]. lhs
|
||||||
|
is an array of feature vectors, rhs is an array of doubles. m is
|
||||||
|
the number of constraints. The function returns the initial
|
||||||
|
set of constraints. */
|
||||||
|
CONSTSET c;
|
||||||
|
long sizePsi = sm->sizePsi;
|
||||||
|
long i;
|
||||||
|
WORD words[2];
|
||||||
|
|
||||||
|
if (1) { /* normal case: start with empty set of constraints */
|
||||||
|
c.lhs = NULL;
|
||||||
|
c.rhs = NULL;
|
||||||
|
c.m = 0;
|
||||||
|
} else { /* add constraints so that all learned weights are
|
||||||
|
positive. WARNING: Currently, they are positive only up to
|
||||||
|
precision epsilon set by -e. */
|
||||||
|
c.lhs = my_malloc(sizeof(DOC *) * sizePsi);
|
||||||
|
c.rhs = my_malloc(sizeof(double) * sizePsi);
|
||||||
|
for (i = 0; i < sizePsi; i++) {
|
||||||
|
words[0].wnum = i + 1;
|
||||||
|
words[0].weight = 1.0;
|
||||||
|
words[1].wnum = 0;
|
||||||
|
/* the following slackid is a hack. we will run into problems,
|
||||||
|
if we have move than 1000000 slack sets (ie examples) */
|
||||||
|
c.lhs[i] = create_example(i, 0, 1000000 + i, 1,
|
||||||
|
create_svector(words, NULL, 1.0));
|
||||||
|
c.rhs[i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return (c);
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL classify_struct_example(PATTERN x, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Finds the label yhat for pattern x that scores the highest
|
||||||
|
according to the linear evaluation function in sm, especially the
|
||||||
|
weights sm.w. The returned label is taken as the prediction of sm
|
||||||
|
for the pattern x. The weights correspond to the features defined
|
||||||
|
by psi() and range from index 1 to index sm->sizePsi. If the
|
||||||
|
function cannot find a label, it shall return an empty label as
|
||||||
|
recognized by the function empty_label(y). */
|
||||||
|
LABEL y;
|
||||||
|
DOC doc;
|
||||||
|
long class_, bestclass = -1, first = 1, j;
|
||||||
|
double score, bestscore = -1;
|
||||||
|
WORD *words;
|
||||||
|
|
||||||
|
doc = *(x.doc);
|
||||||
|
y.scores = (double *)my_malloc(sizeof(double) * (sparm->num_classes_ + 1));
|
||||||
|
y.num_classes_ = sparm->num_classes_;
|
||||||
|
words = doc.fvec->words;
|
||||||
|
for (j = 0; (words[j]).wnum != 0; j++) { /* Check if feature numbers */
|
||||||
|
if ((words[j]).wnum > sparm->num_features) /* are not larger than in */
|
||||||
|
(words[j]).wnum = 0; /* model. Remove feature if */
|
||||||
|
} /* necessary. */
|
||||||
|
for (class_ = 1; class_ <= sparm->num_classes_; class_++) {
|
||||||
|
y.class_ = class_;
|
||||||
|
doc.fvec = psi(x, y, sm, sparm);
|
||||||
|
score = classify_example(sm->svm_model, &doc);
|
||||||
|
free_svector(doc.fvec);
|
||||||
|
y.scores[class_] = score;
|
||||||
|
if ((bestscore < score) || (first)) {
|
||||||
|
bestscore = score;
|
||||||
|
bestclass = class_;
|
||||||
|
first = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
y.class_ = bestclass;
|
||||||
|
return (y);
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y,
|
||||||
|
STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Finds the label ybar for pattern x that that is responsible for
|
||||||
|
the most violated constraint for the slack rescaling
|
||||||
|
formulation. It has to take into account the scoring function in
|
||||||
|
sm, especially the weights sm.w, as well as the loss
|
||||||
|
function. The weights in sm.w correspond to the features defined
|
||||||
|
by psi() and range from index 1 to index sm->sizePsi. Most simple
|
||||||
|
is the case of the zero/one loss function. For the zero/one loss,
|
||||||
|
this function should return the highest scoring label ybar, if
|
||||||
|
ybar is unequal y; if it is equal to the correct label y, then
|
||||||
|
the function shall return the second highest scoring label. If
|
||||||
|
the function cannot find a label, it shall return an empty label
|
||||||
|
as recognized by the function empty_label(y). */
|
||||||
|
LABEL ybar;
|
||||||
|
DOC doc;
|
||||||
|
long class_, bestclass = -1, first = 1;
|
||||||
|
double score, score_y, score_ybar, bestscore = -1;
|
||||||
|
|
||||||
|
/* NOTE: This function could be made much more efficient by not
|
||||||
|
always computing a new PSI vector. */
|
||||||
|
doc = *(x.doc);
|
||||||
|
doc.fvec = psi(x, y, sm, sparm);
|
||||||
|
score_y = classify_example(sm->svm_model, &doc);
|
||||||
|
free_svector(doc.fvec);
|
||||||
|
|
||||||
|
ybar.scores = NULL;
|
||||||
|
ybar.num_classes_ = sparm->num_classes_;
|
||||||
|
for (class_ = 1; class_ <= sparm->num_classes_; class_++) {
|
||||||
|
ybar.class_ = class_;
|
||||||
|
doc.fvec = psi(x, ybar, sm, sparm);
|
||||||
|
score_ybar = classify_example(sm->svm_model, &doc);
|
||||||
|
free_svector(doc.fvec);
|
||||||
|
score = loss(y, ybar, sparm) * (1.0 - score_y + score_ybar);
|
||||||
|
if ((bestscore < score) || (first)) {
|
||||||
|
bestscore = score;
|
||||||
|
bestclass = class_;
|
||||||
|
first = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bestclass == -1)
|
||||||
|
printf("ERROR: Only one class\n");
|
||||||
|
ybar.class_ = bestclass;
|
||||||
|
if (struct_verbosity >= 3)
|
||||||
|
printf("[%ld:%.2f] ", bestclass, bestscore);
|
||||||
|
return (ybar);
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y,
|
||||||
|
STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Finds the label ybar for pattern x that that is responsible for
|
||||||
|
the most violated constraint for the margin rescaling
|
||||||
|
formulation. It has to take into account the scoring function in
|
||||||
|
sm, especially the weights sm.w, as well as the loss
|
||||||
|
function. The weights in sm.w correspond to the features defined
|
||||||
|
by psi() and range from index 1 to index sm->sizePsi. Most simple
|
||||||
|
is the case of the zero/one loss function. For the zero/one loss,
|
||||||
|
this function should return the highest scoring label ybar, if
|
||||||
|
ybar is unequal y; if it is equal to the correct label y, then
|
||||||
|
the function shall return the second highest scoring label. If
|
||||||
|
the function cannot find a label, it shall return an empty label
|
||||||
|
as recognized by the function empty_label(y). */
|
||||||
|
LABEL ybar;
|
||||||
|
DOC doc;
|
||||||
|
long class_, bestclass = -1, first = 1;
|
||||||
|
double score, bestscore = -1;
|
||||||
|
|
||||||
|
/* NOTE: This function could be made much more efficient by not
|
||||||
|
always computing a new PSI vector. */
|
||||||
|
doc = *(x.doc);
|
||||||
|
ybar.scores = NULL;
|
||||||
|
ybar.num_classes_ = sparm->num_classes_;
|
||||||
|
for (class_ = 1; class_ <= sparm->num_classes_; class_++) {
|
||||||
|
ybar.class_ = class_;
|
||||||
|
doc.fvec = psi(x, ybar, sm, sparm);
|
||||||
|
score = classify_example(sm->svm_model, &doc);
|
||||||
|
free_svector(doc.fvec);
|
||||||
|
score += loss(y, ybar, sparm);
|
||||||
|
if ((bestscore < score) || (first)) {
|
||||||
|
bestscore = score;
|
||||||
|
bestclass = class_;
|
||||||
|
first = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (bestclass == -1)
|
||||||
|
printf("ERROR: Only one class\n");
|
||||||
|
ybar.class_ = bestclass;
|
||||||
|
if (struct_verbosity >= 3)
|
||||||
|
printf("[%ld:%.2f] ", bestclass, bestscore);
|
||||||
|
return (ybar);
|
||||||
|
}
|
||||||
|
|
||||||
|
int empty_label(LABEL y) {
|
||||||
|
/* Returns true, if y is an empty label. An empty label might be
|
||||||
|
returned by find_most_violated_constraint_???(x, y, sm) if there
|
||||||
|
is no incorrect label that can be found for x, or if it is unable
|
||||||
|
to label x at all */
|
||||||
|
return (y.class_ < 0.9);
|
||||||
|
}
|
||||||
|
|
||||||
|
SVECTOR *psi(PATTERN x, LABEL y, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Returns a feature vector describing the match between pattern x and
|
||||||
|
label y. The feature vector is returned as an SVECTOR
|
||||||
|
(i.e. pairs <featurenumber:featurevalue>), where the last pair has
|
||||||
|
featurenumber 0 as a terminator. Featurenumbers start with 1 and end with
|
||||||
|
sizePsi. This feature vector determines the linear evaluation
|
||||||
|
function that is used to score labels. There will be one weight in
|
||||||
|
sm.w for each feature. Note that psi has to match
|
||||||
|
find_most_violated_constraint_???(x, y, sm) and vice versa. In
|
||||||
|
particular, find_most_violated_constraint_???(x, y, sm) finds that
|
||||||
|
ybar!=y that maximizes psi(x,ybar,sm)*sm.w (where * is the inner
|
||||||
|
vector product) and the appropriate function of the loss. */
|
||||||
|
SVECTOR *fvec;
|
||||||
|
|
||||||
|
/* shift the feature numbers to the position of weight vector of class y */
|
||||||
|
fvec = shift_s(x.doc->fvec, (y.class_ - 1) * sparm->num_features);
|
||||||
|
|
||||||
|
/* The following makes sure that the weight vectors for each class
|
||||||
|
are treated separately when kernels are used . */
|
||||||
|
fvec->kernel_id = y.class_;
|
||||||
|
|
||||||
|
return (fvec);
|
||||||
|
}
|
||||||
|
|
||||||
|
double loss(LABEL y, LABEL ybar, STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* loss for correct label y and predicted label ybar. The loss for
|
||||||
|
y==ybar has to be zero. sparm->loss_function is set with the -l option. */
|
||||||
|
if (sparm->loss_function == 0) { /* type 0 loss: 0/1 loss */
|
||||||
|
if (y.class_ == ybar.class_) /* return 0, if y==ybar. return 100 else */
|
||||||
|
return (0);
|
||||||
|
else
|
||||||
|
return (100);
|
||||||
|
}
|
||||||
|
if (sparm->loss_function == 1) { /* type 1 loss: squared difference */
|
||||||
|
return ((y.class_ - ybar.class_) * (y.class_ - ybar.class_));
|
||||||
|
} else {
|
||||||
|
/* Put your code for different loss functions here. But then
|
||||||
|
find_most_violated_constraint_???(x, y, sm) has to return the
|
||||||
|
highest scoring label with the largest loss. */
|
||||||
|
printf("Unkown loss function\n");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int finalize_iteration(double ceps, int cached_constraint, SAMPLE sample,
|
||||||
|
STRUCTMODEL *sm, CONSTSET cset, double *alpha,
|
||||||
|
STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* This function is called just before the end of each cutting plane
|
||||||
|
* iteration. ceps is the amount by which the most violated constraint found
|
||||||
|
* in the current iteration was violated. cached_constraint is true if the
|
||||||
|
* added constraint was constructed from the cache. If the return value is
|
||||||
|
* FALSE, then the algorithm is allowed to terminate. If it is TRUE, the
|
||||||
|
* algorithm will keep iterating even if the desired precision sparm->epsilon
|
||||||
|
* is already reached. */
|
||||||
|
return (0);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_struct_learning_stats(SAMPLE sample, STRUCTMODEL *sm, CONSTSET cset,
|
||||||
|
double *alpha, STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* This function is called after training and allows final touches to
|
||||||
|
the model sm. But primarly it allows computing and printing any
|
||||||
|
kind of statistic (e.g. training error) you might want. */
|
||||||
|
|
||||||
|
/* Replace SV with single weight vector */
|
||||||
|
MODEL *model = sm->svm_model;
|
||||||
|
if (model->kernel_parm.kernel_type == LINEAR) {
|
||||||
|
if (struct_verbosity >= 1) {
|
||||||
|
printf("Compacting linear model...");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
sm->svm_model = compact_linear_model(model);
|
||||||
|
sm->w = sm->svm_model->lin_weights; /* short cut to weight vector */
|
||||||
|
free_model(model, 1);
|
||||||
|
if (struct_verbosity >= 1) {
|
||||||
|
printf("done\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_struct_model(char *file, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Writes structural model sm to file file. */
|
||||||
|
FILE *modelfl;
|
||||||
|
long j, i, sv_num;
|
||||||
|
MODEL *model = sm->svm_model;
|
||||||
|
SVECTOR *v;
|
||||||
|
|
||||||
|
if ((modelfl = fopen(file, "w")) == NULL) {
|
||||||
|
perror(file);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
fprintf(modelfl, "SVM-multiclass Version %s\n", INST_VERSION);
|
||||||
|
fprintf(modelfl, "%d # number of classes\n", sparm->num_classes_);
|
||||||
|
fprintf(modelfl, "%d # number of base features\n", sparm->num_features);
|
||||||
|
fprintf(modelfl, "%d # loss function\n", sparm->loss_function);
|
||||||
|
fprintf(modelfl, "%ld # kernel type\n", model->kernel_parm.kernel_type);
|
||||||
|
fprintf(modelfl, "%ld # kernel parameter -d \n",
|
||||||
|
model->kernel_parm.poly_degree);
|
||||||
|
fprintf(modelfl, "%.8g # kernel parameter -g \n",
|
||||||
|
model->kernel_parm.rbf_gamma);
|
||||||
|
fprintf(modelfl, "%.8g # kernel parameter -s \n",
|
||||||
|
model->kernel_parm.coef_lin);
|
||||||
|
fprintf(modelfl, "%.8g # kernel parameter -r \n",
|
||||||
|
model->kernel_parm.coef_const);
|
||||||
|
fprintf(modelfl, "%s# kernel parameter -u \n", model->kernel_parm.custom);
|
||||||
|
fprintf(modelfl, "%ld # highest feature index \n", model->totwords);
|
||||||
|
fprintf(modelfl, "%ld # number of training documents \n", model->totdoc);
|
||||||
|
|
||||||
|
sv_num = 1;
|
||||||
|
for (i = 1; i < model->sv_num; i++) {
|
||||||
|
for (v = model->supvec[i]->fvec; v; v = v->next)
|
||||||
|
sv_num++;
|
||||||
|
}
|
||||||
|
fprintf(modelfl, "%ld # number of support vectors plus 1 \n", sv_num);
|
||||||
|
fprintf(modelfl,
|
||||||
|
"%.8g # threshold b, each following line is a SV (starting with "
|
||||||
|
"alpha*y)\n",
|
||||||
|
model->b);
|
||||||
|
|
||||||
|
for (i = 1; i < model->sv_num; i++) {
|
||||||
|
for (v = model->supvec[i]->fvec; v; v = v->next) {
|
||||||
|
fprintf(modelfl, "%.32g ", model->alpha[i] * v->factor);
|
||||||
|
fprintf(modelfl, "qid:%ld ", v->kernel_id);
|
||||||
|
for (j = 0; (v->words[j]).wnum; j++) {
|
||||||
|
fprintf(modelfl, "%ld:%.8g ", (long)(v->words[j]).wnum,
|
||||||
|
(double)(v->words[j]).weight);
|
||||||
|
}
|
||||||
|
if (v->userdefined)
|
||||||
|
fprintf(modelfl, "#%s\n", v->userdefined);
|
||||||
|
else
|
||||||
|
fprintf(modelfl, "#\n");
|
||||||
|
/* NOTE: this could be made more efficient by summing the
|
||||||
|
alpha's of identical vectors before writing them to the
|
||||||
|
file. */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fclose(modelfl);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_struct_testing_stats(SAMPLE sample, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm,
|
||||||
|
STRUCT_TEST_STATS *teststats) {
|
||||||
|
/* This function is called after making all test predictions in
|
||||||
|
svm_struct_classify and allows computing and printing any kind of
|
||||||
|
evaluation (e.g. precision/recall) you might want. You can use
|
||||||
|
the function eval_prediction to accumulate the necessary
|
||||||
|
statistics for each prediction. */
|
||||||
|
}
|
||||||
|
|
||||||
|
void eval_prediction(long exnum, EXAMPLE ex, LABEL ypred, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm, STRUCT_TEST_STATS *teststats) {
|
||||||
|
/* This function allows you to accumlate statistic for how well the
|
||||||
|
predicition matches the labeled example. It is called from
|
||||||
|
svm_struct_classify. See also the function
|
||||||
|
print_struct_testing_stats. */
|
||||||
|
if (exnum == 0) { /* this is the first time the function is
|
||||||
|
called. So initialize the teststats */
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
STRUCTMODEL read_struct_model(char *file, STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Reads structural model sm from file file. This function is used
|
||||||
|
only in the prediction module, not in the learning module. */
|
||||||
|
FILE *modelfl;
|
||||||
|
STRUCTMODEL sm;
|
||||||
|
long i, queryid, slackid;
|
||||||
|
double costfactor;
|
||||||
|
long max_sv, max_words, ll, wpos;
|
||||||
|
char *line, *comment;
|
||||||
|
WORD *words;
|
||||||
|
char version_buffer[100];
|
||||||
|
MODEL *model;
|
||||||
|
|
||||||
|
nol_ll(file, &max_sv, &max_words, &ll); /* scan size of model file */
|
||||||
|
max_words += 2;
|
||||||
|
ll += 2;
|
||||||
|
|
||||||
|
words = (WORD *)my_malloc(sizeof(WORD) * (max_words + 10));
|
||||||
|
line = (char *)my_malloc(sizeof(char) * ll);
|
||||||
|
model = (MODEL *)my_malloc(sizeof(MODEL));
|
||||||
|
|
||||||
|
if ((modelfl = fopen(file, "r")) == NULL) {
|
||||||
|
perror(file);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
fscanf(modelfl, "SVM-multiclass Version %s\n", version_buffer);
|
||||||
|
if (strcmp(version_buffer, INST_VERSION)) {
|
||||||
|
perror(
|
||||||
|
"Version of model-file does not match version of svm_struct_classify!");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
fscanf(modelfl, "%d%*[^\n]\n", &sparm->num_classes_);
|
||||||
|
fscanf(modelfl, "%d%*[^\n]\n", &sparm->num_features);
|
||||||
|
fscanf(modelfl, "%d%*[^\n]\n", &sparm->loss_function);
|
||||||
|
fscanf(modelfl, "%ld%*[^\n]\n", &model->kernel_parm.kernel_type);
|
||||||
|
fscanf(modelfl, "%ld%*[^\n]\n", &model->kernel_parm.poly_degree);
|
||||||
|
fscanf(modelfl, "%lf%*[^\n]\n", &model->kernel_parm.rbf_gamma);
|
||||||
|
fscanf(modelfl, "%lf%*[^\n]\n", &model->kernel_parm.coef_lin);
|
||||||
|
fscanf(modelfl, "%lf%*[^\n]\n", &model->kernel_parm.coef_const);
|
||||||
|
fscanf(modelfl, "%[^#]%*[^\n]\n", model->kernel_parm.custom);
|
||||||
|
|
||||||
|
fscanf(modelfl, "%ld%*[^\n]\n", &model->totwords);
|
||||||
|
fscanf(modelfl, "%ld%*[^\n]\n", &model->totdoc);
|
||||||
|
fscanf(modelfl, "%ld%*[^\n]\n", &model->sv_num);
|
||||||
|
fscanf(modelfl, "%lf%*[^\n]\n", &model->b);
|
||||||
|
|
||||||
|
model->supvec = (DOC **)my_malloc(sizeof(DOC *) * model->sv_num);
|
||||||
|
model->alpha = (double *)my_malloc(sizeof(double) * model->sv_num);
|
||||||
|
model->index = NULL;
|
||||||
|
model->lin_weights = NULL;
|
||||||
|
|
||||||
|
for (i = 1; i < model->sv_num; i++) {
|
||||||
|
fgets(line, (int)ll, modelfl);
|
||||||
|
if (!parse_document(line, words, &(model->alpha[i]), &queryid, &slackid,
|
||||||
|
&costfactor, &wpos, max_words, &comment)) {
|
||||||
|
printf("\nParsing error while reading model file in SV %ld!\n%s", i,
|
||||||
|
line);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
model->supvec[i] =
|
||||||
|
create_example(-1, 0, 0, 0.0, create_svector(words, comment, 1.0));
|
||||||
|
model->supvec[i]->fvec->kernel_id = queryid;
|
||||||
|
}
|
||||||
|
fclose(modelfl);
|
||||||
|
free(line);
|
||||||
|
free(words);
|
||||||
|
if (verbosity >= 1) {
|
||||||
|
fprintf(stdout, " (%d support vectors read) ", (int)(model->sv_num - 1));
|
||||||
|
}
|
||||||
|
sm.svm_model = model;
|
||||||
|
sm.sizePsi = model->totwords;
|
||||||
|
sm.w = NULL;
|
||||||
|
return (sm);
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_label(FILE *fp, LABEL y) {
|
||||||
|
/* Writes label y to file handle fp. */
|
||||||
|
int i;
|
||||||
|
fprintf(fp, "%d", y.class_);
|
||||||
|
if (y.scores)
|
||||||
|
for (i = 1; i <= y.num_classes_; i++)
|
||||||
|
fprintf(fp, " %f", y.scores[i]);
|
||||||
|
fprintf(fp, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_pattern(PATTERN x) {
|
||||||
|
/* Frees the memory of x. */
|
||||||
|
free_example(x.doc, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_label(LABEL y) {
|
||||||
|
/* Frees the memory of y. */
|
||||||
|
if (y.scores)
|
||||||
|
free(y.scores);
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_struct_model(STRUCTMODEL sm) {
|
||||||
|
/* Frees the memory of model. */
|
||||||
|
/* if(sm.w) free(sm.w); */ /* this is free'd in free_model */
|
||||||
|
if (sm.svm_model)
|
||||||
|
free_model(sm.svm_model, 1);
|
||||||
|
/* add free calls for user defined data here */
|
||||||
|
}
|
||||||
|
|
||||||
|
void free_struct_sample(SAMPLE s) {
|
||||||
|
/* Frees the memory of sample s. */
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < s.n; i++) {
|
||||||
|
free_pattern(s.examples[i].x);
|
||||||
|
free_label(s.examples[i].y);
|
||||||
|
}
|
||||||
|
free(s.examples);
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_struct_help() {
|
||||||
|
/* Prints a help text that is appended to the common help text of
|
||||||
|
svm_struct_learn. */
|
||||||
|
|
||||||
|
printf(" none\n\n");
|
||||||
|
printf("Based on multi-class SVM formulation described in:\n");
|
||||||
|
printf(" K. Crammer and Y. Singer. On the Algorithmic "
|
||||||
|
"Implementation of\n");
|
||||||
|
printf(" Multi-class SVMs, JMLR, 2001.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void parse_struct_parameters(STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Parses the command line parameters that start with -- */
|
||||||
|
int i;
|
||||||
|
|
||||||
|
for (i = 0; (i < sparm->custom_argc) && ((sparm->custom_argv[i])[0] == '-');
|
||||||
|
i++) {
|
||||||
|
switch ((sparm->custom_argv[i])[2]) {
|
||||||
|
case 'a':
|
||||||
|
i++; /* strcpy(learn_parm->alphafile,argv[i]); */
|
||||||
|
break;
|
||||||
|
case 'e':
|
||||||
|
i++; /* sparm->epsilon=atof(sparm->custom_argv[i]); */
|
||||||
|
break;
|
||||||
|
case 'k':
|
||||||
|
i++; /* sparm->newconstretrain=atol(sparm->custom_argv[i]); */
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void print_struct_help_classify() {
|
||||||
|
/* Prints a help text that is appended to the common help text of
|
||||||
|
svm_struct_classify. */
|
||||||
|
}
|
||||||
|
|
||||||
|
void parse_struct_parameters_classify(STRUCT_LEARN_PARM *sparm) {
|
||||||
|
/* Parses the command line parameters that start with -- for the
|
||||||
|
classification module */
|
||||||
|
int i;
|
||||||
|
|
||||||
|
for (i = 0; (i < sparm->custom_argc) && ((sparm->custom_argv[i])[0] == '-');
|
||||||
|
i++) {
|
||||||
|
switch ((sparm->custom_argv[i])[2]) {
|
||||||
|
/* case 'x': i++; strcpy(xvalue,sparm->custom_argv[i]); break; */
|
||||||
|
default:
|
||||||
|
printf("\nUnrecognized option %s!\n\n", sparm->custom_argv[i]);
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
76
src/classifier/svm/svm_struct_api.h
Executable file
76
src/classifier/svm/svm_struct_api.h
Executable file
@@ -0,0 +1,76 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_api.h */
|
||||||
|
/* */
|
||||||
|
/* Definition of API for attaching implementing SVM learning of */
|
||||||
|
/* structures (e.g. parsing, multi-label classification, HMM) */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 03.07.04 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#include "svm_struct/svm_struct_common.h"
|
||||||
|
#include "svm_struct_api_types.h"
|
||||||
|
|
||||||
|
#ifndef svm_struct_api
|
||||||
|
#define svm_struct_api
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
void svm_struct_learn_api_init(int argc, char *argv[]);
|
||||||
|
void svm_struct_learn_api_exit();
|
||||||
|
void svm_struct_classify_api_init(int argc, char *argv[]);
|
||||||
|
void svm_struct_classify_api_exit();
|
||||||
|
SAMPLE read_struct_examples(char *file, STRUCT_LEARN_PARM *sparm);
|
||||||
|
void init_struct_model(SAMPLE sample, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm);
|
||||||
|
CONSTSET init_struct_constraints(SAMPLE sample, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm);
|
||||||
|
LABEL find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y,
|
||||||
|
STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm);
|
||||||
|
LABEL find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y,
|
||||||
|
STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm);
|
||||||
|
LABEL classify_struct_example(PATTERN x, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm);
|
||||||
|
int empty_label(LABEL y);
|
||||||
|
SVECTOR *psi(PATTERN x, LABEL y, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm);
|
||||||
|
double loss(LABEL y, LABEL ybar, STRUCT_LEARN_PARM *sparm);
|
||||||
|
int finalize_iteration(double ceps, int cached_constraint, SAMPLE sample,
|
||||||
|
STRUCTMODEL *sm, CONSTSET cset, double *alpha,
|
||||||
|
STRUCT_LEARN_PARM *sparm);
|
||||||
|
void print_struct_learning_stats(SAMPLE sample, STRUCTMODEL *sm, CONSTSET cset,
|
||||||
|
double *alpha, STRUCT_LEARN_PARM *sparm);
|
||||||
|
void print_struct_testing_stats(SAMPLE sample, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm,
|
||||||
|
STRUCT_TEST_STATS *teststats);
|
||||||
|
void eval_prediction(long exnum, EXAMPLE ex, LABEL prediction, STRUCTMODEL *sm,
|
||||||
|
STRUCT_LEARN_PARM *sparm, STRUCT_TEST_STATS *teststats);
|
||||||
|
void write_struct_model(char *file, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm);
|
||||||
|
STRUCTMODEL read_struct_model(char *file, STRUCT_LEARN_PARM *sparm);
|
||||||
|
void write_label(FILE *fp, LABEL y);
|
||||||
|
void free_pattern(PATTERN x);
|
||||||
|
void free_label(LABEL y);
|
||||||
|
void free_struct_model(STRUCTMODEL sm);
|
||||||
|
void free_struct_sample(SAMPLE s);
|
||||||
|
void print_struct_help();
|
||||||
|
void parse_struct_parameters(STRUCT_LEARN_PARM *sparm);
|
||||||
|
void print_struct_help_classify();
|
||||||
|
void parse_struct_parameters_classify(STRUCT_LEARN_PARM *sparm);
|
||||||
|
void svm_learn_struct_joint_custom(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||||
|
STRUCTMODEL *sm);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif
|
114
src/classifier/svm/svm_struct_api_types.h
Executable file
114
src/classifier/svm/svm_struct_api_types.h
Executable file
@@ -0,0 +1,114 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_api.h */
|
||||||
|
/* */
|
||||||
|
/* Definition of API for attaching implementing SVM learning of */
|
||||||
|
/* structures (e.g. parsing, multi-label classification, HMM) */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 13.10.03 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2003 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#ifndef svm_struct_api_types
|
||||||
|
#define svm_struct_api_types
|
||||||
|
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
#include "svm_light/svm_learn.h"
|
||||||
|
|
||||||
|
#define INST_NAME "Multi-Class SVM"
|
||||||
|
#define INST_VERSION "V2.20"
|
||||||
|
#define INST_VERSION_DATE "14.08.08"
|
||||||
|
|
||||||
|
/* default precision for solving the optimization problem */
|
||||||
|
#define DEFAULT_EPS 0.1
|
||||||
|
/* default loss rescaling method: 1=slack_rescaling, 2=margin_rescaling */
|
||||||
|
#define DEFAULT_RESCALING 2
|
||||||
|
/* default loss function: */
|
||||||
|
#define DEFAULT_LOSS_FCT 0
|
||||||
|
/* default optimization algorithm to use: */
|
||||||
|
#define DEFAULT_ALG_TYPE 4
|
||||||
|
/* store Psi(x,y) once instead of recomputing it every time: */
|
||||||
|
#define USE_FYCACHE 0
|
||||||
|
/* decide whether to evaluate sum before storing vectors in constraint
|
||||||
|
cache:
|
||||||
|
0 = NO,
|
||||||
|
1 = YES (best, if sparse vectors and long vector lists),
|
||||||
|
2 = YES (best, if short vector lists),
|
||||||
|
3 = YES (best, if dense vectors and long vector lists) */
|
||||||
|
#define COMPACT_CACHED_VECTORS 2
|
||||||
|
/* minimum absolute value below which values in sparse vectors are
|
||||||
|
rounded to zero. Values are stored in the FVAL type defined in svm_common.h
|
||||||
|
RECOMMENDATION: assuming you use FVAL=float, use
|
||||||
|
10E-15 if COMPACT_CACHED_VECTORS is 1
|
||||||
|
10E-10 if COMPACT_CACHED_VECTORS is 2 or 3
|
||||||
|
*/
|
||||||
|
#define COMPACT_ROUNDING_THRESH 10E-15
|
||||||
|
|
||||||
|
typedef struct pattern {
|
||||||
|
/* this defines the x-part of a training example, e.g. the structure
|
||||||
|
for storing a natural language sentence in NLP parsing */
|
||||||
|
DOC *doc;
|
||||||
|
} PATTERN;
|
||||||
|
|
||||||
|
typedef struct label {
|
||||||
|
/* this defines the y-part (the label) of a training example,
|
||||||
|
e.g. the parse tree of the corresponding sentence. */
|
||||||
|
int class_; /* class label */
|
||||||
|
int num_classes_; /* total number of classes */
|
||||||
|
double *scores; /* value of linear function of each class */
|
||||||
|
} LABEL;
|
||||||
|
|
||||||
|
typedef struct structmodel {
|
||||||
|
double *w; /* pointer to the learned weights */
|
||||||
|
MODEL *svm_model; /* the learned SVM model */
|
||||||
|
long sizePsi; /* maximum number of weights in w */
|
||||||
|
double walpha;
|
||||||
|
/* other information that is needed for the stuctural model can be
|
||||||
|
added here, e.g. the grammar rules for NLP parsing */
|
||||||
|
} STRUCTMODEL;
|
||||||
|
|
||||||
|
typedef struct struct_learn_parm {
|
||||||
|
double epsilon; /* precision for which to solve
|
||||||
|
quadratic program */
|
||||||
|
double newconstretrain; /* number of new constraints to
|
||||||
|
accumulate before recomputing the QP
|
||||||
|
solution */
|
||||||
|
int ccache_size; /* maximum number of constraints to
|
||||||
|
cache for each example (used in w=4
|
||||||
|
algorithm) */
|
||||||
|
double batch_size; /* size of the mini batches in percent
|
||||||
|
of training set size (used in w=4
|
||||||
|
algorithm) */
|
||||||
|
double C; /* trade-off between margin and loss */
|
||||||
|
char custom_argv[20][300]; /* string set with the -u command line option */
|
||||||
|
int custom_argc; /* number of -u command line options */
|
||||||
|
int slack_norm; /* norm to use in objective function
|
||||||
|
for slack variables; 1 -> L1-norm,
|
||||||
|
2 -> L2-norm */
|
||||||
|
int loss_type; /* selected loss function from -r
|
||||||
|
command line option. Select between
|
||||||
|
slack rescaling (1) and margin
|
||||||
|
rescaling (2) */
|
||||||
|
int loss_function; /* select between different loss
|
||||||
|
functions via -l command line
|
||||||
|
option */
|
||||||
|
/* further parameters that are passed to init_struct_model() */
|
||||||
|
int num_classes_;
|
||||||
|
int num_features;
|
||||||
|
} STRUCT_LEARN_PARM;
|
||||||
|
|
||||||
|
typedef struct struct_test_stats {
|
||||||
|
/* you can add variables for keeping statistics when evaluating the
|
||||||
|
test predictions in svm_struct_classify. This can be used in the
|
||||||
|
function eval_prediction and print_struct_testing_stats. */
|
||||||
|
} STRUCT_TEST_STATS;
|
||||||
|
|
||||||
|
#endif
|
42
src/classifier/svm/svm_struct_learn_custom.c
Executable file
42
src/classifier/svm/svm_struct_learn_custom.c
Executable file
@@ -0,0 +1,42 @@
|
|||||||
|
/***********************************************************************/
|
||||||
|
/* */
|
||||||
|
/* svm_struct_learn_custom.c (instantiated for SVM-perform) */
|
||||||
|
/* */
|
||||||
|
/* Allows implementing a custom/alternate algorithm for solving */
|
||||||
|
/* the structual SVM optimization problem. The algorithm can use */
|
||||||
|
/* full access to the SVM-struct API and to SVM-light. */
|
||||||
|
/* */
|
||||||
|
/* Author: Thorsten Joachims */
|
||||||
|
/* Date: 09.01.08 */
|
||||||
|
/* */
|
||||||
|
/* Copyright (c) 2008 Thorsten Joachims - All rights reserved */
|
||||||
|
/* */
|
||||||
|
/* This software is available for non-commercial use only. It must */
|
||||||
|
/* not be modified and distributed without prior permission of the */
|
||||||
|
/* author. The author is not responsible for implications from the */
|
||||||
|
/* use of this software. */
|
||||||
|
/* */
|
||||||
|
/***********************************************************************/
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "svm_struct_api.h"
|
||||||
|
#include "svm_light/svm_common.h"
|
||||||
|
#include "svm_struct/svm_struct_common.h"
|
||||||
|
#include "svm_struct/svm_struct_learn.h"
|
||||||
|
|
||||||
|
|
||||||
|
void svm_learn_struct_joint_custom(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||||
|
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||||
|
STRUCTMODEL *sm)
|
||||||
|
/* Input: sample (training examples)
|
||||||
|
sparm (structural learning parameters)
|
||||||
|
lparm (svm learning parameters)
|
||||||
|
kparm (kernel parameters)
|
||||||
|
Output: sm (learned model) */
|
||||||
|
{
|
||||||
|
/* Put your algorithm here. See svm_struct_learn.c for an example of
|
||||||
|
how to access this API. */
|
||||||
|
}
|
||||||
|
|
34
src/classifier/svm/svm_trainer.cpp
Normal file
34
src/classifier/svm/svm_trainer.cpp
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
#include "../svm_trainer.h"
|
||||||
|
#include "svm_binary_trainer.hpp"
|
||||||
|
#include "svm_multiclass_trainer.hpp"
|
||||||
|
|
||||||
|
ISVMTrainer new_svm_binary_trainer() {
|
||||||
|
return new ovclassifier::SVMBinaryTrainer();
|
||||||
|
}
|
||||||
|
ISVMTrainer new_svm_multiclass_trainer() {
|
||||||
|
return new ovclassifier::SVMMultiClassTrainer();
|
||||||
|
}
|
||||||
|
|
||||||
|
void destroy_svm_trainer(ISVMTrainer trainer) {
|
||||||
|
delete static_cast<ovclassifier::SVMTrainer *>(trainer);
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_trainer_reset(ISVMTrainer trainer) {
|
||||||
|
static_cast<ovclassifier::SVMTrainer *>(trainer)->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_trainer_set_labels(ISVMTrainer trainer, int labels) {
|
||||||
|
static_cast<ovclassifier::SVMTrainer *>(trainer)->SetLabels(labels);
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_trainer_set_features(ISVMTrainer trainer, int feats) {
|
||||||
|
static_cast<ovclassifier::SVMTrainer *>(trainer)->SetFeatures(feats);
|
||||||
|
}
|
||||||
|
|
||||||
|
void svm_trainer_add_data(ISVMTrainer trainer, int label, const float *vec) {
|
||||||
|
static_cast<ovclassifier::SVMTrainer *>(trainer)->AddData(label, vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
int svm_train(ISVMTrainer trainer, const char *modelfile) {
|
||||||
|
return static_cast<ovclassifier::SVMTrainer *>(trainer)->Train(modelfile);
|
||||||
|
}
|
16
src/classifier/svm/svm_trainer.hpp
Normal file
16
src/classifier/svm/svm_trainer.hpp
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
#ifndef _SVM_TRAINER_H_
|
||||||
|
#define _SVM_TRAINER_H_
|
||||||
|
|
||||||
|
namespace ovclassifier {
|
||||||
|
class SVMTrainer {
|
||||||
|
public:
|
||||||
|
virtual ~SVMTrainer(){};
|
||||||
|
virtual void Reset() = 0;
|
||||||
|
virtual void SetLabels(int labels) = 0;
|
||||||
|
virtual void SetFeatures(int feats) = 0;
|
||||||
|
virtual void AddData(int label, const float *vec) = 0;
|
||||||
|
virtual int Train(const char *modelfile) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace ovclassifier
|
||||||
|
#endif // _SVM_TRAINER_H_
|
19
src/classifier/svm_classifier.h
Normal file
19
src/classifier/svm_classifier.h
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
#ifndef _CLASSIFIER_SVM_CLASSIFIER_C_H_
|
||||||
|
#define _CLASSIFIER_SVM_CLASSIFIER_C_H_
|
||||||
|
|
||||||
|
#include "../common/common.h"
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include "svm/svm_classifier.hpp"
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
typedef void *ISVMClassifier;
|
||||||
|
ISVMClassifier new_svm_binary_classifier();
|
||||||
|
ISVMClassifier new_svm_multiclass_classifier();
|
||||||
|
void destroy_svm_classifier(ISVMClassifier e);
|
||||||
|
int svm_classifier_load_model(ISVMClassifier e, const char *modelfile);
|
||||||
|
double svm_predict(ISVMClassifier e, const float *vec);
|
||||||
|
int svm_classify(ISVMClassifier e, const float *vec, FloatVector *scores);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // !_CLASSIFER_SVM_CLASSIFIER_C_H_
|
20
src/classifier/svm_trainer.h
Normal file
20
src/classifier/svm_trainer.h
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#ifndef _CLASSIFIER_SVM_TRAINER_C_H_
|
||||||
|
#define _CLASSIFIER_SVM_TRAINER_C_H_
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
#include "svm/svm_trainer.hpp"
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
typedef void *ISVMTrainer;
|
||||||
|
ISVMTrainer new_svm_binary_trainer();
|
||||||
|
ISVMTrainer new_svm_multiclass_trainer();
|
||||||
|
void destroy_svm_trainer(ISVMTrainer trainer);
|
||||||
|
void svm_trainer_reset(ISVMTrainer trainer);
|
||||||
|
void svm_trainer_set_labels(ISVMTrainer trainer, int labels);
|
||||||
|
void svm_trainer_set_features(ISVMTrainer trainer, int feats);
|
||||||
|
void svm_trainer_add_data(ISVMTrainer trainer, int label, const float *vec);
|
||||||
|
int svm_train(ISVMTrainer trainer, const char *modelfile);
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
#endif // !_CLASSIFER_SVM_TRAINER_C_H_
|
@@ -1,9 +1,9 @@
|
|||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "cpu.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <math.h>
|
|
||||||
#include <float.h>
|
#include <float.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "cpu.h"
|
#include <math.h>
|
||||||
|
|
||||||
#ifdef OV_VULKAN
|
#ifdef OV_VULKAN
|
||||||
#include "gpu.h"
|
#include "gpu.h"
|
||||||
@@ -11,423 +11,407 @@
|
|||||||
|
|
||||||
int get_gpu_count() {
|
int get_gpu_count() {
|
||||||
#ifdef OV_VULKAN
|
#ifdef OV_VULKAN
|
||||||
return ncnn::get_gpu_count();
|
return ncnn::get_gpu_count();
|
||||||
#endif // OV_VULKAN
|
#endif // OV_VULKAN
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int create_gpu_instance() {
|
int create_gpu_instance() {
|
||||||
#ifdef OV_VULKAN
|
#ifdef OV_VULKAN
|
||||||
return ncnn::create_gpu_instance();
|
return ncnn::create_gpu_instance();
|
||||||
#endif // OV_VULKAN
|
#endif // OV_VULKAN
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void destroy_gpu_instance() {
|
void destroy_gpu_instance() {
|
||||||
#ifdef OV_VULKAN
|
#ifdef OV_VULKAN
|
||||||
ncnn::destroy_gpu_instance();
|
ncnn::destroy_gpu_instance();
|
||||||
#endif // OV_VULKAN
|
#endif // OV_VULKAN
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_big_cpu_count() {
|
int get_big_cpu_count() { return ncnn::get_big_cpu_count(); }
|
||||||
return ncnn::get_big_cpu_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_omp_num_threads(int n) {
|
void set_omp_num_threads(int n) {
|
||||||
#ifdef OV_OPENMP
|
#ifdef OV_OPENMP
|
||||||
ncnn::set_omp_num_threads(n);
|
ncnn::set_omp_num_threads(n);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
int load_model(IEstimator d, const char *root_path) {
|
int load_model(IEstimator d, const char *root_path) {
|
||||||
return static_cast<ov::Estimator*>(d)->LoadModel(root_path);
|
return static_cast<ov::Estimator *>(d)->LoadModel(root_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
void destroy_estimator(IEstimator d) {
|
void destroy_estimator(IEstimator d) { delete static_cast<ov::Estimator *>(d); }
|
||||||
delete static_cast<ov::Estimator*>(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_num_threads(IEstimator d, int n) {
|
void set_num_threads(IEstimator d, int n) {
|
||||||
static_cast<ov::Estimator*>(d)->set_num_threads(n);
|
static_cast<ov::Estimator *>(d)->set_num_threads(n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_light_mode(IEstimator d, bool mode) {
|
void set_light_mode(IEstimator d, bool mode) {
|
||||||
static_cast<ov::Estimator*>(d)->set_light_mode(mode);
|
static_cast<ov::Estimator *>(d)->set_light_mode(mode);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreePoint2fVector(Point2fVector* p) {
|
void FreePoint2fVector(Point2fVector *p) {
|
||||||
if (p->points != NULL) {
|
if (p->points != NULL) {
|
||||||
free(p->points);
|
free(p->points);
|
||||||
p->points = NULL;
|
p->points = NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Point2fVectorSetValue(Point2fVector *p, int i, const Point2f* val) {
|
void FreePoint3dVector(Point3dVector *p) {
|
||||||
if (p->points == NULL || i >= p->length) {
|
if (p->points != NULL) {
|
||||||
return;
|
free(p->points);
|
||||||
}
|
p->points = NULL;
|
||||||
p->points[i] = *val;
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Point2fVectorSetValue(Point2fVector *p, int i, const Point2f *val) {
|
||||||
|
if (p->points == NULL || i >= p->length) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
p->points[i] = *val;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeFloatVector(FloatVector *p) {
|
void FreeFloatVector(FloatVector *p) {
|
||||||
if (p->values != NULL) {
|
if (p->values != NULL) {
|
||||||
free(p->values);
|
free(p->values);
|
||||||
p->values = NULL;
|
p->values = NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeBytes(Bytes *p) {
|
void FreeBytes(Bytes *p) {
|
||||||
if (p->values != NULL) {
|
if (p->values != NULL) {
|
||||||
free(p->values);
|
free(p->values);
|
||||||
p->values = NULL;
|
p->values = NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeKeypointVector(KeypointVector *p) {
|
void FreeKeypointVector(KeypointVector *p) {
|
||||||
if (p->points != NULL) {
|
if (p->points != NULL) {
|
||||||
free(p->points);
|
free(p->points);
|
||||||
p->points = NULL;
|
p->points = NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void KeypointVectorSetValue(KeypointVector *p, int i, const Keypoint* val) {
|
void KeypointVectorSetValue(KeypointVector *p, int i, const Keypoint *val) {
|
||||||
if (p->points == NULL || i >= p->length) {
|
if (p->points == NULL || i >= p->length) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
p->points[i] = *val;
|
p->points[i] = *val;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeObjectInfo(ObjectInfo *p) {
|
void FreeObjectInfo(ObjectInfo *p) {
|
||||||
if (p->pts != NULL) {
|
if (p->pts != NULL) {
|
||||||
FreeKeypointVector(p->pts);
|
FreeKeypointVector(p->pts);
|
||||||
free(p->pts);
|
free(p->pts);
|
||||||
p->pts = NULL;
|
p->pts = NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeObjectInfoVector(ObjectInfoVector *p) {
|
void FreeObjectInfoVector(ObjectInfoVector *p) {
|
||||||
if (p->items!=NULL) {
|
if (p->items != NULL) {
|
||||||
for (int i=0; i < p->length; i ++) {
|
for (int i = 0; i < p->length; i++) {
|
||||||
FreeObjectInfo(&p->items[i]);
|
FreeObjectInfo(&p->items[i]);
|
||||||
}
|
|
||||||
free(p->items);
|
|
||||||
p->items= NULL;
|
|
||||||
}
|
}
|
||||||
|
free(p->items);
|
||||||
|
p->items = NULL;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void FreeImage(Image* p) {
|
void FreeImage(Image *p) {
|
||||||
if (p->data != NULL) {
|
if (p->data != NULL) {
|
||||||
free(p->data);
|
free(p->data);
|
||||||
p->data = NULL;
|
p->data = NULL;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
|
|
||||||
Estimator::Estimator() : EstimatorBase() {
|
Estimator::Estimator() : EstimatorBase() {
|
||||||
blob_allocator_.set_size_compare_ratio(0.f);
|
blob_allocator_.set_size_compare_ratio(0.f);
|
||||||
workspace_allocator_.set_size_compare_ratio(0.f);
|
workspace_allocator_.set_size_compare_ratio(0.f);
|
||||||
net_ = new ncnn::Net();
|
net_ = new ncnn::Net();
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
if (num_threads > 0) {
|
if (num_threads > 0) {
|
||||||
net_->opt.num_threads = num_threads;
|
net_->opt.num_threads = num_threads;
|
||||||
}
|
}
|
||||||
net_->opt.blob_allocator = &blob_allocator_;
|
net_->opt.blob_allocator = &blob_allocator_;
|
||||||
net_->opt.workspace_allocator = &workspace_allocator_;
|
net_->opt.workspace_allocator = &workspace_allocator_;
|
||||||
|
net_->opt.lightmode = light_mode_;
|
||||||
#ifdef OV_VULKAN
|
#ifdef OV_VULKAN
|
||||||
net_->opt.use_vulkan_compute = true;
|
net_->opt.use_vulkan_compute = true;
|
||||||
#endif // OV_VULKAN
|
#endif // OV_VULKAN
|
||||||
}
|
}
|
||||||
|
|
||||||
Estimator::~Estimator() {
|
Estimator::~Estimator() {
|
||||||
if (net_) {
|
if (net_) {
|
||||||
net_->clear();
|
net_->clear();
|
||||||
}
|
}
|
||||||
workspace_allocator_.clear();
|
workspace_allocator_.clear();
|
||||||
blob_allocator_.clear();
|
blob_allocator_.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
int Estimator::LoadModel(const char * root_path) {
|
int Estimator::LoadModel(const char *root_path) {
|
||||||
std::string param_file = std::string(root_path) + "/param";
|
std::string param_file = std::string(root_path) + "/param";
|
||||||
std::string bin_file = std::string(root_path) + "/bin";
|
std::string bin_file = std::string(root_path) + "/bin";
|
||||||
if (net_->load_param(param_file.c_str()) == -1 ||
|
if (net_->load_param(param_file.c_str()) == -1 ||
|
||||||
net_->load_model(bin_file.c_str()) == -1) {
|
net_->load_model(bin_file.c_str()) == -1) {
|
||||||
return 10000;
|
return 10000;
|
||||||
}
|
}
|
||||||
|
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
EstimatorBase::EstimatorBase() {
|
EstimatorBase::EstimatorBase() { num_threads = ncnn::get_big_cpu_count(); }
|
||||||
num_threads = ncnn::get_big_cpu_count();
|
|
||||||
}
|
|
||||||
|
|
||||||
EstimatorBase::~EstimatorBase() {}
|
EstimatorBase::~EstimatorBase() {}
|
||||||
|
|
||||||
void EstimatorBase::set_num_threads(int n) {
|
void EstimatorBase::set_num_threads(int n) { num_threads = n; }
|
||||||
num_threads = n;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Estimator::set_num_threads(int n) {
|
void Estimator::set_num_threads(int n) {
|
||||||
EstimatorBase::set_num_threads(n);
|
EstimatorBase::set_num_threads(n);
|
||||||
if (net_) {
|
if (net_) {
|
||||||
net_->opt.num_threads = n;
|
net_->opt.num_threads = n;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Estimator::set_light_mode(bool mode) {
|
void Estimator::set_light_mode(bool mode) {
|
||||||
if (net_) {
|
if (net_) {
|
||||||
net_->opt.lightmode = mode;
|
net_->opt.lightmode = mode;
|
||||||
light_mode_ = mode;
|
light_mode_ = mode;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int RatioAnchors(const Rect & anchor,
|
int RatioAnchors(const Rect &anchor, const std::vector<float> &ratios,
|
||||||
const std::vector<float>& ratios,
|
std::vector<Rect> *anchors, int threads_num) {
|
||||||
std::vector<Rect>* anchors, int threads_num) {
|
anchors->clear();
|
||||||
anchors->clear();
|
Point center = Point(anchor.x + (anchor.width - 1) * 0.5f,
|
||||||
Point center = Point(anchor.x + (anchor.width - 1) * 0.5f,
|
anchor.y + (anchor.height - 1) * 0.5f);
|
||||||
anchor.y + (anchor.height - 1) * 0.5f);
|
float anchor_size = anchor.width * anchor.height;
|
||||||
float anchor_size = anchor.width * anchor.height;
|
#ifdef OV_OPENMP
|
||||||
#ifdef OV_OPENMP
|
|
||||||
#pragma omp parallel for num_threads(threads_num)
|
#pragma omp parallel for num_threads(threads_num)
|
||||||
#endif
|
#endif
|
||||||
for (int i = 0; i < static_cast<int>(ratios.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(ratios.size()); ++i) {
|
||||||
float ratio = ratios.at(i);
|
float ratio = ratios.at(i);
|
||||||
float anchor_size_ratio = anchor_size / ratio;
|
float anchor_size_ratio = anchor_size / ratio;
|
||||||
float curr_anchor_width = sqrt(anchor_size_ratio);
|
float curr_anchor_width = sqrt(anchor_size_ratio);
|
||||||
float curr_anchor_height = curr_anchor_width * ratio;
|
float curr_anchor_height = curr_anchor_width * ratio;
|
||||||
float curr_x = center.x - (curr_anchor_width - 1)* 0.5f;
|
float curr_x = center.x - (curr_anchor_width - 1) * 0.5f;
|
||||||
float curr_y = center.y - (curr_anchor_height - 1)* 0.5f;
|
float curr_y = center.y - (curr_anchor_height - 1) * 0.5f;
|
||||||
|
|
||||||
Rect curr_anchor = Rect(curr_x, curr_y,
|
Rect curr_anchor =
|
||||||
curr_anchor_width - 1, curr_anchor_height - 1);
|
Rect(curr_x, curr_y, curr_anchor_width - 1, curr_anchor_height - 1);
|
||||||
anchors->push_back(curr_anchor);
|
anchors->push_back(curr_anchor);
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int ScaleAnchors(const std::vector<Rect>& ratio_anchors,
|
int ScaleAnchors(const std::vector<Rect> &ratio_anchors,
|
||||||
const std::vector<float>& scales, std::vector<Rect>* anchors, int threads_num) {
|
const std::vector<float> &scales, std::vector<Rect> *anchors,
|
||||||
anchors->clear();
|
int threads_num) {
|
||||||
|
anchors->clear();
|
||||||
#if defined(_OPENMP)
|
#if defined(_OPENMP)
|
||||||
#pragma omp parallel for num_threads(threads_num)
|
#pragma omp parallel for num_threads(threads_num)
|
||||||
#endif
|
#endif
|
||||||
for (int i = 0; i < static_cast<int>(ratio_anchors.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(ratio_anchors.size()); ++i) {
|
||||||
Rect anchor = ratio_anchors.at(i);
|
Rect anchor = ratio_anchors.at(i);
|
||||||
Point2f center = Point2f(anchor.x + anchor.width * 0.5f,
|
Point2f center = Point2f(anchor.x + anchor.width * 0.5f,
|
||||||
anchor.y + anchor.height * 0.5f);
|
anchor.y + anchor.height * 0.5f);
|
||||||
for (int j = 0; j < static_cast<int>(scales.size()); ++j) {
|
for (int j = 0; j < static_cast<int>(scales.size()); ++j) {
|
||||||
float scale = scales.at(j);
|
float scale = scales.at(j);
|
||||||
float curr_width = scale * (anchor.width + 1);
|
float curr_width = scale * (anchor.width + 1);
|
||||||
float curr_height = scale * (anchor.height + 1);
|
float curr_height = scale * (anchor.height + 1);
|
||||||
float curr_x = center.x - curr_width * 0.5f;
|
float curr_x = center.x - curr_width * 0.5f;
|
||||||
float curr_y = center.y - curr_height * 0.5f;
|
float curr_y = center.y - curr_height * 0.5f;
|
||||||
Rect curr_anchor = Rect(curr_x, curr_y,
|
Rect curr_anchor = Rect(curr_x, curr_y, curr_width, curr_height);
|
||||||
curr_width, curr_height);
|
anchors->push_back(curr_anchor);
|
||||||
anchors->push_back(curr_anchor);
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int GenerateAnchors(const int & base_size,
|
int GenerateAnchors(const int &base_size, const std::vector<float> &ratios,
|
||||||
const std::vector<float>& ratios,
|
const std::vector<float> scales, std::vector<Rect> *anchors,
|
||||||
const std::vector<float> scales,
|
int threads_num) {
|
||||||
std::vector<Rect>* anchors,
|
anchors->clear();
|
||||||
int threads_num) {
|
Rect anchor = Rect(0, 0, base_size, base_size);
|
||||||
anchors->clear();
|
std::vector<Rect> ratio_anchors;
|
||||||
Rect anchor = Rect(0, 0, base_size, base_size);
|
RatioAnchors(anchor, ratios, &ratio_anchors, threads_num);
|
||||||
std::vector<Rect> ratio_anchors;
|
ScaleAnchors(ratio_anchors, scales, anchors, threads_num);
|
||||||
RatioAnchors(anchor, ratios, &ratio_anchors, threads_num);
|
|
||||||
ScaleAnchors(ratio_anchors, scales, anchors, threads_num);
|
return 0;
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
float InterRectArea(const Rect & a, const Rect & b) {
|
float InterRectArea(const Rect &a, const Rect &b) {
|
||||||
Point left_top = Point(std::max(a.x, b.x), std::max(a.y, b.y));
|
Point left_top = Point(std::max(a.x, b.x), std::max(a.y, b.y));
|
||||||
Point right_bottom = Point(std::min(a.br().x, b.br().x), std::min(a.br().y, b.br().y));
|
Point right_bottom =
|
||||||
Point diff = right_bottom - left_top;
|
Point(std::min(a.br().x, b.br().x), std::min(a.br().y, b.br().y));
|
||||||
return (std::max(diff.x + 1, 0) * std::max(diff.y + 1, 0));
|
Point diff = right_bottom - left_top;
|
||||||
|
return (std::max(diff.x + 1, 0) * std::max(diff.y + 1, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
int ComputeIOU(const Rect & rect1,
|
int ComputeIOU(const Rect &rect1, const Rect &rect2, float *iou,
|
||||||
const Rect & rect2, float * iou,
|
const std::string &type) {
|
||||||
const std::string& type) {
|
|
||||||
|
|
||||||
float inter_area = InterRectArea(rect1, rect2);
|
float inter_area = InterRectArea(rect1, rect2);
|
||||||
if (type == "UNION") {
|
if (type == "UNION") {
|
||||||
*iou = inter_area / (rect1.area() + rect2.area() - inter_area);
|
*iou = inter_area / (rect1.area() + rect2.area() - inter_area);
|
||||||
}
|
} else {
|
||||||
else {
|
*iou = inter_area / std::min(rect1.area(), rect2.area());
|
||||||
*iou = inter_area / std::min(rect1.area(), rect2.area());
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EnlargeRect(const float &scale, Rect *rect) {
|
||||||
void EnlargeRect(const float& scale, Rect* rect) {
|
float offset_x = (scale - 1.f) / 2.f * rect->width;
|
||||||
float offset_x = (scale - 1.f) / 2.f * rect->width;
|
float offset_y = (scale - 1.f) / 2.f * rect->height;
|
||||||
float offset_y = (scale - 1.f) / 2.f * rect->height;
|
rect->x -= offset_x;
|
||||||
rect->x -= offset_x;
|
rect->y -= offset_y;
|
||||||
rect->y -= offset_y;
|
rect->width = scale * rect->width;
|
||||||
rect->width = scale * rect->width;
|
rect->height = scale * rect->height;
|
||||||
rect->height = scale * rect->height;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void RectifyRect(Rect* rect) {
|
void RectifyRect(Rect *rect) {
|
||||||
int max_side = std::max(rect->width, rect->height);
|
int max_side = std::max(rect->width, rect->height);
|
||||||
int offset_x = (max_side - rect->width) / 2;
|
int offset_x = (max_side - rect->width) / 2;
|
||||||
int offset_y = (max_side - rect->height) / 2;
|
int offset_y = (max_side - rect->height) / 2;
|
||||||
|
|
||||||
rect->x -= offset_x;
|
rect->x -= offset_x;
|
||||||
rect->y -= offset_y;
|
rect->y -= offset_y;
|
||||||
rect->width = max_side;
|
rect->width = max_side;
|
||||||
rect->height = max_side;
|
rect->height = max_side;
|
||||||
}
|
}
|
||||||
|
|
||||||
void qsort_descent_inplace(std::vector<ObjectInfo>& objects, int left, int right)
|
void qsort_descent_inplace(std::vector<ObjectInfo> &objects, int left,
|
||||||
{
|
int right) {
|
||||||
int i = left;
|
int i = left;
|
||||||
int j = right;
|
int j = right;
|
||||||
float p = objects[(left + right) / 2].score;
|
float p = objects[(left + right) / 2].score;
|
||||||
|
|
||||||
while (i <= j)
|
while (i <= j) {
|
||||||
|
while (objects[i].score > p)
|
||||||
|
i++;
|
||||||
|
|
||||||
|
while (objects[j].score < p)
|
||||||
|
j--;
|
||||||
|
|
||||||
|
if (i <= j) {
|
||||||
|
// swap
|
||||||
|
std::swap(objects[i], objects[j]);
|
||||||
|
|
||||||
|
i++;
|
||||||
|
j--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma omp parallel sections
|
||||||
|
{
|
||||||
|
#pragma omp section
|
||||||
{
|
{
|
||||||
while (objects[i].score > p)
|
if (left < j)
|
||||||
i++;
|
qsort_descent_inplace(objects, left, j);
|
||||||
|
}
|
||||||
|
#pragma omp section
|
||||||
|
{
|
||||||
|
if (i < right)
|
||||||
|
qsort_descent_inplace(objects, i, right);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
while (objects[j].score < p)
|
void qsort_descent_inplace(std::vector<ObjectInfo> &objects) {
|
||||||
j--;
|
if (objects.empty())
|
||||||
|
return;
|
||||||
|
|
||||||
if (i <= j)
|
qsort_descent_inplace(objects, 0, objects.size() - 1);
|
||||||
{
|
}
|
||||||
// swap
|
|
||||||
std::swap(objects[i], objects[j]);
|
|
||||||
|
|
||||||
i++;
|
void nms_sorted_bboxes(const std::vector<ObjectInfo> &objects,
|
||||||
j--;
|
std::vector<int> &picked, float nms_threshold) {
|
||||||
}
|
picked.clear();
|
||||||
|
|
||||||
|
const int n = objects.size();
|
||||||
|
|
||||||
|
std::vector<float> areas(n);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
areas[i] = objects[i].rect.area();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
const ObjectInfo &a = objects[i];
|
||||||
|
|
||||||
|
int keep = 1;
|
||||||
|
for (int j = 0; j < (int)picked.size(); j++) {
|
||||||
|
const ObjectInfo &b = objects[picked[j]];
|
||||||
|
|
||||||
|
// intersection over union
|
||||||
|
float inter_area = InterRectArea(a.rect, b.rect);
|
||||||
|
float union_area = areas[i] + areas[picked[j]] - inter_area;
|
||||||
|
// float IoU = inter_area / union_area
|
||||||
|
if (inter_area / union_area > nms_threshold)
|
||||||
|
keep = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
#pragma omp parallel sections
|
if (keep)
|
||||||
{
|
picked.push_back(i);
|
||||||
#pragma omp section
|
}
|
||||||
{
|
|
||||||
if (left < j) qsort_descent_inplace(objects, left, j);
|
|
||||||
}
|
|
||||||
#pragma omp section
|
|
||||||
{
|
|
||||||
if (i < right) qsort_descent_inplace(objects, i, right);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void qsort_descent_inplace(std::vector<ObjectInfo>& objects)
|
|
||||||
{
|
|
||||||
if (objects.empty())
|
|
||||||
return;
|
|
||||||
|
|
||||||
qsort_descent_inplace(objects, 0, objects.size() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void nms_sorted_bboxes(const std::vector<ObjectInfo>& objects, std::vector<int>& picked, float nms_threshold)
|
|
||||||
{
|
|
||||||
picked.clear();
|
|
||||||
|
|
||||||
const int n = objects.size();
|
|
||||||
|
|
||||||
std::vector<float> areas(n);
|
|
||||||
for (int i = 0; i < n; i++)
|
|
||||||
{
|
|
||||||
areas[i] = objects[i].rect.area();
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < n; i++)
|
|
||||||
{
|
|
||||||
const ObjectInfo& a = objects[i];
|
|
||||||
|
|
||||||
int keep = 1;
|
|
||||||
for (int j = 0; j < (int)picked.size(); j++)
|
|
||||||
{
|
|
||||||
const ObjectInfo& b = objects[picked[j]];
|
|
||||||
|
|
||||||
// intersection over union
|
|
||||||
float inter_area = InterRectArea(a.rect, b.rect);
|
|
||||||
float union_area = areas[i] + areas[picked[j]] - inter_area;
|
|
||||||
// float IoU = inter_area / union_area
|
|
||||||
if (inter_area / union_area > nms_threshold)
|
|
||||||
keep = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (keep)
|
|
||||||
picked.push_back(i);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
//
|
//
|
||||||
// insightface/detection/scrfd/mmdet/core/anchor/anchor_generator.py gen_single_level_base_anchors()
|
// insightface/detection/scrfd/mmdet/core/anchor/anchor_generator.py
|
||||||
ncnn::Mat generate_anchors(int base_size, const ncnn::Mat& ratios, const ncnn::Mat& scales)
|
// gen_single_level_base_anchors()
|
||||||
{
|
ncnn::Mat generate_anchors(int base_size, const ncnn::Mat &ratios,
|
||||||
int num_ratio = ratios.w;
|
const ncnn::Mat &scales) {
|
||||||
int num_scale = scales.w;
|
int num_ratio = ratios.w;
|
||||||
|
int num_scale = scales.w;
|
||||||
|
|
||||||
ncnn::Mat anchors;
|
ncnn::Mat anchors;
|
||||||
anchors.create(4, num_ratio * num_scale);
|
anchors.create(4, num_ratio * num_scale);
|
||||||
|
|
||||||
const float cx = 0;
|
const float cx = 0;
|
||||||
const float cy = 0;
|
const float cy = 0;
|
||||||
|
|
||||||
for (int i = 0; i < num_ratio; i++)
|
for (int i = 0; i < num_ratio; i++) {
|
||||||
{
|
float ar = ratios[i];
|
||||||
float ar = ratios[i];
|
|
||||||
|
|
||||||
int r_w = round(base_size / sqrt(ar));
|
int r_w = round(base_size / sqrt(ar));
|
||||||
int r_h = round(r_w * ar); //round(base_size * sqrt(ar));
|
int r_h = round(r_w * ar); // round(base_size * sqrt(ar));
|
||||||
|
|
||||||
for (int j = 0; j < num_scale; j++)
|
for (int j = 0; j < num_scale; j++) {
|
||||||
{
|
float scale = scales[j];
|
||||||
float scale = scales[j];
|
|
||||||
|
|
||||||
float rs_w = r_w * scale;
|
float rs_w = r_w * scale;
|
||||||
float rs_h = r_h * scale;
|
float rs_h = r_h * scale;
|
||||||
|
|
||||||
float* anchor = anchors.row(i * num_scale + j);
|
float *anchor = anchors.row(i * num_scale + j);
|
||||||
|
|
||||||
anchor[0] = cx - rs_w * 0.5f;
|
anchor[0] = cx - rs_w * 0.5f;
|
||||||
anchor[1] = cy - rs_h * 0.5f;
|
anchor[1] = cy - rs_h * 0.5f;
|
||||||
anchor[2] = cx + rs_w * 0.5f;
|
anchor[2] = cx + rs_w * 0.5f;
|
||||||
anchor[3] = cy + rs_h * 0.5f;
|
anchor[3] = cy + rs_h * 0.5f;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return anchors;
|
return anchors;
|
||||||
}
|
}
|
||||||
|
|
||||||
int generate_grids_and_stride(const int target_size, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
|
int generate_grids_and_stride(const int target_size, std::vector<int> &strides,
|
||||||
{
|
std::vector<GridAndStride> &grid_strides) {
|
||||||
for (auto stride : strides)
|
for (auto stride : strides) {
|
||||||
{
|
int num_grid = target_size / stride;
|
||||||
int num_grid = target_size / stride;
|
for (int g1 = 0; g1 < num_grid; g1++) {
|
||||||
for (int g1 = 0; g1 < num_grid; g1++)
|
for (int g0 = 0; g0 < num_grid; g0++) {
|
||||||
{
|
grid_strides.push_back((GridAndStride){g0, g1, stride});
|
||||||
for (int g0 = 0; g0 < num_grid; g0++)
|
}
|
||||||
{
|
|
||||||
grid_strides.push_back((GridAndStride){g0, g1, stride});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sigmoid(float x)
|
float sigmoid(float x) { return static_cast<float>(1.f / (1.f + exp(-x))); }
|
||||||
{
|
|
||||||
return static_cast<float>(1.f / (1.f + exp(-x)));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
} // namespace ov
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user