mirror of
https://github.com/bububa/openvision.git
synced 2025-09-26 17:51:13 +08:00
Merge branch 'release/v1.0.1'
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -70,3 +70,5 @@ _testmain.go
|
||||
test
|
||||
.vim
|
||||
dist/
|
||||
|
||||
libtorch/
|
||||
|
@@ -33,6 +33,9 @@ cmake .. # optional -DNCNN_VULKAN=OFF -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COM
|
||||
- scrfd [Google Drive](https://drive.google.com/drive/folders/1XPjfsuXGj9rXqAmo1K70BsqWmHvoYQv_?usp=sharing)
|
||||
- tracker (for face IOU calculation bettween frames)
|
||||
- hopenet (for head pose detection) [Google Drive](https://drive.google.com/drive/folders/1zLam-8s9ZMPDUxUEtNU2F9yFTDRM5fk-?usp=sharing)
|
||||
- hair (for hair segmentation) [Google Drive](https://drive.google.com/drive/folders/14DOBaFrxTL1k4T1ved5qfRUUziurItT8?usp=sharing)
|
||||
- eye
|
||||
- lenet (eye status detector) [Google Drive](https://drive.google.com/drive/folders/1jaonx6PeXFLA8gBKo4eQGuxsncVnqS7o?usp=sharing)
|
||||
- pose
|
||||
- detector (for pose detection/estimation)
|
||||
- ultralight [Google Drive](https://drive.google.com/drive/folders/15b-I5HDyGe2WLb-TO85SJYmnYONvGOKh?usp=sharing)
|
||||
@@ -50,10 +53,14 @@ cmake .. # optional -DNCNN_VULKAN=OFF -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COM
|
||||
- nanodet [Google Drive](https://drive.google.com/drive/folders/1ywH7r_clqqA_BAOFSzA92Q0lxJtWlN3z?usp=sharing)
|
||||
- pose (for hand pose estimation)
|
||||
- handnet [Google Drive](https://drive.google.com/drive/folders/1DsCGmiVaZobbMWRp5Oec8GbIpeg7CsNR?usp=sharing)
|
||||
- pose3d (for 3d handpose detection)
|
||||
- mediapipe [Google Drive](https://drive.google.com/drive/folders/1LsqIGB55dusZJqmP1uhnQUnNE2tLzifp?usp=sharing)
|
||||
- styletransfer
|
||||
- animegan2 [Google Drive](https://drive.google.com/drive/folders/1K6ZScENPHVbxupHkwl5WcpG8PPECtD8e?usp=sharing)
|
||||
- tracker
|
||||
- lighttrack [Google Drive](https://drive.google.com/drive/folders/16cxns_xzSOABHn6UcY1OXyf4MFcSSbEf?usp=sharing)
|
||||
- counter
|
||||
- p2pnet [Google Drive](https://drive.google.com/drive/folders/1kmtBsPIS79C3hMAwm_Tv9tAPvJLV9k35?usp=sharing)
|
||||
- golang binding (github.com/bububa/openvision/go)
|
||||
|
||||
## Reference
|
||||
|
3
data/font/.gitignore
vendored
Normal file
3
data/font/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
*
|
||||
*/
|
||||
!.gitignore
|
@@ -45,6 +45,9 @@ make -j 4
|
||||
- scrfd [Google Drive](https://drive.google.com/drive/folders/1XPjfsuXGj9rXqAmo1K70BsqWmHvoYQv_?usp=sharing)
|
||||
- tracker (for face IOU calculation bettween frames)
|
||||
- hopenet (for head pose detection) [Google Drive](https://drive.google.com/drive/folders/1zLam-8s9ZMPDUxUEtNU2F9yFTDRM5fk-?usp=sharing)
|
||||
- hair (for hair segmentation) [Google Drive](https://drive.google.com/drive/folders/14DOBaFrxTL1k4T1ved5qfRUUziurItT8?usp=sharing)
|
||||
- eye
|
||||
- lenet (eye status detector) [Google Drive](https://drive.google.com/drive/folders/1jaonx6PeXFLA8gBKo4eQGuxsncVnqS7o?usp=sharing)
|
||||
- pose
|
||||
- detector (for pose detection/estimation)
|
||||
- ultralight [Google Drive](https://drive.google.com/drive/folders/15b-I5HDyGe2WLb-TO85SJYmnYONvGOKh?usp=sharing)
|
||||
@@ -66,3 +69,5 @@ make -j 4
|
||||
- animegan2 [Google Drive](https://drive.google.com/drive/folders/1K6ZScENPHVbxupHkwl5WcpG8PPECtD8e?usp=sharing)
|
||||
- tracker
|
||||
- lighttrack [Google Drive](https://drive.google.com/drive/folders/16cxns_xzSOABHn6UcY1OXyf4MFcSSbEf?usp=sharing)
|
||||
- counter
|
||||
- p2pnet [Google Drive](https://drive.google.com/drive/folders/1kmtBsPIS79C3hMAwm_Tv9tAPvJLV9k35?usp=sharing)
|
||||
|
2
go/classifier/doc.go
Normal file
2
go/classifier/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package classifier implement different classifiers
|
||||
package classifier
|
40
go/classifier/svm/binary_classifier.go
Normal file
40
go/classifier/svm/binary_classifier.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package svm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/classifier/svm_classifier.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// BinaryClassifier represents svm classifier
|
||||
type BinaryClassifier struct {
|
||||
d C.ISVMClassifier
|
||||
}
|
||||
|
||||
// NewBinaryClassifier returns a new BinaryClassifier
|
||||
func NewBinaryClassifier() *BinaryClassifier {
|
||||
return &BinaryClassifier{
|
||||
d: C.new_svm_binary_classifier(),
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy destroy C.ISVMClassifier
|
||||
func (t *BinaryClassifier) Destroy() {
|
||||
DestroyClassifier(t.d)
|
||||
}
|
||||
|
||||
// LoadModel load model
|
||||
func (t *BinaryClassifier) LoadModel(modelPath string) {
|
||||
LoadClassifierModel(t.d, modelPath)
|
||||
}
|
||||
|
||||
// Predict returns predicted score
|
||||
func (t *BinaryClassifier) Predict(vec []float32) float64 {
|
||||
return Predict(t.d, vec)
|
||||
}
|
||||
|
||||
// Classify returns classifid scores
|
||||
func (t *BinaryClassifier) Classify(vec []float32) ([]float64, error) {
|
||||
return Classify(t.d, vec)
|
||||
}
|
50
go/classifier/svm/binary_trainer.go
Normal file
50
go/classifier/svm/binary_trainer.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package svm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/classifier/svm_trainer.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// BinaryTrainer represents svm trainer
|
||||
type BinaryTrainer struct {
|
||||
d C.ISVMTrainer
|
||||
}
|
||||
|
||||
// NewBinaryTrainer returns a new BinaryTrainer
|
||||
func NewBinaryTrainer() *BinaryTrainer {
|
||||
return &BinaryTrainer{
|
||||
d: C.new_svm_binary_trainer(),
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy destroy C.ISVMTrainer
|
||||
func (t *BinaryTrainer) Destroy() {
|
||||
DestroyTrainer(t.d)
|
||||
}
|
||||
|
||||
// Reset reset C.ISVMTrainer
|
||||
func (t *BinaryTrainer) Reset() {
|
||||
ResetTrainer(t.d)
|
||||
}
|
||||
|
||||
// SetLabels set total labels
|
||||
func (t *BinaryTrainer) Labels(labels int) {
|
||||
SetLabels(t.d, labels)
|
||||
}
|
||||
|
||||
// SetFeatures set total features
|
||||
func (t *BinaryTrainer) SetFeatures(feats int) {
|
||||
SetFeatures(t.d, feats)
|
||||
}
|
||||
|
||||
// AddData add data with label
|
||||
func (t *BinaryTrainer) AddData(labelID int, feats []float32) {
|
||||
AddData(t.d, labelID, feats)
|
||||
}
|
||||
|
||||
// Train train model
|
||||
func (t *BinaryTrainer) Train(modelPath string) error {
|
||||
return Train(t.d, modelPath)
|
||||
}
|
12
go/classifier/svm/cgo.go
Normal file
12
go/classifier/svm/cgo.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !vulkan
|
||||
// +build !vulkan
|
||||
|
||||
package svm
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
12
go/classifier/svm/cgo_vulkan.go
Normal file
12
go/classifier/svm/cgo_vulkan.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build vulkan
|
||||
// +build vulkan
|
||||
|
||||
package svm
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
58
go/classifier/svm/classifier.go
Normal file
58
go/classifier/svm/classifier.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package svm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/classifier/svm_classifier.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Classifier represents svm classifier
|
||||
type Classifier interface {
|
||||
LoadModel(string)
|
||||
Destroy()
|
||||
Predict(vec []float32) float64
|
||||
Classify(vec []float32) (scores []float64, err error)
|
||||
}
|
||||
|
||||
// Destroy destroy C.ISVMClassifier
|
||||
func DestroyClassifier(d C.ISVMClassifier) {
|
||||
C.destroy_svm_classifier(d)
|
||||
}
|
||||
|
||||
// LoadModel load model
|
||||
func LoadClassifierModel(d C.ISVMClassifier, modelPath string) {
|
||||
cPath := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
C.svm_classifier_load_model(d, cPath)
|
||||
}
|
||||
|
||||
func Predict(d C.ISVMClassifier, vec []float32) float64 {
|
||||
cvals := make([]C.float, 0, len(vec))
|
||||
for _, v := range vec {
|
||||
cvals = append(cvals, C.float(v))
|
||||
}
|
||||
score := C.svm_predict(d, &cvals[0])
|
||||
return float64(score)
|
||||
}
|
||||
|
||||
// Classify returns class scores
|
||||
func Classify(d C.ISVMClassifier, vec []float32) ([]float64, error) {
|
||||
cvals := make([]C.float, 0, len(vec))
|
||||
for _, v := range vec {
|
||||
cvals = append(cvals, C.float(v))
|
||||
}
|
||||
cScores := common.NewCFloatVector()
|
||||
defer common.FreeCFloatVector(cScores)
|
||||
errCode := C.svm_classify(d, &cvals[0], (*C.FloatVector)(unsafe.Pointer(cScores)))
|
||||
if errCode != 0 {
|
||||
return nil, openvision.ClassifyError(int(errCode))
|
||||
}
|
||||
return common.GoFloatVector(cScores), nil
|
||||
}
|
2
go/classifier/svm/doc.go
Normal file
2
go/classifier/svm/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package svm implement svm classifier
|
||||
package svm
|
40
go/classifier/svm/multiclass_classifier.go
Normal file
40
go/classifier/svm/multiclass_classifier.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package svm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/classifier/svm_classifier.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// MultiClassClassifier represents svm classifier
|
||||
type MultiClassClassifier struct {
|
||||
d C.ISVMClassifier
|
||||
}
|
||||
|
||||
// NewMultiClassClassifier returns a new MultiClassClassifier
|
||||
func NewMultiClassClassifier() *MultiClassClassifier {
|
||||
return &MultiClassClassifier{
|
||||
d: C.new_svm_multiclass_classifier(),
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy destroy C.ISVMClassifier
|
||||
func (t *MultiClassClassifier) Destroy() {
|
||||
DestroyClassifier(t.d)
|
||||
}
|
||||
|
||||
// LoadModel load model
|
||||
func (t *MultiClassClassifier) LoadModel(modelPath string) {
|
||||
LoadClassifierModel(t.d, modelPath)
|
||||
}
|
||||
|
||||
// Predict returns predicted score
|
||||
func (t *MultiClassClassifier) Predict(vec []float32) float64 {
|
||||
return Predict(t.d, vec)
|
||||
}
|
||||
|
||||
// Classify returns classifid scores
|
||||
func (t *MultiClassClassifier) Classify(vec []float32) ([]float64, error) {
|
||||
return Classify(t.d, vec)
|
||||
}
|
50
go/classifier/svm/multiclass_trainer.go
Normal file
50
go/classifier/svm/multiclass_trainer.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package svm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/classifier/svm_trainer.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// MultiClassTrainer represents svm trainer
|
||||
type MultiClassTrainer struct {
|
||||
d C.ISVMTrainer
|
||||
}
|
||||
|
||||
// NewMultiClassTrainer returns a new MultiClassTrainer
|
||||
func NewMultiClassTrainer() *MultiClassTrainer {
|
||||
return &MultiClassTrainer{
|
||||
d: C.new_svm_multiclass_trainer(),
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy destroy C.ISVMTrainer
|
||||
func (t *MultiClassTrainer) Destroy() {
|
||||
DestroyTrainer(t.d)
|
||||
}
|
||||
|
||||
// Reset reset C.ISVMTrainer
|
||||
func (t *MultiClassTrainer) Reset() {
|
||||
ResetTrainer(t.d)
|
||||
}
|
||||
|
||||
// SetLabels set total labels
|
||||
func (t *MultiClassTrainer) SetLabels(labels int) {
|
||||
SetLabels(t.d, labels)
|
||||
}
|
||||
|
||||
// SetFeatures set total features
|
||||
func (t *MultiClassTrainer) SetFeatures(feats int) {
|
||||
SetFeatures(t.d, feats)
|
||||
}
|
||||
|
||||
// AddData add data with label
|
||||
func (t *MultiClassTrainer) AddData(labelID int, feats []float32) {
|
||||
AddData(t.d, labelID, feats)
|
||||
}
|
||||
|
||||
// Train train model
|
||||
func (t *MultiClassTrainer) Train(modelPath string) error {
|
||||
return Train(t.d, modelPath)
|
||||
}
|
63
go/classifier/svm/trainer.go
Normal file
63
go/classifier/svm/trainer.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package svm
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/classifier/svm_trainer.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
)
|
||||
|
||||
// Trainer represents svm trainer
|
||||
type Trainer interface {
|
||||
Reset()
|
||||
Destroy()
|
||||
SetLabels(int)
|
||||
SetFeatures(int)
|
||||
AddData(int, []float32)
|
||||
Train(modelPath string) error
|
||||
}
|
||||
|
||||
// DestroyTrainer destroy C.ISVMTrainer
|
||||
func DestroyTrainer(d C.ISVMTrainer) {
|
||||
C.destroy_svm_trainer(d)
|
||||
}
|
||||
|
||||
// ResetTrainer reset C.ISVMTrainer
|
||||
func ResetTrainer(d C.ISVMTrainer) {
|
||||
C.svm_trainer_reset(d)
|
||||
}
|
||||
|
||||
// SetLabels set total labels
|
||||
func SetLabels(d C.ISVMTrainer, labels int) {
|
||||
C.svm_trainer_set_labels(d, C.int(labels))
|
||||
}
|
||||
|
||||
// SetFeatures set total features
|
||||
func SetFeatures(d C.ISVMTrainer, feats int) {
|
||||
C.svm_trainer_set_features(d, C.int(feats))
|
||||
}
|
||||
|
||||
// AddData add data with label
|
||||
func AddData(d C.ISVMTrainer, labelID int, feats []float32) {
|
||||
vec := make([]C.float, 0, len(feats))
|
||||
for _, v := range feats {
|
||||
vec = append(vec, C.float(v))
|
||||
}
|
||||
C.svm_trainer_add_data(d, C.int(labelID), &vec[0])
|
||||
}
|
||||
|
||||
// Train train model
|
||||
func Train(d C.ISVMTrainer, modelPath string) error {
|
||||
cPath := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
errCode := C.svm_train(d, cPath)
|
||||
if errCode != 0 {
|
||||
return openvision.TrainingError(int(errCode))
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -38,6 +38,9 @@ func parseHexColor(x string) (r, g, b, a uint32) {
|
||||
}
|
||||
|
||||
const (
|
||||
White = "#FFFFFF"
|
||||
Black = "#000000"
|
||||
Gray = "#333333"
|
||||
Green = "#64DD17"
|
||||
Pink = "#E91E63"
|
||||
Red = "#FF1744"
|
||||
|
45
go/common/font.go
Normal file
45
go/common/font.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/golang/freetype/truetype"
|
||||
"github.com/llgcode/draw2d"
|
||||
)
|
||||
|
||||
// Font font info
|
||||
type Font struct {
|
||||
// Cache FontCache
|
||||
Cache draw2d.FontCache
|
||||
// Size font size
|
||||
Size float64 `json:"size,omitempty"`
|
||||
// Data font setting
|
||||
Data *draw2d.FontData `json:"data,omitempty"`
|
||||
// Font
|
||||
Font *truetype.Font `json:"-"`
|
||||
}
|
||||
|
||||
// Load font from font cache
|
||||
func (f *Font) Load(cache draw2d.FontCache) error {
|
||||
if f.Font != nil {
|
||||
return nil
|
||||
}
|
||||
if f.Data == nil {
|
||||
return nil
|
||||
}
|
||||
if cache == nil {
|
||||
return errors.New("missing font cache")
|
||||
}
|
||||
ft, err := cache.Load(*f.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Cache = cache
|
||||
f.Font = ft
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewFontCache load font cache
|
||||
func NewFontCache(fontFolder string) *draw2d.SyncFolderFontCache {
|
||||
return draw2d.NewSyncFolderFontCache(fontFolder)
|
||||
}
|
@@ -60,6 +60,8 @@ func GoRect(c *C.Rect, w float64, h float64) Rectangle {
|
||||
|
||||
var ZR = Rectangle{}
|
||||
|
||||
var FullRect = Rect(0, 0, 1, 1)
|
||||
|
||||
// Point represents a Point
|
||||
type Point struct {
|
||||
X float64
|
||||
@@ -88,6 +90,9 @@ func NewCPoint2fVector() *C.Point2fVector {
|
||||
|
||||
// GoPoint2fVector convert C.Point2fVector to []Point
|
||||
func GoPoint2fVector(cVector *C.Point2fVector, w float64, h float64) []Point {
|
||||
if cVector == nil {
|
||||
return nil
|
||||
}
|
||||
l := int(cVector.length)
|
||||
ret := make([]Point, 0, l)
|
||||
ptr := unsafe.Pointer(cVector.points)
|
||||
@@ -103,3 +108,52 @@ func FreeCPoint2fVector(c *C.Point2fVector) {
|
||||
C.FreePoint2fVector(c)
|
||||
C.free(unsafe.Pointer(c))
|
||||
}
|
||||
|
||||
// Point3d represents a 3dPoint
|
||||
type Point3d struct {
|
||||
X float64
|
||||
Y float64
|
||||
Z float64
|
||||
}
|
||||
|
||||
// Pt3d returns a New Point3d
|
||||
func Pt3d(x, y, z float64) Point3d {
|
||||
return Point3d{x, y, z}
|
||||
}
|
||||
|
||||
var ZP3d = Point3d{}
|
||||
|
||||
// GoPoint3d conver C.Point3d to Point3d
|
||||
func GoPoint3d(c *C.Point3d) Point3d {
|
||||
return Pt3d(
|
||||
float64(c.x),
|
||||
float64(c.y),
|
||||
float64(c.z),
|
||||
)
|
||||
}
|
||||
|
||||
// NewCPoint3dVector retruns C.Point3dVector pointer
|
||||
func NewCPoint3dVector() *C.Point3dVector {
|
||||
return (*C.Point3dVector)(C.malloc(C.sizeof_Point3d))
|
||||
}
|
||||
|
||||
// GoPoint3dVector convert C.Point3dVector to []Point3d
|
||||
func GoPoint3dVector(cVector *C.Point3dVector) []Point3d {
|
||||
if cVector == nil {
|
||||
return nil
|
||||
}
|
||||
l := int(cVector.length)
|
||||
ret := make([]Point3d, 0, l)
|
||||
ptr := unsafe.Pointer(cVector.points)
|
||||
for i := 0; i < l; i++ {
|
||||
cPoint3d := (*C.Point3d)(unsafe.Pointer(uintptr(ptr) + uintptr(C.sizeof_Point3d*C.int(i))))
|
||||
ret = append(ret, GoPoint3d(cPoint3d))
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// FreeCPoint3dVector release C.Point3dVector memory
|
||||
func FreeCPoint3dVector(c *C.Point3dVector) {
|
||||
C.FreePoint3dVector(c)
|
||||
C.free(unsafe.Pointer(c))
|
||||
}
|
||||
|
@@ -26,6 +26,9 @@ type Image struct {
|
||||
// NewImage returns a new Image
|
||||
func NewImage(img image.Image) *Image {
|
||||
buf := new(bytes.Buffer)
|
||||
if img == nil {
|
||||
return &Image{buffer: buf}
|
||||
}
|
||||
Image2RGB(buf, img)
|
||||
return &Image{
|
||||
Image: img,
|
||||
@@ -33,6 +36,21 @@ func NewImage(img image.Image) *Image {
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Image) Reset() {
|
||||
i.Image = nil
|
||||
if i.buffer != nil {
|
||||
i.buffer.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
// Write write bytes to buffer
|
||||
func (i *Image) Write(b []byte) {
|
||||
if i.buffer == nil {
|
||||
return
|
||||
}
|
||||
i.buffer.Write(b)
|
||||
}
|
||||
|
||||
// Bytes returns image bytes in rgb
|
||||
func (i Image) Bytes() []byte {
|
||||
if i.buffer == nil {
|
||||
@@ -74,20 +92,23 @@ func NewCImage() *C.Image {
|
||||
return ret
|
||||
}
|
||||
|
||||
// FreeCImage free C.Image
|
||||
func FreeCImage(c *C.Image) {
|
||||
C.FreeImage(c)
|
||||
C.free(unsafe.Pointer(c))
|
||||
}
|
||||
|
||||
func GoImage(c *C.Image) (image.Image, error) {
|
||||
// GoImage returns Image from C.Image
|
||||
func GoImage(c *C.Image, out *Image) {
|
||||
w := int(c.width)
|
||||
h := int(c.height)
|
||||
channels := int(c.channels)
|
||||
data := C.GoBytes(unsafe.Pointer(c.data), C.int(w*h*channels)*C.sizeof_uchar)
|
||||
return NewImageFromBytes(data, w, h, channels)
|
||||
NewImageFromBytes(data, w, h, channels, out)
|
||||
}
|
||||
|
||||
func NewImageFromBytes(data []byte, w int, h int, channels int) (image.Image, error) {
|
||||
// NewImageFromBytes returns Image by []byte
|
||||
func NewImageFromBytes(data []byte, w int, h int, channels int, out *Image) {
|
||||
img := image.NewRGBA(image.Rect(0, 0, w, h))
|
||||
for y := 0; y < h; y++ {
|
||||
for x := 0; x < w; x++ {
|
||||
@@ -97,9 +118,10 @@ func NewImageFromBytes(data []byte, w int, h int, channels int) (image.Image, er
|
||||
alpha = data[pos+3]
|
||||
}
|
||||
img.SetRGBA(x, y, color.RGBA{uint8(data[pos]), uint8(data[pos+1]), uint8(data[pos+2]), uint8(alpha)})
|
||||
out.Write([]byte{byte(data[pos]), byte(data[pos+1]), byte(data[pos+2]), byte(alpha)})
|
||||
}
|
||||
}
|
||||
return img, nil
|
||||
out.Image = img
|
||||
}
|
||||
|
||||
// Image2RGB write image rgbdata to buffer
|
||||
@@ -170,3 +192,44 @@ func DrawCircle(gc *draw2dimg.GraphicContext, pt Point, r float64, borderColor s
|
||||
gc.Stroke()
|
||||
}
|
||||
}
|
||||
|
||||
// DrawLabel draw label text to image
|
||||
func DrawLabel(gc *draw2dimg.GraphicContext, font *Font, label string, pt Point, txtColor string, bgColor string, scale float64) {
|
||||
if font == nil || font.Cache == nil || font.Data == nil {
|
||||
return
|
||||
}
|
||||
gc.FontCache = font.Cache
|
||||
gc.SetFontData(*font.Data)
|
||||
gc.SetFontSize(font.Size * scale)
|
||||
var (
|
||||
x = float64(pt.X)
|
||||
y = float64(pt.Y)
|
||||
padding = 2.0 * scale
|
||||
)
|
||||
left, top, right, bottom := gc.GetStringBounds(label)
|
||||
height := bottom - top
|
||||
width := right - left
|
||||
if bgColor != "" {
|
||||
gc.SetFillColor(ColorFromHex(bgColor))
|
||||
draw2dkit.Rectangle(gc, x, y, x+width+padding*2, y+height+padding*2)
|
||||
gc.Fill()
|
||||
}
|
||||
gc.SetFillColor(ColorFromHex(txtColor))
|
||||
gc.FillStringAt(label, x-left+padding, y-top+padding)
|
||||
}
|
||||
|
||||
// DrawLabelInWidth draw label text to image in width restrict
|
||||
func DrawLabelInWidth(gc *draw2dimg.GraphicContext, font *Font, label string, pt Point, txtColor string, bgColor string, boundWidth float64) {
|
||||
if font == nil || font.Cache == nil || font.Data == nil {
|
||||
return
|
||||
}
|
||||
gc.FontCache = font.Cache
|
||||
gc.SetFontData(*font.Data)
|
||||
gc.SetFontSize(font.Size)
|
||||
left, _, right, _ := gc.GetStringBounds(label)
|
||||
padding := 2.0
|
||||
width := right - left
|
||||
fontWidth := width + padding*2
|
||||
scale := boundWidth / fontWidth
|
||||
DrawLabel(gc, font, label, pt, txtColor, bgColor, scale)
|
||||
}
|
||||
|
@@ -20,6 +20,8 @@ type ObjectInfo struct {
|
||||
Rect Rectangle
|
||||
// Points keypoints
|
||||
Keypoints []Keypoint
|
||||
// Name
|
||||
Name string
|
||||
}
|
||||
|
||||
// GoObjectInfo convert C.ObjectInfo to go type
|
||||
|
57
go/common/palmobject.go
Normal file
57
go/common/palmobject.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package common
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/common/common.h"
|
||||
#include "openvision/hand/pose3d.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// PalmObject
|
||||
type PalmObject struct {
|
||||
Name string
|
||||
Score float64
|
||||
Rotation float64
|
||||
RectPoints []Point
|
||||
Landmarks []Point
|
||||
Skeleton []Point
|
||||
Skeleton3d []Point3d
|
||||
}
|
||||
|
||||
// NewCPalmObjectVector returns *C.PalmObjectVector
|
||||
func NewCPalmObjectVector() *C.PalmObjectVector {
|
||||
return (*C.PalmObjectVector)(C.malloc(C.sizeof_PalmObjectVector))
|
||||
}
|
||||
|
||||
// FreeCPalmObjectVector release *C.PalmObjectVector memory
|
||||
func FreeCPalmObjectVector(p *C.PalmObjectVector) {
|
||||
C.FreePalmObjectVector(p)
|
||||
C.free(unsafe.Pointer(p))
|
||||
}
|
||||
|
||||
// GoPalmObject convert C.PalmObject to Go type
|
||||
func GoPalmObject(cObj *C.PalmObject, w float64, h float64) PalmObject {
|
||||
return PalmObject{
|
||||
Score: float64(cObj.score),
|
||||
Rotation: float64(cObj.rotation),
|
||||
RectPoints: GoPoint2fVector(cObj.rect, w, h),
|
||||
Landmarks: GoPoint2fVector(cObj.landmarks, w, h),
|
||||
Skeleton: GoPoint2fVector(cObj.skeleton, w, h),
|
||||
Skeleton3d: GoPoint3dVector(cObj.skeleton3d),
|
||||
}
|
||||
}
|
||||
|
||||
func GoPalmObjectVector(c *C.PalmObjectVector, w float64, h float64) []PalmObject {
|
||||
l := int(c.length)
|
||||
ret := make([]PalmObject, 0, l)
|
||||
ptr := unsafe.Pointer(c.items)
|
||||
for i := 0; i < l; i++ {
|
||||
cObj := (*C.PalmObject)(unsafe.Pointer(uintptr(ptr) + uintptr(C.sizeof_PalmObject*C.int(i))))
|
||||
ret = append(ret, GoPalmObject(cObj, w, h))
|
||||
}
|
||||
return ret
|
||||
}
|
12
go/counter/cgo.go
Normal file
12
go/counter/cgo.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !vulkan
|
||||
// +build !vulkan
|
||||
|
||||
package counter
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../lib
|
||||
*/
|
||||
import "C"
|
11
go/counter/cgo_vulkan.go
Normal file
11
go/counter/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build vulkan
|
||||
|
||||
package counter
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../lib
|
||||
*/
|
||||
import "C"
|
40
go/counter/counter.go
Normal file
40
go/counter/counter.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package counter
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/common/common.h"
|
||||
#include "openvision/counter/counter.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Counter represents Object Counter interface
|
||||
type Counter interface {
|
||||
common.Estimator
|
||||
CrowdCount(img *common.Image) ([]common.Keypoint, error)
|
||||
}
|
||||
|
||||
// CrowdCount returns object counter
|
||||
func CrowdCount(d Counter, img *common.Image) ([]common.Keypoint, error) {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
ptsC := common.NewCKeypointVector()
|
||||
defer common.FreeCKeypointVector(ptsC)
|
||||
errCode := C.crowd_count(
|
||||
(C.ICounter)(d.Pointer()),
|
||||
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||
C.int(imgWidth),
|
||||
C.int(imgHeight),
|
||||
(*C.KeypointVector)(unsafe.Pointer(ptsC)))
|
||||
if errCode != 0 {
|
||||
return nil, openvision.CounterError(int(errCode))
|
||||
}
|
||||
return common.GoKeypointVector(ptsC, imgWidth, imgHeight), nil
|
||||
}
|
2
go/counter/doc.go
Normal file
2
go/counter/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package counter include object counter
|
||||
package counter
|
45
go/counter/p2pnet.go
Normal file
45
go/counter/p2pnet.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package counter
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/counter/counter.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// P2PNet represents p2pnet counter
|
||||
type P2PNet struct {
|
||||
d C.ICounter
|
||||
}
|
||||
|
||||
// NewP2PNet returns a new P2PNet
|
||||
func NewP2PNet() *P2PNet {
|
||||
return &P2PNet{
|
||||
d: C.new_p2pnet_crowd_counter(),
|
||||
}
|
||||
}
|
||||
|
||||
// Destroy free tracker
|
||||
func (d *P2PNet) Destroy() {
|
||||
common.DestroyEstimator(d)
|
||||
}
|
||||
|
||||
// Pointer implement Estimator interface
|
||||
func (d *P2PNet) Pointer() unsafe.Pointer {
|
||||
return unsafe.Pointer(d.d)
|
||||
}
|
||||
|
||||
// LoadModel load model for detecter
|
||||
func (d *P2PNet) LoadModel(modelPath string) error {
|
||||
return common.EstimatorLoadModel(d, modelPath)
|
||||
}
|
||||
|
||||
// CrowdCount implement Object Counter interface
|
||||
func (d *P2PNet) CrowdCount(img *common.Image) ([]common.Keypoint, error) {
|
||||
return CrowdCount(d, img)
|
||||
}
|
24
go/error.go
24
go/error.go
@@ -56,6 +56,12 @@ var (
|
||||
Message: "detect head pose failed",
|
||||
}
|
||||
}
|
||||
HairMattingError = func(code int) Error {
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: "hair matting failed",
|
||||
}
|
||||
}
|
||||
DetectHandError = func(code int) Error {
|
||||
return Error{
|
||||
Code: code,
|
||||
@@ -74,10 +80,28 @@ var (
|
||||
Message: "object tracker error",
|
||||
}
|
||||
}
|
||||
CounterError = func(code int) Error {
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: "object counter error",
|
||||
}
|
||||
}
|
||||
RealsrError = func(code int) Error {
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: "super-resolution process error",
|
||||
}
|
||||
}
|
||||
TrainingError = func(code int) Error {
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: "training process failed",
|
||||
}
|
||||
}
|
||||
ClassifyError = func(code int) Error {
|
||||
return Error{
|
||||
Code: code,
|
||||
Message: "classify process failed",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
@@ -50,8 +50,9 @@ func align(d detecter.Detecter, a *aligner.Aligner, imgPath string, filename str
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
aligned := common.NewImage(nil)
|
||||
for idx, face := range faces {
|
||||
aligned, err := a.Align(img, face)
|
||||
err := a.Align(img, face, aligned)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
92
go/examples/counter/main.go
Normal file
92
go/examples/counter/main.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"log"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
"github.com/bububa/openvision/go/counter"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wd, _ := os.Getwd()
|
||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||
imgPath := filepath.Join(dataPath, "./images")
|
||||
modelPath := filepath.Join(dataPath, "./models")
|
||||
common.CreateGPUInstance()
|
||||
defer common.DestroyGPUInstance()
|
||||
cpuCores := common.GetBigCPUCount()
|
||||
common.SetOMPThreads(cpuCores)
|
||||
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||
d := p2pnet(modelPath)
|
||||
defer d.Destroy()
|
||||
common.SetEstimatorThreads(d, cpuCores)
|
||||
crowdCount(d, imgPath, "congested2.jpg")
|
||||
}
|
||||
|
||||
func p2pnet(modelPath string) counter.Counter {
|
||||
modelPath = filepath.Join(modelPath, "p2pnet")
|
||||
d := counter.NewP2PNet()
|
||||
if err := d.LoadModel(modelPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func crowdCount(d counter.Counter, imgPath string, filename string) {
|
||||
inPath := filepath.Join(imgPath, filename)
|
||||
imgLoaded, err := loadImage(inPath)
|
||||
if err != nil {
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgLoaded)
|
||||
pts, err := d.CrowdCount(img)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
log.Printf("count: %d\n", len(pts))
|
||||
}
|
||||
|
||||
func loadImage(filePath string) (image.Image, error) {
|
||||
fn, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fn.Close()
|
||||
img, _, err := image.Decode(fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
func saveImage(img image.Image, filePath string) error {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := jpeg.Encode(buf, img, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
fn, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fn.Close()
|
||||
fn.Write(buf.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanPath(wd string, path string) string {
|
||||
usr, _ := user.Current()
|
||||
dir := usr.HomeDir
|
||||
if path == "~" {
|
||||
return dir
|
||||
} else if strings.HasPrefix(path, "~/") {
|
||||
return filepath.Join(dir, path[2:])
|
||||
}
|
||||
return filepath.Join(wd, path)
|
||||
}
|
@@ -9,10 +9,14 @@ import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/llgcode/draw2d"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
"github.com/bububa/openvision/go/face/detecter"
|
||||
"github.com/bububa/openvision/go/face/drawer"
|
||||
facedrawer "github.com/bububa/openvision/go/face/drawer"
|
||||
)
|
||||
|
||||
@@ -21,16 +25,35 @@ func main() {
|
||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||
imgPath := filepath.Join(dataPath, "./images")
|
||||
modelPath := filepath.Join(dataPath, "./models")
|
||||
fontPath := filepath.Join(dataPath, "./font")
|
||||
common.CreateGPUInstance()
|
||||
defer common.DestroyGPUInstance()
|
||||
cpuCores := common.GetBigCPUCount()
|
||||
common.SetOMPThreads(cpuCores)
|
||||
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||
test_detect(imgPath, modelPath, cpuCores)
|
||||
test_mask(imgPath, modelPath, cpuCores)
|
||||
test_detect(imgPath, modelPath, fontPath, cpuCores)
|
||||
test_mask(imgPath, modelPath, fontPath, cpuCores)
|
||||
}
|
||||
|
||||
func test_detect(imgPath string, modelPath string, threads int) {
|
||||
func load_font(fontPath string) *common.Font {
|
||||
fontCache := common.NewFontCache(fontPath)
|
||||
fnt := &common.Font{
|
||||
Size: 9,
|
||||
Data: &draw2d.FontData{
|
||||
Name: "NotoSansCJKsc",
|
||||
//Name: "Roboto",
|
||||
Family: draw2d.FontFamilySans,
|
||||
Style: draw2d.FontStyleNormal,
|
||||
},
|
||||
}
|
||||
if err := fnt.Load(fontCache); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return fnt
|
||||
}
|
||||
func test_detect(imgPath string, modelPath string, fontPath string, threads int) {
|
||||
fnt := load_font(fontPath)
|
||||
drawer := facedrawer.New(drawer.WithFont(fnt))
|
||||
for idx, d := range []detecter.Detecter{
|
||||
retinaface(modelPath),
|
||||
centerface(modelPath),
|
||||
@@ -39,16 +62,22 @@ func test_detect(imgPath string, modelPath string, threads int) {
|
||||
scrfd(modelPath),
|
||||
} {
|
||||
common.SetEstimatorThreads(d, threads)
|
||||
detect(d, imgPath, idx, "4.jpg", false)
|
||||
detect(d, drawer, imgPath, idx, "4.jpg")
|
||||
d.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
func test_mask(imgPath string, modelPath string, threads int) {
|
||||
func test_mask(imgPath string, modelPath string, fontPath string, threads int) {
|
||||
fnt := load_font(fontPath)
|
||||
drawer := facedrawer.New(
|
||||
facedrawer.WithBorderColor(common.Red),
|
||||
facedrawer.WithMaskColor(common.Green),
|
||||
facedrawer.WithFont(fnt),
|
||||
)
|
||||
d := anticonv(modelPath)
|
||||
common.SetEstimatorThreads(d, threads)
|
||||
defer d.Destroy()
|
||||
detect(d, imgPath, 0, "mask3.jpg", true)
|
||||
detect(d, drawer, imgPath, 0, "mask3.jpg")
|
||||
}
|
||||
|
||||
func retinaface(modelPath string) detecter.Detecter {
|
||||
@@ -105,7 +134,7 @@ func anticonv(modelPath string) detecter.Detecter {
|
||||
return d
|
||||
}
|
||||
|
||||
func detect(d detecter.Detecter, imgPath string, idx int, filename string, mask bool) {
|
||||
func detect(d detecter.Detecter, drawer *facedrawer.Drawer, imgPath string, idx int, filename string) {
|
||||
inPath := filepath.Join(imgPath, filename)
|
||||
img, err := loadImage(inPath)
|
||||
if err != nil {
|
||||
@@ -116,18 +145,11 @@ func detect(d detecter.Detecter, imgPath string, idx int, filename string, mask
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%d-%s", idx, filename))
|
||||
|
||||
var drawer *facedrawer.Drawer
|
||||
if mask {
|
||||
drawer = facedrawer.New(
|
||||
facedrawer.WithBorderColor(common.Red),
|
||||
facedrawer.WithMaskColor(common.Green),
|
||||
)
|
||||
} else {
|
||||
drawer = facedrawer.New()
|
||||
for idx, face := range faces {
|
||||
faces[idx].Label = strconv.FormatFloat(float64(face.Score), 'f', 4, 64)
|
||||
}
|
||||
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%d-%s", idx, filename))
|
||||
out := drawer.Draw(img, faces)
|
||||
|
||||
if err := saveImage(out, outPath); err != nil {
|
||||
|
115
go/examples/eye/main.go
Normal file
115
go/examples/eye/main.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"log"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
"github.com/bububa/openvision/go/face/detecter"
|
||||
"github.com/bububa/openvision/go/face/eye"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wd, _ := os.Getwd()
|
||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||
imgPath := filepath.Join(dataPath, "./images")
|
||||
modelPath := filepath.Join(dataPath, "./models")
|
||||
common.CreateGPUInstance()
|
||||
defer common.DestroyGPUInstance()
|
||||
cpuCores := common.GetBigCPUCount()
|
||||
common.SetOMPThreads(cpuCores)
|
||||
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||
d := retinaface(modelPath)
|
||||
defer d.Destroy()
|
||||
common.SetEstimatorThreads(d, cpuCores)
|
||||
e := lenet(modelPath)
|
||||
defer e.Destroy()
|
||||
common.SetEstimatorThreads(e, cpuCores)
|
||||
for _, fn := range []string{"eye-open.jpg", "eye-close.jpg", "eye-half.jpg"} {
|
||||
detect(d, e, imgPath, fn)
|
||||
}
|
||||
}
|
||||
|
||||
func retinaface(modelPath string) detecter.Detecter {
|
||||
modelPath = filepath.Join(modelPath, "fd")
|
||||
d := detecter.NewRetinaFace()
|
||||
if err := d.LoadModel(modelPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func lenet(modelPath string) eye.Detecter {
|
||||
modelPath = filepath.Join(modelPath, "eye/lenet")
|
||||
d := eye.NewLenet()
|
||||
if err := d.LoadModel(modelPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func detect(d detecter.Detecter, e eye.Detecter, imgPath string, filename string) {
|
||||
inPath := filepath.Join(imgPath, filename)
|
||||
imgLoaded, err := loadImage(inPath)
|
||||
if err != nil {
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgLoaded)
|
||||
faces, err := d.Detect(img)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
for _, face := range faces {
|
||||
rect := face.Rect
|
||||
closed, err := e.IsClosed(img, rect)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
fmt.Printf("fn: %s, closed: %+v\n", filename, closed)
|
||||
}
|
||||
}
|
||||
|
||||
func loadImage(filePath string) (image.Image, error) {
|
||||
fn, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fn.Close()
|
||||
img, _, err := image.Decode(fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
func saveImage(img image.Image, filePath string) error {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := jpeg.Encode(buf, img, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
fn, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fn.Close()
|
||||
fn.Write(buf.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanPath(wd string, path string) string {
|
||||
usr, _ := user.Current()
|
||||
dir := usr.HomeDir
|
||||
if path == "~" {
|
||||
return dir
|
||||
} else if strings.HasPrefix(path, "~/") {
|
||||
return filepath.Join(dir, path[2:])
|
||||
}
|
||||
return filepath.Join(wd, path)
|
||||
}
|
94
go/examples/hair/main.go
Normal file
94
go/examples/hair/main.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/jpeg"
|
||||
"log"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
"github.com/bububa/openvision/go/face/hair"
|
||||
)
|
||||
|
||||
func main() {
|
||||
wd, _ := os.Getwd()
|
||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||
imgPath := filepath.Join(dataPath, "./images")
|
||||
modelPath := filepath.Join(dataPath, "./models")
|
||||
common.CreateGPUInstance()
|
||||
defer common.DestroyGPUInstance()
|
||||
d := estimator(modelPath)
|
||||
defer d.Destroy()
|
||||
matting(d, imgPath, "hair1.jpg")
|
||||
}
|
||||
|
||||
func estimator(modelPath string) *hair.Hair {
|
||||
modelPath = filepath.Join(modelPath, "hair")
|
||||
d := hair.NewHair()
|
||||
if err := d.LoadModel(modelPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func matting(d *hair.Hair, imgPath string, filename string) {
|
||||
inPath := filepath.Join(imgPath, filename)
|
||||
imgLoaded, err := loadImage(inPath)
|
||||
if err != nil {
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgLoaded)
|
||||
out := common.NewImage(nil)
|
||||
if err := d.Matting(img, out); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("hair-matting-%s", filename))
|
||||
|
||||
if err := saveImage(out, outPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func loadImage(filePath string) (image.Image, error) {
|
||||
fn, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fn.Close()
|
||||
img, _, err := image.Decode(fn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return img, nil
|
||||
}
|
||||
|
||||
func saveImage(img image.Image, filePath string) error {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := jpeg.Encode(buf, img, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
fn, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fn.Close()
|
||||
fn.Write(buf.Bytes())
|
||||
return nil
|
||||
}
|
||||
|
||||
func cleanPath(wd string, path string) string {
|
||||
usr, _ := user.Current()
|
||||
dir := usr.HomeDir
|
||||
if path == "~" {
|
||||
return dir
|
||||
} else if strings.HasPrefix(path, "~/") {
|
||||
return filepath.Join(dir, path[2:])
|
||||
}
|
||||
return filepath.Join(wd, path)
|
||||
}
|
@@ -15,6 +15,8 @@ import (
|
||||
"github.com/bububa/openvision/go/hand/detecter"
|
||||
handdrawer "github.com/bububa/openvision/go/hand/drawer"
|
||||
"github.com/bububa/openvision/go/hand/pose"
|
||||
"github.com/bububa/openvision/go/hand/pose3d"
|
||||
"github.com/llgcode/draw2d"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -22,22 +24,26 @@ func main() {
|
||||
dataPath := cleanPath(wd, "~/go/src/github.com/bububa/openvision/data")
|
||||
imgPath := filepath.Join(dataPath, "./images")
|
||||
modelPath := filepath.Join(dataPath, "./models")
|
||||
fontPath := filepath.Join(dataPath, "./font")
|
||||
common.CreateGPUInstance()
|
||||
defer common.DestroyGPUInstance()
|
||||
cpuCores := common.GetBigCPUCount()
|
||||
common.SetOMPThreads(cpuCores)
|
||||
log.Printf("CPU big cores:%d\n", cpuCores)
|
||||
estimator := handpose(modelPath)
|
||||
defer estimator.Destroy()
|
||||
common.SetEstimatorThreads(estimator, cpuCores)
|
||||
for idx, d := range []detecter.Detecter{
|
||||
yolox(modelPath),
|
||||
nanodet(modelPath),
|
||||
} {
|
||||
defer d.Destroy()
|
||||
common.SetEstimatorThreads(d, cpuCores)
|
||||
detect(d, estimator, imgPath, "hand1.jpg", idx)
|
||||
}
|
||||
// estimator := handpose(modelPath)
|
||||
// defer estimator.Destroy()
|
||||
// common.SetEstimatorThreads(estimator, cpuCores)
|
||||
// for idx, d := range []detecter.Detecter{
|
||||
// yolox(modelPath),
|
||||
// nanodet(modelPath),
|
||||
// } {
|
||||
// defer d.Destroy()
|
||||
// common.SetEstimatorThreads(d, cpuCores)
|
||||
// detect(d, estimator, imgPath, "hand2.jpg", idx)
|
||||
// }
|
||||
d3d := mediapipe(modelPath)
|
||||
detect3d(d3d, imgPath, fontPath, "hand1.jpg")
|
||||
detect3d(d3d, imgPath, fontPath, "hand2.jpg")
|
||||
}
|
||||
|
||||
func yolox(modelPath string) detecter.Detecter {
|
||||
@@ -67,6 +73,16 @@ func handpose(modelPath string) pose.Estimator {
|
||||
return d
|
||||
}
|
||||
|
||||
func mediapipe(modelPath string) *pose3d.Mediapipe {
|
||||
palmPath := filepath.Join(modelPath, "mediapipe/palm/full")
|
||||
handPath := filepath.Join(modelPath, "mediapipe/hand/full")
|
||||
d := pose3d.NewMediapipe()
|
||||
if err := d.LoadModel(palmPath, handPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func detect(d detecter.Detecter, e pose.Estimator, imgPath string, filename string, idx int) {
|
||||
inPath := filepath.Join(imgPath, filename)
|
||||
imgSrc, err := loadImage(inPath)
|
||||
@@ -104,6 +120,37 @@ func detect(d detecter.Detecter, e pose.Estimator, imgPath string, filename stri
|
||||
if err := saveImage(out, outPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
|
||||
func detect3d(d *pose3d.Mediapipe, imgPath string, fontPath string, filename string) {
|
||||
inPath := filepath.Join(imgPath, filename)
|
||||
imgSrc, err := loadImage(inPath)
|
||||
if err != nil {
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgSrc)
|
||||
rois, err := d.Detect(img)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
// log.Printf("%+v\n", rois)
|
||||
fnt := load_font(fontPath)
|
||||
drawer := handdrawer.New(handdrawer.WithFont(fnt))
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("pose3d-hand-%s", filename))
|
||||
out := drawer.DrawPalm(img, rois)
|
||||
|
||||
if err := saveImage(out, outPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
for idx, roi := range rois {
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("pose3d-palm3d-%d-%s", idx, filename))
|
||||
out := drawer.DrawPalm3D(roi, 400, "#442519")
|
||||
|
||||
if err := saveImage(out, outPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -144,3 +191,20 @@ func cleanPath(wd string, path string) string {
|
||||
}
|
||||
return filepath.Join(wd, path)
|
||||
}
|
||||
|
||||
func load_font(fontPath string) *common.Font {
|
||||
fontCache := common.NewFontCache(fontPath)
|
||||
fnt := &common.Font{
|
||||
Size: 9,
|
||||
Data: &draw2d.FontData{
|
||||
Name: "NotoSansCJKsc",
|
||||
//Name: "Roboto",
|
||||
Family: draw2d.FontFamilySans,
|
||||
Style: draw2d.FontStyleNormal,
|
||||
},
|
||||
}
|
||||
if err := fnt.Load(fontCache); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
return fnt
|
||||
}
|
||||
|
@@ -77,8 +77,8 @@ func videomatting(seg segmentor.Segmentor, imgPath string, filename string, idx
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgLoaded)
|
||||
out, err := seg.Matting(img)
|
||||
if err != nil {
|
||||
out := common.NewImage(nil)
|
||||
if err := seg.Matting(img, out); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
outPath := filepath.Join(imgPath, "./results/videomatting", fmt.Sprintf("%d.jpeg", idx))
|
||||
@@ -95,8 +95,8 @@ func matting(seg segmentor.Segmentor, imgPath string, filename string, idx int)
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgLoaded)
|
||||
out, err := seg.Matting(img)
|
||||
if err != nil {
|
||||
out := common.NewImage(nil)
|
||||
if err := seg.Matting(img, out); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("poseseg-matting-%d-%s", idx, filename))
|
||||
@@ -119,8 +119,8 @@ func merge(seg segmentor.Segmentor, imgPath string, filename string, bgFilename
|
||||
log.Fatalln("load bg image failed,", err)
|
||||
}
|
||||
bg := common.NewImage(bgLoaded)
|
||||
out, err := seg.Merge(img, bg)
|
||||
if err != nil {
|
||||
out := common.NewImage(nil)
|
||||
if err := seg.Merge(img, bg, out); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("poseseg-merge-%d-%s", idx, filename))
|
||||
|
@@ -49,15 +49,14 @@ func transform(transfer styletransfer.StyleTransfer, imgPath string, filename st
|
||||
log.Fatalln("load image failed,", err)
|
||||
}
|
||||
img := common.NewImage(imgLoaded)
|
||||
out, err := transfer.Transform(img)
|
||||
if err != nil {
|
||||
out := common.NewImage(nil)
|
||||
if err := transfer.Transform(img, out); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
outPath := filepath.Join(imgPath, "./results", fmt.Sprintf("%s-%s", modelName, filename))
|
||||
if err := saveImage(out, outPath); err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func loadImage(filePath string) (image.Image, error) {
|
||||
|
@@ -8,7 +8,6 @@ package aligner
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
@@ -39,7 +38,7 @@ func (a *Aligner) SetThreads(n int) {
|
||||
}
|
||||
|
||||
// Align face
|
||||
func (a *Aligner) Align(img *common.Image, faceInfo face.FaceInfo) (image.Image, error) {
|
||||
func (a *Aligner) Align(img *common.Image, faceInfo face.FaceInfo, out *common.Image) error {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
@@ -61,7 +60,8 @@ func (a *Aligner) Align(img *common.Image, faceInfo face.FaceInfo) (image.Image,
|
||||
(*C.Image)(unsafe.Pointer(outImgC)),
|
||||
)
|
||||
if errCode != 0 {
|
||||
return nil, openvision.AlignFaceError(int(errCode))
|
||||
return openvision.AlignFaceError(int(errCode))
|
||||
}
|
||||
return common.GoImage(outImgC)
|
||||
common.GoImage(outImgC, out)
|
||||
return nil
|
||||
}
|
||||
|
@@ -17,4 +17,6 @@ const (
|
||||
DefaultKeypointStrokeWidth = 2
|
||||
// DefaultInvalidBorderColor default drawer invalid border color
|
||||
DefaultInvalidBorderColor = common.Red
|
||||
// DefaultLabelColor default label color
|
||||
DefaultLabelColor = common.White
|
||||
)
|
||||
|
@@ -26,6 +26,10 @@ type Drawer struct {
|
||||
MaskColor string
|
||||
// InvalidBorderColor
|
||||
InvalidBorderColor string
|
||||
// LabelColor string
|
||||
LabelColor string
|
||||
// Font
|
||||
Font *common.Font
|
||||
}
|
||||
|
||||
// New returns a new Drawer
|
||||
@@ -38,6 +42,7 @@ func New(options ...Option) *Drawer {
|
||||
KeypointRadius: DefaultKeypointRadius,
|
||||
InvalidBorderColor: DefaultInvalidBorderColor,
|
||||
MaskColor: DefaultBorderColor,
|
||||
LabelColor: DefaultLabelColor,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt.apply(d)
|
||||
@@ -69,6 +74,9 @@ func (d *Drawer) Draw(img image.Image, faces []face.FaceInfo) image.Image {
|
||||
for _, pt := range face.Keypoints {
|
||||
common.DrawCircle(gc, common.Pt(pt.X*imgW, pt.Y*imgH), d.KeypointRadius, d.KeypointColor, "", d.KeypointStrokeWidth)
|
||||
}
|
||||
if face.Label != "" {
|
||||
common.DrawLabelInWidth(gc, d.Font, face.Label, common.Pt(rect.X, rect.MaxY()), d.LabelColor, borderColor, rect.Width)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
@@ -1,5 +1,7 @@
|
||||
package drawer
|
||||
|
||||
import "github.com/bububa/openvision/go/common"
|
||||
|
||||
// Option represents Drawer option interface
|
||||
type Option interface {
|
||||
apply(*Drawer)
|
||||
@@ -59,3 +61,17 @@ func WithMaskColor(color string) Option {
|
||||
d.MaskColor = color
|
||||
})
|
||||
}
|
||||
|
||||
// WithLabelColor set Drawer LabelColor
|
||||
func WithLabelColor(color string) Option {
|
||||
return optionFunc(func(d *Drawer) {
|
||||
d.LabelColor = color
|
||||
})
|
||||
}
|
||||
|
||||
// WithFont set Drawer Font
|
||||
func WithFont(font *common.Font) Option {
|
||||
return optionFunc(func(d *Drawer) {
|
||||
d.Font = font
|
||||
})
|
||||
}
|
||||
|
11
go/face/eye/cgo.go
Normal file
11
go/face/eye/cgo.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build !vulkan
|
||||
|
||||
package eye
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
11
go/face/eye/cgo_vulkan.go
Normal file
11
go/face/eye/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build vulkan
|
||||
|
||||
package eye
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
44
go/face/eye/detecter.go
Normal file
44
go/face/eye/detecter.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package eye
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/common/common.h"
|
||||
#include "openvision/face/eye.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Detecter represents Eye Detector interface
|
||||
type Detecter interface {
|
||||
common.Estimator
|
||||
IsClosed(img *common.Image, face common.Rectangle) (bool, error)
|
||||
}
|
||||
|
||||
// IsClosed check whether eyes are closed
|
||||
func IsClosed(r Detecter, img *common.Image, faceRect common.Rectangle) (bool, error) {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
scoresC := common.NewCFloatVector()
|
||||
defer common.FreeCFloatVector(scoresC)
|
||||
CRect := faceRect.CRect(imgWidth, imgHeight)
|
||||
errCode := C.eye_status(
|
||||
(C.IEye)(r.Pointer()),
|
||||
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||
C.int(imgWidth), C.int(imgHeight),
|
||||
(*C.Rect)(unsafe.Pointer(CRect)),
|
||||
(*C.FloatVector)(unsafe.Pointer(scoresC)),
|
||||
)
|
||||
C.free(unsafe.Pointer(CRect))
|
||||
if errCode != 0 {
|
||||
return false, openvision.RecognizeFaceError(int(errCode))
|
||||
}
|
||||
scores := common.GoFloatVector(scoresC)
|
||||
return len(scores) > 0 && scores[0] == 1, nil
|
||||
}
|
2
go/face/eye/doc.go
Normal file
2
go/face/eye/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package eye include eye status detector
|
||||
package eye
|
45
go/face/eye/lenet.go
Normal file
45
go/face/eye/lenet.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package eye
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/face/eye.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Lenet represents Lenet detecter
|
||||
type Lenet struct {
|
||||
d C.IEye
|
||||
}
|
||||
|
||||
// NetLenet returns a new Lenet detecter
|
||||
func NewLenet() *Lenet {
|
||||
return &Lenet{
|
||||
d: C.new_lenet_eye(),
|
||||
}
|
||||
}
|
||||
|
||||
// Pointer implement Estimator interface
|
||||
func (d *Lenet) Pointer() unsafe.Pointer {
|
||||
return unsafe.Pointer(d.d)
|
||||
}
|
||||
|
||||
// LoadModel implement Recognizer interface
|
||||
func (d *Lenet) LoadModel(modelPath string) error {
|
||||
return common.EstimatorLoadModel(d, modelPath)
|
||||
}
|
||||
|
||||
// Destroy implement Recognizer interface
|
||||
func (d *Lenet) Destroy() {
|
||||
common.DestroyEstimator(d)
|
||||
}
|
||||
|
||||
// IsClosed implement Eye Detecter interface
|
||||
func (d *Lenet) IsClosed(img *common.Image, faceRect common.Rectangle) (bool, error) {
|
||||
return IsClosed(d, img, faceRect)
|
||||
}
|
@@ -22,6 +22,8 @@ type FaceInfo struct {
|
||||
Keypoints [5]common.Point
|
||||
// Mask has mask or not
|
||||
Mask bool
|
||||
// Label
|
||||
Label string
|
||||
}
|
||||
|
||||
// GoFaceInfo convert c FaceInfo to go type
|
||||
|
11
go/face/hair/cgo.go
Normal file
11
go/face/hair/cgo.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build !vulkan
|
||||
|
||||
package hair
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
11
go/face/hair/cgo_vulkan.go
Normal file
11
go/face/hair/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build vulkan
|
||||
|
||||
package hair
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
2
go/face/hair/doc.go
Normal file
2
go/face/hair/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package hair include hair segmentation
|
||||
package hair
|
61
go/face/hair/hair.go
Normal file
61
go/face/hair/hair.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package hair
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/face/hair.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Hair represents Hair segmentor
|
||||
type Hair struct {
|
||||
d C.IHair
|
||||
}
|
||||
|
||||
// NewHair returns a new Hair
|
||||
func NewHair() *Hair {
|
||||
return &Hair{
|
||||
d: C.new_hair(),
|
||||
}
|
||||
}
|
||||
|
||||
// Pointer implement Estimator interface
|
||||
func (h *Hair) Pointer() unsafe.Pointer {
|
||||
return unsafe.Pointer(h.d)
|
||||
}
|
||||
|
||||
// LoadModel load detecter model
|
||||
func (h *Hair) LoadModel(modelPath string) error {
|
||||
return common.EstimatorLoadModel(h, modelPath)
|
||||
}
|
||||
|
||||
// Destroy destroy C.IHair
|
||||
func (h *Hair) Destroy() {
|
||||
common.DestroyEstimator(h)
|
||||
}
|
||||
|
||||
// Matting returns hair matting image
|
||||
func (h *Hair) Matting(img *common.Image, out *common.Image) error {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
outImgC := common.NewCImage()
|
||||
defer common.FreeCImage(outImgC)
|
||||
errCode := C.hair_matting(
|
||||
(C.IHair)(h.Pointer()),
|
||||
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||
C.int(imgWidth), C.int(imgHeight),
|
||||
(*C.Image)(unsafe.Pointer(outImgC)),
|
||||
)
|
||||
if errCode != 0 {
|
||||
return openvision.HairMattingError(int(errCode))
|
||||
}
|
||||
common.GoImage(outImgC, out)
|
||||
return nil
|
||||
}
|
@@ -34,7 +34,6 @@ func (h *Hopenet) Pointer() unsafe.Pointer {
|
||||
// LoadModel load detecter model
|
||||
func (h *Hopenet) LoadModel(modelPath string) error {
|
||||
return common.EstimatorLoadModel(h, modelPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy destroy C.IHopeNet
|
||||
|
@@ -7,71 +7,16 @@ import (
|
||||
const (
|
||||
// DefaultBorderColor default drawer border color
|
||||
DefaultBorderColor = common.Green
|
||||
// DefaultKeypointColor default drawer keypoint color
|
||||
DefaultKeypointColor = common.Pink
|
||||
// DefaultBorderStrokeWidth default drawer border stroke width
|
||||
DefaultBorderStrokeWidth = 3
|
||||
// DefaultKeypointRadius default drawer keypoint radius
|
||||
DefaultKeypointRadius = 3
|
||||
// DefaultKeypointStrokeWidth default drawer keypoint stroke width
|
||||
DefaultKeypointStrokeWidth = 1
|
||||
)
|
||||
|
||||
// CocoPart coco part define
|
||||
type CocoPart = int
|
||||
|
||||
const (
|
||||
// CocoPartNose nose
|
||||
CocoPartNose CocoPart = iota
|
||||
// CocoPartLEye left eye
|
||||
CocoPartLEye
|
||||
// CocoPartREye right eye
|
||||
CocoPartREye
|
||||
// CocoPartLEar left ear
|
||||
CocoPartLEar
|
||||
// CocoPartREar right ear
|
||||
CocoPartREar
|
||||
// CocoPartLShoulder left sholder
|
||||
CocoPartLShoulder
|
||||
// CocoPartRShoulder right sholder
|
||||
CocoPartRShoulder
|
||||
// CocoPartLElbow left elbow
|
||||
CocoPartLElbow
|
||||
// CocoPartRElbow right elbow
|
||||
CocoPartRElbow
|
||||
// CocoPartLWrist left wrist
|
||||
CocoPartLWrist
|
||||
// CocoPartRWrist right wrist
|
||||
CocoPartRWrist
|
||||
// CocoPartLHip left hip
|
||||
CocoPartLHip
|
||||
// CocoPartRHip right hip
|
||||
CocoPartRHip
|
||||
// CocoPartLKnee left knee
|
||||
CocoPartLKnee
|
||||
// CocoPartRKnee right knee
|
||||
CocoPartRKnee
|
||||
// CocoPartRAnkle right ankle
|
||||
CocoPartRAnkle
|
||||
// CocoPartLAnkle left ankle
|
||||
CocoPartLAnkle
|
||||
// CocoPartNeck neck
|
||||
CocoPartNeck
|
||||
// CocoPartBackground background
|
||||
CocoPartBackground
|
||||
)
|
||||
|
||||
var (
|
||||
// CocoPair represents joints pair
|
||||
CocoPair = [16][2]CocoPart{
|
||||
{0, 1}, {1, 3}, {0, 2}, {2, 4}, {5, 6}, {5, 7}, {7, 9}, {6, 8}, {8, 10}, {5, 11}, {6, 12}, {11, 12}, {11, 13}, {12, 14}, {13, 15}, {14, 16},
|
||||
}
|
||||
// CocoColors represents color for coco parts
|
||||
CocoColors = [17]string{
|
||||
"#ff0000", "#ff5500", "#ffaa00", "#ffff00",
|
||||
"#aaff00", "#55ff00", "#00ff00", "#00ff55", "#00ffaa",
|
||||
"#00ffff", "#00aaff", "#0055ff",
|
||||
"#0000ff", "#aa00ff", "#ff00ff",
|
||||
"#ff00aa", "#ff0055",
|
||||
}
|
||||
// DefaultLabelColor default label color
|
||||
DefaultLabelColor = common.White
|
||||
)
|
||||
|
||||
var (
|
||||
|
@@ -2,8 +2,12 @@ package drawer
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
|
||||
"github.com/llgcode/draw2d"
|
||||
"github.com/llgcode/draw2d/draw2dimg"
|
||||
"github.com/llgcode/draw2d/draw2dkit"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
@@ -18,6 +22,12 @@ type Drawer struct {
|
||||
KeypointStrokeWidth float64
|
||||
// KeypointRadius represents keypoints circle radius
|
||||
KeypointRadius float64
|
||||
// KeypointColor represents keypoint color
|
||||
KeypointColor string
|
||||
// LabelColor string
|
||||
LabelColor string
|
||||
// Font
|
||||
Font *common.Font
|
||||
}
|
||||
|
||||
// New returns a new Drawer
|
||||
@@ -27,6 +37,8 @@ func New(options ...Option) *Drawer {
|
||||
BorderStrokeWidth: DefaultBorderStrokeWidth,
|
||||
KeypointStrokeWidth: DefaultKeypointStrokeWidth,
|
||||
KeypointRadius: DefaultKeypointRadius,
|
||||
KeypointColor: DefaultKeypointColor,
|
||||
LabelColor: DefaultLabelColor,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt.apply(d)
|
||||
@@ -42,8 +54,6 @@ func (d *Drawer) Draw(img image.Image, rois []common.ObjectInfo, drawBorder bool
|
||||
gc := draw2dimg.NewGraphicContext(out)
|
||||
gc.DrawImage(img)
|
||||
for _, roi := range rois {
|
||||
if drawBorder {
|
||||
// draw rect
|
||||
rect := common.Rect(
|
||||
roi.Rect.X*imgW,
|
||||
roi.Rect.Y*imgH,
|
||||
@@ -51,6 +61,8 @@ func (d *Drawer) Draw(img image.Image, rois []common.ObjectInfo, drawBorder bool
|
||||
roi.Rect.Height*imgH,
|
||||
)
|
||||
borderColor := d.BorderColor
|
||||
if drawBorder {
|
||||
// draw rect
|
||||
common.DrawRectangle(gc, rect, borderColor, "", d.BorderStrokeWidth)
|
||||
}
|
||||
l := len(roi.Keypoints)
|
||||
@@ -95,6 +107,127 @@ func (d *Drawer) Draw(img image.Image, rois []common.ObjectInfo, drawBorder bool
|
||||
poseColor := PoseColors[colorIdx]
|
||||
common.DrawCircle(gc, common.Pt(pt.Point.X*imgW, pt.Point.Y*imgH), d.KeypointRadius, poseColor, "", d.KeypointStrokeWidth)
|
||||
}
|
||||
// draw name
|
||||
if roi.Name != "" {
|
||||
common.DrawLabelInWidth(gc, d.Font, roi.Name, common.Pt(rect.X, rect.MaxY()), d.LabelColor, borderColor, rect.Width)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DrawPalm draw PalmObject
|
||||
func (d *Drawer) DrawPalm(img image.Image, rois []common.PalmObject) image.Image {
|
||||
imgW := float64(img.Bounds().Dx())
|
||||
imgH := float64(img.Bounds().Dy())
|
||||
out := image.NewRGBA(img.Bounds())
|
||||
gc := draw2dimg.NewGraphicContext(out)
|
||||
gc.DrawImage(img)
|
||||
for _, roi := range rois {
|
||||
gc.SetLineWidth(d.BorderStrokeWidth)
|
||||
gc.SetStrokeColor(common.ColorFromHex(d.BorderColor))
|
||||
gc.BeginPath()
|
||||
for idx, pt := range roi.RectPoints {
|
||||
gc.MoveTo(pt.X*imgW, pt.Y*imgH)
|
||||
if idx == 3 {
|
||||
gc.LineTo(roi.RectPoints[0].X*imgW, roi.RectPoints[0].Y*imgH)
|
||||
} else {
|
||||
gc.LineTo(roi.RectPoints[idx+1].X*imgW, roi.RectPoints[idx+1].Y*imgH)
|
||||
}
|
||||
}
|
||||
gc.Close()
|
||||
gc.Stroke()
|
||||
|
||||
l := len(roi.Skeleton)
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
// draw skeleton
|
||||
for idx := range roi.Skeleton[:l-1] {
|
||||
var (
|
||||
p0 common.Point
|
||||
p1 common.Point
|
||||
poseColor = PoseColors[idx/4]
|
||||
)
|
||||
gc.SetStrokeColor(common.ColorFromHex(poseColor))
|
||||
if idx == 5 || idx == 9 || idx == 13 || idx == 17 {
|
||||
p0 = roi.Skeleton[0]
|
||||
p1 = roi.Skeleton[idx]
|
||||
gc.BeginPath()
|
||||
gc.MoveTo(p0.X*imgW, p0.Y*imgH)
|
||||
gc.LineTo(p1.X*imgW, p1.Y*imgH)
|
||||
gc.Close()
|
||||
gc.Stroke()
|
||||
} else if idx == 4 || idx == 8 || idx == 12 || idx == 16 {
|
||||
continue
|
||||
}
|
||||
p0 = roi.Skeleton[idx]
|
||||
p1 = roi.Skeleton[idx+1]
|
||||
gc.BeginPath()
|
||||
gc.MoveTo(p0.X*imgW, p0.Y*imgH)
|
||||
gc.LineTo(p1.X*imgW, p1.Y*imgH)
|
||||
gc.Close()
|
||||
gc.Stroke()
|
||||
}
|
||||
for _, pt := range roi.Landmarks {
|
||||
common.DrawCircle(gc, common.Pt(pt.X*imgW, pt.Y*imgH), d.KeypointRadius, d.KeypointColor, "", d.KeypointStrokeWidth)
|
||||
}
|
||||
// draw name
|
||||
if roi.Name != "" {
|
||||
deltaX := (roi.RectPoints[2].X - roi.RectPoints[3].X) * imgW
|
||||
deltaY := (roi.RectPoints[2].Y - roi.RectPoints[3].Y) * imgH
|
||||
width := math.Sqrt(math.Abs(deltaX*deltaX) + math.Abs(deltaY*deltaY))
|
||||
metrix := draw2d.NewRotationMatrix(roi.Rotation)
|
||||
ptX, ptY := metrix.InverseTransformPoint(roi.RectPoints[3].X*imgW, roi.RectPoints[3].Y*imgH)
|
||||
gc.Save()
|
||||
gc.Rotate(roi.Rotation)
|
||||
common.DrawLabelInWidth(gc, d.Font, roi.Name, common.Pt(ptX, ptY), d.LabelColor, d.BorderColor, width)
|
||||
gc.Restore()
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DrawPalm3D draw 3d PalmObject
|
||||
func (d *Drawer) DrawPalm3D(roi common.PalmObject, size float64, bg string) image.Image {
|
||||
out := image.NewRGBA(image.Rect(0, 0, int(size), int(size)))
|
||||
gc := draw2dimg.NewGraphicContext(out)
|
||||
l := len(roi.Skeleton3d)
|
||||
if l == 0 {
|
||||
return out
|
||||
}
|
||||
if bg != "" {
|
||||
bgColor := common.ColorFromHex(bg)
|
||||
gc.SetFillColor(bgColor)
|
||||
draw2dkit.Rectangle(gc, 0, 0, size, size)
|
||||
gc.Fill()
|
||||
gc.SetFillColor(color.Transparent)
|
||||
}
|
||||
// draw skeleton3d
|
||||
for idx := range roi.Skeleton3d[:l-1] {
|
||||
var (
|
||||
p0 common.Point3d
|
||||
p1 common.Point3d
|
||||
poseColor = PoseColors[idx/4]
|
||||
)
|
||||
gc.SetStrokeColor(common.ColorFromHex(poseColor))
|
||||
if idx == 5 || idx == 9 || idx == 13 || idx == 17 {
|
||||
p0 = roi.Skeleton3d[0]
|
||||
p1 = roi.Skeleton3d[idx]
|
||||
gc.BeginPath()
|
||||
gc.MoveTo(p0.X*size, p0.Y*size)
|
||||
gc.LineTo(p1.X*size, p1.Y*size)
|
||||
gc.Close()
|
||||
gc.Stroke()
|
||||
} else if idx == 4 || idx == 8 || idx == 12 || idx == 16 {
|
||||
continue
|
||||
}
|
||||
p0 = roi.Skeleton3d[idx]
|
||||
p1 = roi.Skeleton3d[idx+1]
|
||||
gc.BeginPath()
|
||||
gc.MoveTo(p0.X*size, p0.Y*size)
|
||||
gc.LineTo(p1.X*size, p1.Y*size)
|
||||
gc.Close()
|
||||
gc.Stroke()
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
@@ -1,5 +1,9 @@
|
||||
package drawer
|
||||
|
||||
import (
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Option represents Drawer option interface
|
||||
type Option interface {
|
||||
apply(*Drawer)
|
||||
@@ -38,3 +42,17 @@ func WithKeypointStrokeWidth(w float64) Option {
|
||||
d.KeypointStrokeWidth = w
|
||||
})
|
||||
}
|
||||
|
||||
// WithKeypointColor set Drawer KeypointColor
|
||||
func WithKeypointColor(color string) Option {
|
||||
return optionFunc(func(d *Drawer) {
|
||||
d.KeypointColor = color
|
||||
})
|
||||
}
|
||||
|
||||
// WithFont set Drawer Font
|
||||
func WithFont(font *common.Font) Option {
|
||||
return optionFunc(func(d *Drawer) {
|
||||
d.Font = font
|
||||
})
|
||||
}
|
||||
|
11
go/hand/pose3d/cgo.go
Normal file
11
go/hand/pose3d/cgo.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build !vulkan
|
||||
|
||||
package pose3d
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
11
go/hand/pose3d/cgo_vulkan.go
Normal file
11
go/hand/pose3d/cgo_vulkan.go
Normal file
@@ -0,0 +1,11 @@
|
||||
// +build vulkan
|
||||
|
||||
package pose3d
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: --std=c++11 -fopenmp
|
||||
#cgo CPPFLAGS: -I ${SRCDIR}/../../../include -I /usr/local/include
|
||||
#cgo LDFLAGS: -lstdc++ -lncnn -lomp -lopenvision -lglslang -lvulkan -lSPIRV -lOGLCompiler -lMachineIndependent -lGenericCodeGen -lOSDependent
|
||||
#cgo LDFLAGS: -L /usr/local/lib -L ${SRCDIR}/../../../lib
|
||||
*/
|
||||
import "C"
|
2
go/hand/pose3d/doc.go
Normal file
2
go/hand/pose3d/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package pose hand 3d pose estimator
|
||||
package pose3d
|
62
go/hand/pose3d/mediapipe.go
Normal file
62
go/hand/pose3d/mediapipe.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package pose3d
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "openvision/common/common.h"
|
||||
#include "openvision/hand/pose3d.h"
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
"github.com/bububa/openvision/go/common"
|
||||
)
|
||||
|
||||
// Mediapipe represents mediapipe estimator interface
|
||||
type Mediapipe struct {
|
||||
d C.IHandPose3DEstimator
|
||||
}
|
||||
|
||||
func NewMediapipe() *Mediapipe {
|
||||
return &Mediapipe{
|
||||
d: C.new_mediapipe_hand(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mediapipe) Destroy() {
|
||||
C.destroy_mediapipe_hand(m.d)
|
||||
}
|
||||
|
||||
func (m *Mediapipe) LoadModel(palmPath string, handPath string) error {
|
||||
cPalm := C.CString(palmPath)
|
||||
defer C.free(unsafe.Pointer(cPalm))
|
||||
cHand := C.CString(handPath)
|
||||
defer C.free(unsafe.Pointer(cHand))
|
||||
retCode := C.mediapipe_hand_load_model(m.d, cPalm, cHand)
|
||||
if retCode != 0 {
|
||||
return openvision.LoadModelError(int(retCode))
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// Detect detect hand 3d pose
|
||||
func (m *Mediapipe) Detect(img *common.Image) ([]common.PalmObject, error) {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
cObjs := common.NewCPalmObjectVector()
|
||||
defer common.FreeCPalmObjectVector(cObjs)
|
||||
errCode := C.mediapipe_hand_detect(
|
||||
m.d,
|
||||
(*C.uchar)(unsafe.Pointer(&data[0])),
|
||||
C.int(imgWidth), C.int(imgHeight),
|
||||
(*C.PalmObjectVector)(unsafe.Pointer(cObjs)),
|
||||
)
|
||||
if errCode != 0 {
|
||||
return nil, openvision.DetectHandError(int(errCode))
|
||||
}
|
||||
return common.GoPalmObjectVector(cObjs, imgWidth, imgHeight), nil
|
||||
}
|
@@ -7,7 +7,6 @@ package segmentor
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
@@ -41,11 +40,11 @@ func (d *Deeplabv3plus) LoadModel(modelPath string) error {
|
||||
}
|
||||
|
||||
// Matting implement Segmentor interface
|
||||
func (d *Deeplabv3plus) Matting(img *common.Image) (image.Image, error) {
|
||||
return Matting(d, img)
|
||||
func (d *Deeplabv3plus) Matting(img *common.Image, out *common.Image) error {
|
||||
return Matting(d, img, out)
|
||||
}
|
||||
|
||||
// Merge implement Segmentor interface
|
||||
func (d *Deeplabv3plus) Merge(img *common.Image, bg *common.Image) (image.Image, error) {
|
||||
return Merge(d, img, bg)
|
||||
func (d *Deeplabv3plus) Merge(img *common.Image, bg *common.Image, out *common.Image) error {
|
||||
return Merge(d, img, bg, out)
|
||||
}
|
||||
|
@@ -7,7 +7,6 @@ package segmentor
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
@@ -41,11 +40,11 @@ func (d *ERDNet) LoadModel(modelPath string) error {
|
||||
}
|
||||
|
||||
// Matting implement Segmentor interface
|
||||
func (d *ERDNet) Matting(img *common.Image) (image.Image, error) {
|
||||
return Matting(d, img)
|
||||
func (d *ERDNet) Matting(img *common.Image, out *common.Image) error {
|
||||
return Matting(d, img, out)
|
||||
}
|
||||
|
||||
// Merge implement Segmentor interface
|
||||
func (d *ERDNet) Merge(img *common.Image, bg *common.Image) (image.Image, error) {
|
||||
return Merge(d, img, bg)
|
||||
func (d *ERDNet) Merge(img *common.Image, bg *common.Image, out *common.Image) error {
|
||||
return Merge(d, img, bg, out)
|
||||
}
|
||||
|
@@ -7,7 +7,6 @@ package segmentor
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
@@ -44,11 +43,11 @@ func (d *RVM) LoadModel(modelPath string) error {
|
||||
}
|
||||
|
||||
// Matting implement Segmentor interface
|
||||
func (d *RVM) Matting(img *common.Image) (image.Image, error) {
|
||||
return Matting(d, img)
|
||||
func (d *RVM) Matting(img *common.Image, out *common.Image) error {
|
||||
return Matting(d, img, out)
|
||||
}
|
||||
|
||||
// Merge implement Segmentor interface
|
||||
func (d *RVM) Merge(img *common.Image, bg *common.Image) (image.Image, error) {
|
||||
return Merge(d, img, bg)
|
||||
func (d *RVM) Merge(img *common.Image, bg *common.Image, out *common.Image) error {
|
||||
return Merge(d, img, bg, out)
|
||||
}
|
||||
|
@@ -8,7 +8,6 @@ package segmentor
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
@@ -18,12 +17,12 @@ import (
|
||||
// Segmentor represents segmentor interface
|
||||
type Segmentor interface {
|
||||
common.Estimator
|
||||
Matting(img *common.Image) (image.Image, error)
|
||||
Merge(img *common.Image, bg *common.Image) (image.Image, error)
|
||||
Matting(img *common.Image, out *common.Image) error
|
||||
Merge(img *common.Image, bg *common.Image, out *common.Image) error
|
||||
}
|
||||
|
||||
// Matting returns pose segment matting image
|
||||
func Matting(d Segmentor, img *common.Image) (image.Image, error) {
|
||||
func Matting(d Segmentor, img *common.Image, out *common.Image) error {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
@@ -36,13 +35,14 @@ func Matting(d Segmentor, img *common.Image) (image.Image, error) {
|
||||
C.int(imgHeight),
|
||||
(*C.Image)(unsafe.Pointer(outImgC)))
|
||||
if errCode != 0 {
|
||||
return nil, openvision.DetectPoseError(int(errCode))
|
||||
return openvision.DetectPoseError(int(errCode))
|
||||
}
|
||||
return common.GoImage(outImgC)
|
||||
common.GoImage(outImgC, out)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Merge merge pose with background
|
||||
func Merge(d Segmentor, img *common.Image, bg *common.Image) (image.Image, error) {
|
||||
func Merge(d Segmentor, img *common.Image, bg *common.Image, out *common.Image) error {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
@@ -59,7 +59,8 @@ func Merge(d Segmentor, img *common.Image, bg *common.Image) (image.Image, error
|
||||
C.int(bgWidth), C.int(bgHeight),
|
||||
(*C.Image)(unsafe.Pointer(outImgC)))
|
||||
if errCode != 0 {
|
||||
return nil, openvision.DetectPoseError(int(errCode))
|
||||
return openvision.DetectPoseError(int(errCode))
|
||||
}
|
||||
return common.GoImage(outImgC)
|
||||
common.GoImage(outImgC, out)
|
||||
return nil
|
||||
}
|
||||
|
@@ -7,7 +7,6 @@ package styletransfer
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
"github.com/bububa/openvision/go/common"
|
||||
@@ -41,6 +40,6 @@ func (d *AnimeGan2) LoadModel(modelPath string) error {
|
||||
}
|
||||
|
||||
// Transform implement StyleTransfer interface
|
||||
func (d *AnimeGan2) Transform(img *common.Image) (image.Image, error) {
|
||||
return Transform(d, img)
|
||||
func (d *AnimeGan2) Transform(img *common.Image, out *common.Image) error {
|
||||
return Transform(d, img, out)
|
||||
}
|
||||
|
@@ -8,7 +8,6 @@ package styletransfer
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"image"
|
||||
"unsafe"
|
||||
|
||||
openvision "github.com/bububa/openvision/go"
|
||||
@@ -18,11 +17,11 @@ import (
|
||||
// StyleTransfer represents Style Transfer interface
|
||||
type StyleTransfer interface {
|
||||
common.Estimator
|
||||
Transform(img *common.Image) (image.Image, error)
|
||||
Transform(img *common.Image, out *common.Image) error
|
||||
}
|
||||
|
||||
// Transform returns style transform image
|
||||
func Transform(d StyleTransfer, img *common.Image) (image.Image, error) {
|
||||
func Transform(d StyleTransfer, img *common.Image, out *common.Image) error {
|
||||
imgWidth := img.WidthF64()
|
||||
imgHeight := img.HeightF64()
|
||||
data := img.Bytes()
|
||||
@@ -35,7 +34,8 @@ func Transform(d StyleTransfer, img *common.Image) (image.Image, error) {
|
||||
C.int(imgHeight),
|
||||
(*C.Image)(unsafe.Pointer(outImgC)))
|
||||
if errCode != 0 {
|
||||
return nil, openvision.DetectPoseError(int(errCode))
|
||||
return openvision.DetectPoseError(int(errCode))
|
||||
}
|
||||
return common.GoImage(outImgC)
|
||||
common.GoImage(outImgC, out)
|
||||
return nil
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
file(GLOB_RECURSE SRC_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/*.cxx
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/*.c
|
||||
)
|
||||
|
||||
message(${SRC_FILES})
|
||||
@@ -10,6 +11,11 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -O2 -fPIC -std=c++11 -fopenmp")
|
||||
add_library(openvision STATIC ${SRC_FILES})
|
||||
target_link_libraries(openvision PUBLIC ncnn)
|
||||
|
||||
# set(CMAKE_PREFIX_PATH "${CMAKE_PREFIX_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/../libtorch/share/cmake/Torch")
|
||||
# find_package(Torch REQUIRED)
|
||||
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
# target_link_libraries(openvision PUBLIC ${TORCH_LIBRARIES})
|
||||
|
||||
if(OV_OPENMP)
|
||||
find_package(OpenMP)
|
||||
if(NOT TARGET OpenMP::OpenMP_CXX AND (OpenMP_CXX_FOUND OR OPENMP_FOUND))
|
||||
@@ -57,13 +63,15 @@ target_include_directories(openvision
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/recognizer/mobilefacenet>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/tracker>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/hair>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/eye>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/hopenet>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/face/aligner>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/detecter>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/pose>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/hand/pose3d>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/pose>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/pose/detecter>
|
||||
@@ -73,6 +81,11 @@ target_include_directories(openvision
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/styletransfer>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/tracker>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/counter>
|
||||
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/classifier>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/classifier/svm>
|
||||
)
|
||||
|
||||
#install(TARGETS openvision EXPORT openvision ARCHIVE DESTINATION ${LIBRARY_OUTPUT_PATH})
|
||||
@@ -89,12 +102,15 @@ file(COPY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/face/tracker.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/face/hopenet.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/face/aligner.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/face/hair.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/face/eye.h
|
||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/face
|
||||
)
|
||||
|
||||
file(COPY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hand/detecter.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hand/pose.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/hand/pose3d.h
|
||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/hand
|
||||
)
|
||||
|
||||
@@ -115,3 +131,13 @@ file(COPY
|
||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/tracker
|
||||
)
|
||||
|
||||
file(COPY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/counter/counter.h
|
||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/counter
|
||||
)
|
||||
|
||||
file(COPY
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/classifier/svm_trainer.h
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/classifier/svm_classifier.h
|
||||
DESTINATION ${INCLUDE_OUTPUT_PATH}/openvision/classifier
|
||||
)
|
||||
|
55
src/classifier/svm/svm_binary_classifier.cpp
Normal file
55
src/classifier/svm/svm_binary_classifier.cpp
Normal file
@@ -0,0 +1,55 @@
|
||||
#include "svm_binary_classifier.hpp"
|
||||
#include "svm_light/svm_common.h"
|
||||
|
||||
namespace ovclassifier {
|
||||
SVMBinaryClassifier::SVMBinaryClassifier() {}
|
||||
|
||||
SVMBinaryClassifier::~SVMBinaryClassifier() {
|
||||
if (model_ != NULL) {
|
||||
free_model(model_, 1);
|
||||
model_ = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMBinaryClassifier::LoadModel(const char *modelfile) {
|
||||
if (model_ != NULL) {
|
||||
free_model(model_, 1);
|
||||
}
|
||||
model_ = (MODEL *)my_malloc(sizeof(MODEL));
|
||||
model_ = read_model((char *)modelfile);
|
||||
if (model_->kernel_parm.kernel_type == 0) { /* linear kernel */
|
||||
/* compute weight vector */
|
||||
add_weight_vector_to_linear_model(model_);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
double SVMBinaryClassifier::Predict(const float *vec) {
|
||||
WORD *words = (WORD *)malloc(sizeof(WORD) * (model_->totwords + 10));
|
||||
for (int i = 0; i < (model_->totwords + 10); ++i) {
|
||||
if (i >= model_->totwords) {
|
||||
words[i].wnum = 0;
|
||||
words[i].weight = 0;
|
||||
} else {
|
||||
words[i].wnum = i + 1;
|
||||
words[i].weight = vec[i];
|
||||
}
|
||||
}
|
||||
DOC *doc =
|
||||
create_example(-1, 0, 0, 0.0, create_svector(words, (char *)"", 1.0));
|
||||
free(words);
|
||||
double dist;
|
||||
if (model_->kernel_parm.kernel_type == 0) {
|
||||
dist = classify_example_linear(model_, doc);
|
||||
} else {
|
||||
dist = classify_example(model_, doc);
|
||||
}
|
||||
free_example(doc, 1);
|
||||
return dist;
|
||||
}
|
||||
|
||||
int SVMBinaryClassifier::Classify(const float *vec,
|
||||
std::vector<float> &scores) {
|
||||
return -1;
|
||||
}
|
||||
} // namespace ovclassifier
|
19
src/classifier/svm/svm_binary_classifier.hpp
Normal file
19
src/classifier/svm/svm_binary_classifier.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef _CLASSIFIER_SVM_BINARY_CLASSIFIER_H_
|
||||
#define _CLASSIFIER_SVM_BINARY_CLASSIFIER_H_
|
||||
|
||||
#include "svm_classifier.hpp"
|
||||
|
||||
namespace ovclassifier {
|
||||
class SVMBinaryClassifier : public SVMClassifier {
|
||||
public:
|
||||
SVMBinaryClassifier();
|
||||
~SVMBinaryClassifier();
|
||||
int LoadModel(const char *modelfile);
|
||||
double Predict(const float *vec);
|
||||
int Classify(const float *vec, std::vector<float> &scores);
|
||||
|
||||
private:
|
||||
MODEL *model_ = NULL;
|
||||
};
|
||||
} // namespace ovclassifier
|
||||
#endif // !_CLASSIFIER_SVM_BINARY_CLASSIFIER_H_
|
126
src/classifier/svm/svm_binary_trainer.cpp
Normal file
126
src/classifier/svm/svm_binary_trainer.cpp
Normal file
@@ -0,0 +1,126 @@
|
||||
#include "svm_binary_trainer.hpp"
|
||||
#include "svm_light/svm_common.h"
|
||||
#include "svm_light/svm_learn.h"
|
||||
|
||||
namespace ovclassifier {
|
||||
|
||||
SVMBinaryTrainer::SVMBinaryTrainer() {
|
||||
learn_parm = (LEARN_PARM *)malloc(sizeof(LEARN_PARM));
|
||||
strcpy(learn_parm->predfile, "trans_predictions");
|
||||
strcpy(learn_parm->alphafile, "");
|
||||
learn_parm->biased_hyperplane = 1;
|
||||
learn_parm->sharedslack = 0;
|
||||
learn_parm->remove_inconsistent = 0;
|
||||
learn_parm->skip_final_opt_check = 0;
|
||||
learn_parm->svm_maxqpsize = 10;
|
||||
learn_parm->svm_newvarsinqp = 0;
|
||||
learn_parm->svm_iter_to_shrink = -9999;
|
||||
learn_parm->maxiter = 100000;
|
||||
learn_parm->kernel_cache_size = 40;
|
||||
learn_parm->svm_c = 0.0;
|
||||
learn_parm->eps = 0.1;
|
||||
learn_parm->transduction_posratio = -1.0;
|
||||
learn_parm->svm_costratio = 1.0;
|
||||
learn_parm->svm_costratio_unlab = 1.0;
|
||||
learn_parm->svm_unlabbound = 1E-5;
|
||||
learn_parm->epsilon_crit = 0.001;
|
||||
learn_parm->epsilon_a = 1E-15;
|
||||
learn_parm->compute_loo = 0;
|
||||
learn_parm->rho = 1.0;
|
||||
learn_parm->xa_depth = 0;
|
||||
kernel_parm = (KERNEL_PARM *)malloc(sizeof(KERNEL_PARM));
|
||||
kernel_parm->kernel_type = 0;
|
||||
kernel_parm->poly_degree = 3;
|
||||
kernel_parm->rbf_gamma = 1.0;
|
||||
kernel_parm->coef_lin = 1;
|
||||
kernel_parm->coef_const = 1;
|
||||
strcpy(kernel_parm->custom, "empty");
|
||||
}
|
||||
|
||||
SVMBinaryTrainer::~SVMBinaryTrainer() {
|
||||
if (learn_parm != NULL) {
|
||||
free(learn_parm);
|
||||
learn_parm = NULL;
|
||||
}
|
||||
if (kernel_parm != NULL) {
|
||||
free(kernel_parm);
|
||||
kernel_parm = NULL;
|
||||
}
|
||||
if (kernel_cache != NULL) {
|
||||
kernel_cache_cleanup(kernel_cache);
|
||||
kernel_cache = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void SVMBinaryTrainer::Reset() {
|
||||
feats_ = 0;
|
||||
items_.clear();
|
||||
if (kernel_cache != NULL) {
|
||||
kernel_cache_cleanup(kernel_cache);
|
||||
}
|
||||
}
|
||||
|
||||
void SVMBinaryTrainer::SetLabels(int labels) {}
|
||||
|
||||
void SVMBinaryTrainer::SetFeatures(int feats) { feats_ = feats; }
|
||||
|
||||
void SVMBinaryTrainer::AddData(int label, const float *vec) {
|
||||
if (label != 1 && label != -1) {
|
||||
return;
|
||||
}
|
||||
LabelItem itm;
|
||||
itm.label = label;
|
||||
for (int i = 0; i < feats_; ++i) {
|
||||
itm.vec.push_back(vec[i]);
|
||||
}
|
||||
items_.push_back(itm);
|
||||
}
|
||||
|
||||
int SVMBinaryTrainer::Train(const char *modelfile) {
|
||||
int totdoc = items_.size();
|
||||
if (totdoc == 0 || feats_ == 0) {
|
||||
return -1;
|
||||
}
|
||||
kernel_cache = kernel_cache_init(totdoc, learn_parm->kernel_cache_size);
|
||||
double *labels = (double *)malloc(sizeof(double) * totdoc);
|
||||
double *alphas = (double *)malloc(sizeof(double) * totdoc);
|
||||
DOC **docs = (DOC **)malloc(sizeof(DOC *) * totdoc);
|
||||
WORD *words = (WORD *)malloc(sizeof(WORD) * (feats_ + 10));
|
||||
for (int dnum = 0; dnum < totdoc; ++dnum) {
|
||||
const int docFeats = items_[dnum].vec.size();
|
||||
for (int i = 0; i < (feats_ + 10); ++i) {
|
||||
if (i >= feats_) {
|
||||
words[i].wnum = 0;
|
||||
} else {
|
||||
words[i].wnum = i + 1;
|
||||
}
|
||||
if (i >= docFeats) {
|
||||
words[i].weight = 0;
|
||||
} else {
|
||||
words[i].weight = items_[dnum].vec[i];
|
||||
}
|
||||
}
|
||||
labels[dnum] = items_[dnum].label;
|
||||
docs[dnum] =
|
||||
create_example(dnum, 0, 0, 0, create_svector(words, (char *)"", 1.0));
|
||||
}
|
||||
free(words);
|
||||
|
||||
MODEL *model_ = (MODEL *)malloc(sizeof(MODEL));
|
||||
svm_learn_classification(docs, labels, (long int)totdoc, (long int)feats_,
|
||||
learn_parm, kernel_parm, kernel_cache, model_,
|
||||
alphas);
|
||||
write_model((char *)modelfile, model_);
|
||||
free(labels);
|
||||
labels = NULL;
|
||||
free(alphas);
|
||||
alphas = NULL;
|
||||
for (int i = 0; i < totdoc; i++) {
|
||||
free_example(docs[i], 1);
|
||||
}
|
||||
free_model(model_, 0);
|
||||
model_ = NULL;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace ovclassifier
|
28
src/classifier/svm/svm_binary_trainer.hpp
Normal file
28
src/classifier/svm/svm_binary_trainer.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#ifndef _SVM_BINARY_TRAINER_H_
|
||||
#define _SVM_BINARY_TRAINER_H_
|
||||
|
||||
#include "svm_common.hpp"
|
||||
#include "svm_light/svm_common.h"
|
||||
#include "svm_trainer.hpp"
|
||||
#include <vector>
|
||||
|
||||
namespace ovclassifier {
|
||||
class SVMBinaryTrainer : public SVMTrainer {
|
||||
public:
|
||||
SVMBinaryTrainer();
|
||||
~SVMBinaryTrainer();
|
||||
void Reset();
|
||||
void SetLabels(int labels);
|
||||
void SetFeatures(int feats);
|
||||
void AddData(int label, const float *vec);
|
||||
int Train(const char *modelfile);
|
||||
|
||||
private:
|
||||
KERNEL_PARM *kernel_parm = NULL;
|
||||
LEARN_PARM *learn_parm = NULL;
|
||||
KERNEL_CACHE *kernel_cache = NULL;
|
||||
int feats_;
|
||||
std::vector<LabelItem> items_;
|
||||
};
|
||||
} // namespace ovclassifier
|
||||
#endif // _SVM_BINARY_TRAINER_H_
|
33
src/classifier/svm/svm_classifier.cpp
Normal file
33
src/classifier/svm/svm_classifier.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
#include "../svm_classifier.h"
|
||||
#include "svm_binary_classifier.hpp"
|
||||
#include "svm_multiclass_classifier.hpp"
|
||||
|
||||
ISVMClassifier new_svm_binary_classifier() {
|
||||
return new ovclassifier::SVMBinaryClassifier();
|
||||
}
|
||||
ISVMClassifier new_svm_multiclass_classifier() {
|
||||
return new ovclassifier::SVMMultiClassClassifier();
|
||||
}
|
||||
void destroy_svm_classifier(ISVMClassifier e) {
|
||||
delete static_cast<ovclassifier::SVMClassifier *>(e);
|
||||
}
|
||||
int svm_classifier_load_model(ISVMClassifier e, const char *modelfile) {
|
||||
return static_cast<ovclassifier::SVMClassifier *>(e)->LoadModel(modelfile);
|
||||
}
|
||||
double svm_predict(ISVMClassifier e, const float *vec) {
|
||||
return static_cast<ovclassifier::SVMClassifier *>(e)->Predict(vec);
|
||||
}
|
||||
int svm_classify(ISVMClassifier e, const float *vec, FloatVector *scores) {
|
||||
std::vector<float> scores_;
|
||||
int ret =
|
||||
static_cast<ovclassifier::SVMClassifier *>(e)->Classify(vec, scores_);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
scores->length = scores_.size();
|
||||
scores->values = (float *)malloc(sizeof(float) * scores->length);
|
||||
for (int i = 0; i < scores->length; ++i) {
|
||||
scores->values[i] = scores_[i];
|
||||
}
|
||||
return 0;
|
||||
}
|
16
src/classifier/svm/svm_classifier.hpp
Normal file
16
src/classifier/svm/svm_classifier.hpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef _CLASSIFIER_SVM_CLASSIFIER_H_
|
||||
#define _CLASSIFIER_SVM_CLASSIFIER_H_
|
||||
|
||||
#include "svm_light/svm_common.h"
|
||||
#include <vector>
|
||||
|
||||
namespace ovclassifier {
|
||||
class SVMClassifier {
|
||||
public:
|
||||
virtual ~SVMClassifier(){};
|
||||
virtual int LoadModel(const char *modelfile) = 0;
|
||||
virtual double Predict(const float *vec) = 0;
|
||||
virtual int Classify(const float *vec, std::vector<float> &scores) = 0;
|
||||
};
|
||||
} // namespace ovclassifier
|
||||
#endif // !_CLASSIFIER_SVM_CLASSIFIER_H_
|
11
src/classifier/svm/svm_common.hpp
Normal file
11
src/classifier/svm/svm_common.hpp
Normal file
@@ -0,0 +1,11 @@
|
||||
#ifndef _SVM_COMMON_HPP_
|
||||
#define _SVM_COMMON_HPP_
|
||||
|
||||
#include <vector>
|
||||
namespace ovclassifier {
|
||||
struct LabelItem {
|
||||
int label;
|
||||
std::vector<float> vec;
|
||||
};
|
||||
} // namespace ovclassifier
|
||||
#endif // !_SVM_COMMON_HPP_
|
40
src/classifier/svm/svm_light/kernel.h
Normal file
40
src/classifier/svm/svm_light/kernel.h
Normal file
@@ -0,0 +1,40 @@
|
||||
/************************************************************************/
|
||||
/* */
|
||||
/* kernel.h */
|
||||
/* */
|
||||
/* User defined kernel function. Feel free to plug in your own. */
|
||||
/* */
|
||||
/* Copyright: Thorsten Joachims */
|
||||
/* Date: 16.12.97 */
|
||||
/* */
|
||||
/************************************************************************/
|
||||
|
||||
/* KERNEL_PARM is defined in svm_common.h The field 'custom' is reserved for */
|
||||
/* parameters of the user defined kernel. You can also access and use */
|
||||
/* the parameters of the other kernels. Just replace the line
|
||||
return((double)(1.0));
|
||||
with your own kernel. */
|
||||
|
||||
/* Example: The following computes the polynomial kernel. sprod_ss
|
||||
computes the inner product between two sparse vectors.
|
||||
|
||||
return((CFLOAT)pow(kernel_parm->coef_lin*sprod_ss(a,b)
|
||||
+kernel_parm->coef_const,(double)kernel_parm->poly_degree));
|
||||
*/
|
||||
|
||||
/* If you are implementing a kernel that is not based on a
|
||||
feature/value representation, you might want to make use of the
|
||||
field "userdefined" in SVECTOR. By default, this field will contain
|
||||
whatever string you put behind a # sign in the example file. So, if
|
||||
a line in your training file looks like
|
||||
|
||||
-1 1:3 5:6 #abcdefg
|
||||
|
||||
then the SVECTOR field "words" will contain the vector 1:3 5:6, and
|
||||
"userdefined" will contain the string "abcdefg". */
|
||||
|
||||
double custom_kernel(KERNEL_PARM *kernel_parm, SVECTOR *a, SVECTOR *b)
|
||||
/* plug in you favorite kernel */
|
||||
{
|
||||
return((double)(1.0));
|
||||
}
|
619
src/classifier/svm/svm_light/pr_loqo/pr_loqo.c
Normal file
619
src/classifier/svm/svm_light/pr_loqo/pr_loqo.c
Normal file
@@ -0,0 +1,619 @@
|
||||
/*
|
||||
* File: pr_loqo.c
|
||||
* Purpose: solves quadratic programming problem for pattern recognition
|
||||
* for support vectors
|
||||
*
|
||||
* Author: Alex J. Smola
|
||||
* Created: 10/14/97
|
||||
* Updated: 11/08/97
|
||||
* Updated: 13/08/98 (removed exit(1) as it crashes svm lite when the margin
|
||||
* in a not sufficiently conservative manner)
|
||||
*
|
||||
*
|
||||
* Copyright (c) 1997 GMD Berlin - All rights reserved
|
||||
* THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE of GMD Berlin
|
||||
* The copyright notice above does not evidence any
|
||||
* actual or intended publication of this work.
|
||||
*
|
||||
* Unauthorized commercial use of this software is not allowed
|
||||
*/
|
||||
|
||||
#include "pr_loqo.h"
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <time.h>
|
||||
|
||||
#define max(A, B) ((A) > (B) ? (A) : (B))
|
||||
#define min(A, B) ((A) < (B) ? (A) : (B))
|
||||
#define sqr(A) ((A) * (A))
|
||||
#define ABS(A) ((A) > 0 ? (A) : (-(A)))
|
||||
|
||||
#define PREDICTOR 1
|
||||
#define CORRECTOR 2
|
||||
|
||||
/*****************************************************************
|
||||
replace this by any other function that will exit gracefully
|
||||
in a larger system
|
||||
***************************************************************/
|
||||
|
||||
void nrerror(char error_text[]) {
|
||||
printf("ERROR: terminating optimizer - %s\n", error_text);
|
||||
/* exit(1); */
|
||||
}
|
||||
|
||||
/*****************************************************************
|
||||
taken from numerical recipes and modified to accept pointers
|
||||
moreover numerical recipes code seems to be buggy (at least the
|
||||
ones on the web)
|
||||
|
||||
cholesky solver and backsubstitution
|
||||
leaves upper right triangle intact (rows first order)
|
||||
***************************************************************/
|
||||
|
||||
void choldc(double a[], int n, double p[]) {
|
||||
void nrerror(char error_text[]);
|
||||
int i, j, k;
|
||||
double sum;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
for (j = i; j < n; j++) {
|
||||
sum = a[n * i + j];
|
||||
for (k = i - 1; k >= 0; k--)
|
||||
sum -= a[n * i + k] * a[n * j + k];
|
||||
if (i == j) {
|
||||
if (sum <= 0.0) {
|
||||
nrerror((char *)"choldc failed, matrix not positive definite");
|
||||
sum = 0.0;
|
||||
}
|
||||
p[i] = sqrt(sum);
|
||||
} else
|
||||
a[n * j + i] = sum / p[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cholsb(double a[], int n, double p[], double b[], double x[]) {
|
||||
int i, k;
|
||||
double sum;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
sum = b[i];
|
||||
for (k = i - 1; k >= 0; k--)
|
||||
sum -= a[n * i + k] * x[k];
|
||||
x[i] = sum / p[i];
|
||||
}
|
||||
|
||||
for (i = n - 1; i >= 0; i--) {
|
||||
sum = x[i];
|
||||
for (k = i + 1; k < n; k++)
|
||||
sum -= a[n * k + i] * x[k];
|
||||
x[i] = sum / p[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*****************************************************************
|
||||
sometimes we only need the forward or backward pass of the
|
||||
backsubstitution, hence we provide these two routines separately
|
||||
***************************************************************/
|
||||
|
||||
void chol_forward(double a[], int n, double p[], double b[], double x[]) {
|
||||
int i, k;
|
||||
double sum;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
sum = b[i];
|
||||
for (k = i - 1; k >= 0; k--)
|
||||
sum -= a[n * i + k] * x[k];
|
||||
x[i] = sum / p[i];
|
||||
}
|
||||
}
|
||||
|
||||
void chol_backward(double a[], int n, double p[], double b[], double x[]) {
|
||||
int i, k;
|
||||
double sum;
|
||||
|
||||
for (i = n - 1; i >= 0; i--) {
|
||||
sum = b[i];
|
||||
for (k = i + 1; k < n; k++)
|
||||
sum -= a[n * k + i] * x[k];
|
||||
x[i] = sum / p[i];
|
||||
}
|
||||
}
|
||||
|
||||
/*****************************************************************
|
||||
solves the system | -H_x A' | |x_x| = |c_x|
|
||||
| A H_y| |x_y| |c_y|
|
||||
|
||||
with H_x (and H_y) positive (semidefinite) matrices
|
||||
and n, m the respective sizes of H_x and H_y
|
||||
|
||||
for variables see pg. 48 of notebook or do the calculations on a
|
||||
sheet of paper again
|
||||
|
||||
predictor solves the whole thing, corrector assues that H_x didn't
|
||||
change and relies on the results of the predictor. therefore do
|
||||
_not_ modify workspace
|
||||
|
||||
if you want to speed tune anything in the code here's the right
|
||||
place to do so: about 95% of the time is being spent in
|
||||
here. something like an iterative refinement would be nice,
|
||||
especially when switching from double to single precision. if you
|
||||
have a fast parallel cholesky use it instead of the numrec
|
||||
implementations.
|
||||
|
||||
side effects: changes H_y (but this is just the unit matrix or zero anyway
|
||||
in our case)
|
||||
***************************************************************/
|
||||
|
||||
void solve_reduced(int n, int m, double h_x[], double h_y[], double a[],
|
||||
double x_x[], double x_y[], double c_x[], double c_y[],
|
||||
double workspace[], int step) {
|
||||
int i, j, k;
|
||||
|
||||
double *p_x;
|
||||
double *p_y;
|
||||
double *t_a;
|
||||
double *t_c;
|
||||
double *t_y;
|
||||
|
||||
p_x = workspace; /* together n + m + n*m + n + m = n*(m+2)+2*m */
|
||||
p_y = p_x + n;
|
||||
t_a = p_y + m;
|
||||
t_c = t_a + n * m;
|
||||
t_y = t_c + n;
|
||||
|
||||
if (step == PREDICTOR) {
|
||||
choldc(h_x, n, p_x); /* do cholesky decomposition */
|
||||
|
||||
for (i = 0; i < m; i++) /* forward pass for A' */
|
||||
chol_forward(h_x, n, p_x, a + i * n, t_a + i * n);
|
||||
|
||||
for (i = 0; i < m; i++) /* compute (h_y + a h_x^-1A') */
|
||||
for (j = i; j < m; j++)
|
||||
for (k = 0; k < n; k++)
|
||||
h_y[m * i + j] += t_a[n * j + k] * t_a[n * i + k];
|
||||
|
||||
choldc(h_y, m, p_y); /* and cholesky decomposition */
|
||||
}
|
||||
|
||||
chol_forward(h_x, n, p_x, c_x, t_c);
|
||||
/* forward pass for c */
|
||||
|
||||
for (i = 0; i < m; i++) { /* and solve for x_y */
|
||||
t_y[i] = c_y[i];
|
||||
for (j = 0; j < n; j++)
|
||||
t_y[i] += t_a[i * n + j] * t_c[j];
|
||||
}
|
||||
|
||||
cholsb(h_y, m, p_y, t_y, x_y);
|
||||
|
||||
for (i = 0; i < n; i++) { /* finally solve for x_x */
|
||||
t_c[i] = -t_c[i];
|
||||
for (j = 0; j < m; j++)
|
||||
t_c[i] += t_a[j * n + i] * x_y[j];
|
||||
}
|
||||
|
||||
chol_backward(h_x, n, p_x, t_c, x_x);
|
||||
}
|
||||
|
||||
/*****************************************************************
|
||||
matrix vector multiplication (symmetric matrix but only one triangle
|
||||
given). computes m*x = y
|
||||
no need to tune it as it's only of O(n^2) but cholesky is of
|
||||
O(n^3). so don't waste your time _here_ although it isn't very
|
||||
elegant.
|
||||
***************************************************************/
|
||||
|
||||
void matrix_vector(int n, double m[], double x[], double y[]) {
|
||||
int i, j;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
y[i] = m[(n + 1) * i] * x[i];
|
||||
|
||||
for (j = 0; j < i; j++)
|
||||
y[i] += m[i + n * j] * x[j];
|
||||
|
||||
for (j = i + 1; j < n; j++)
|
||||
y[i] += m[n * i + j] * x[j];
|
||||
}
|
||||
}
|
||||
|
||||
/*****************************************************************
|
||||
call only this routine; this is the only one you're interested in
|
||||
for doing quadratical optimization
|
||||
|
||||
the restart feature exists but it may not be of much use due to the
|
||||
fact that an initial setting, although close but not very close the
|
||||
the actual solution will result in very good starting diagnostics
|
||||
(primal and dual feasibility and small infeasibility gap) but incur
|
||||
later stalling of the optimizer afterwards as we have to enforce
|
||||
positivity of the slacks.
|
||||
***************************************************************/
|
||||
|
||||
int pr_loqo(int n, int m, double c[], double h_x[], double a[], double b[],
|
||||
double l[], double u[], double primal[], double dual[], int verb,
|
||||
double sigfig_max, int counter_max, double margin, double bound,
|
||||
int restart) {
|
||||
/* the knobs to be tuned ... */
|
||||
/* double margin = -0.95; we will go up to 95% of the
|
||||
distance between old variables and zero */
|
||||
/* double bound = 10; preset value for the start. small
|
||||
values give good initial
|
||||
feasibility but may result in slow
|
||||
convergence afterwards: we're too
|
||||
close to zero */
|
||||
/* to be allocated */
|
||||
double *workspace;
|
||||
double *diag_h_x;
|
||||
double *h_y;
|
||||
double *c_x;
|
||||
double *c_y;
|
||||
double *h_dot_x;
|
||||
double *rho;
|
||||
double *nu;
|
||||
double *tau;
|
||||
double *sigma;
|
||||
double *gamma_z;
|
||||
double *gamma_s;
|
||||
|
||||
double *hat_nu;
|
||||
double *hat_tau;
|
||||
|
||||
double *delta_x;
|
||||
double *delta_y;
|
||||
double *delta_s;
|
||||
double *delta_z;
|
||||
double *delta_g;
|
||||
double *delta_t;
|
||||
|
||||
double *d;
|
||||
|
||||
/* from the header - pointers into primal and dual */
|
||||
double *x;
|
||||
double *y;
|
||||
double *g;
|
||||
double *z;
|
||||
double *s;
|
||||
double *t;
|
||||
|
||||
/* auxiliary variables */
|
||||
double b_plus_1;
|
||||
double c_plus_1;
|
||||
|
||||
double x_h_x;
|
||||
double primal_inf;
|
||||
double dual_inf;
|
||||
|
||||
double sigfig;
|
||||
double primal_obj, dual_obj;
|
||||
double mu;
|
||||
double alfa, step;
|
||||
int counter = 0;
|
||||
|
||||
int status = STILL_RUNNING;
|
||||
|
||||
int i, j, k;
|
||||
|
||||
/* memory allocation */
|
||||
workspace = malloc((n * (m + 2) + 2 * m) * sizeof(double));
|
||||
diag_h_x = malloc(n * sizeof(double));
|
||||
h_y = malloc(m * m * sizeof(double));
|
||||
c_x = malloc(n * sizeof(double));
|
||||
c_y = malloc(m * sizeof(double));
|
||||
h_dot_x = malloc(n * sizeof(double));
|
||||
|
||||
rho = malloc(m * sizeof(double));
|
||||
nu = malloc(n * sizeof(double));
|
||||
tau = malloc(n * sizeof(double));
|
||||
sigma = malloc(n * sizeof(double));
|
||||
|
||||
gamma_z = malloc(n * sizeof(double));
|
||||
gamma_s = malloc(n * sizeof(double));
|
||||
|
||||
hat_nu = malloc(n * sizeof(double));
|
||||
hat_tau = malloc(n * sizeof(double));
|
||||
|
||||
delta_x = malloc(n * sizeof(double));
|
||||
delta_y = malloc(m * sizeof(double));
|
||||
delta_s = malloc(n * sizeof(double));
|
||||
delta_z = malloc(n * sizeof(double));
|
||||
delta_g = malloc(n * sizeof(double));
|
||||
delta_t = malloc(n * sizeof(double));
|
||||
|
||||
d = malloc(n * sizeof(double));
|
||||
|
||||
/* pointers into the external variables */
|
||||
x = primal; /* n */
|
||||
g = x + n; /* n */
|
||||
t = g + n; /* n */
|
||||
|
||||
y = dual; /* m */
|
||||
z = y + m; /* n */
|
||||
s = z + n; /* n */
|
||||
|
||||
/* initial settings */
|
||||
b_plus_1 = 1;
|
||||
c_plus_1 = 0;
|
||||
for (i = 0; i < n; i++)
|
||||
c_plus_1 += c[i];
|
||||
|
||||
/* get diagonal terms */
|
||||
for (i = 0; i < n; i++)
|
||||
diag_h_x[i] = h_x[(n + 1) * i];
|
||||
|
||||
/* starting point */
|
||||
if (restart == 1) {
|
||||
/* x, y already preset */
|
||||
for (i = 0; i < n; i++) { /* compute g, t for primal feasibility */
|
||||
g[i] = max(ABS(x[i] - l[i]), bound);
|
||||
t[i] = max(ABS(u[i] - x[i]), bound);
|
||||
}
|
||||
|
||||
matrix_vector(n, h_x, x, h_dot_x); /* h_dot_x = h_x * x */
|
||||
|
||||
for (i = 0; i < n; i++) { /* sigma is a dummy variable to calculate z, s */
|
||||
sigma[i] = c[i] + h_dot_x[i];
|
||||
for (j = 0; j < m; j++)
|
||||
sigma[i] -= a[n * j + i] * y[j];
|
||||
|
||||
if (sigma[i] > 0) {
|
||||
s[i] = bound;
|
||||
z[i] = sigma[i] + bound;
|
||||
} else {
|
||||
s[i] = bound - sigma[i];
|
||||
z[i] = bound;
|
||||
}
|
||||
}
|
||||
} else { /* use default start settings */
|
||||
for (i = 0; i < m; i++)
|
||||
for (j = i; j < m; j++)
|
||||
h_y[i * m + j] = (i == j) ? 1 : 0;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
c_x[i] = c[i];
|
||||
h_x[(n + 1) * i] += 1;
|
||||
}
|
||||
|
||||
for (i = 0; i < m; i++)
|
||||
c_y[i] = b[i];
|
||||
|
||||
/* and solve the system [-H_x A'; A H_y] [x, y] = [c_x; c_y] */
|
||||
solve_reduced(n, m, h_x, h_y, a, x, y, c_x, c_y, workspace, PREDICTOR);
|
||||
|
||||
/* initialize the other variables */
|
||||
for (i = 0; i < n; i++) {
|
||||
g[i] = max(ABS(x[i] - l[i]), bound);
|
||||
z[i] = max(ABS(x[i]), bound);
|
||||
t[i] = max(ABS(u[i] - x[i]), bound);
|
||||
s[i] = max(ABS(x[i]), bound);
|
||||
}
|
||||
}
|
||||
|
||||
for (i = 0, mu = 0; i < n; i++)
|
||||
mu += z[i] * g[i] + s[i] * t[i];
|
||||
mu = mu / (2 * n);
|
||||
|
||||
/* the main loop */
|
||||
if (verb >= STATUS) {
|
||||
printf("counter | pri_inf | dual_inf | pri_obj | dual_obj | ");
|
||||
printf("sigfig | alpha | nu \n");
|
||||
printf("-------------------------------------------------------");
|
||||
printf("---------------------------\n");
|
||||
}
|
||||
|
||||
while (status == STILL_RUNNING) {
|
||||
/* predictor */
|
||||
|
||||
/* put back original diagonal values */
|
||||
for (i = 0; i < n; i++)
|
||||
h_x[(n + 1) * i] = diag_h_x[i];
|
||||
|
||||
matrix_vector(n, h_x, x, h_dot_x); /* compute h_dot_x = h_x * x */
|
||||
|
||||
for (i = 0; i < m; i++) {
|
||||
rho[i] = b[i];
|
||||
for (j = 0; j < n; j++)
|
||||
rho[i] -= a[n * i + j] * x[j];
|
||||
}
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
nu[i] = l[i] - x[i] + g[i];
|
||||
tau[i] = u[i] - x[i] - t[i];
|
||||
|
||||
sigma[i] = c[i] - z[i] + s[i] + h_dot_x[i];
|
||||
for (j = 0; j < m; j++)
|
||||
sigma[i] -= a[n * j + i] * y[j];
|
||||
|
||||
gamma_z[i] = -z[i];
|
||||
gamma_s[i] = -s[i];
|
||||
}
|
||||
|
||||
/* instrumentation */
|
||||
x_h_x = 0;
|
||||
primal_inf = 0;
|
||||
dual_inf = 0;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
x_h_x += h_dot_x[i] * x[i];
|
||||
primal_inf += sqr(tau[i]);
|
||||
primal_inf += sqr(nu[i]);
|
||||
dual_inf += sqr(sigma[i]);
|
||||
}
|
||||
for (i = 0; i < m; i++)
|
||||
primal_inf += sqr(rho[i]);
|
||||
primal_inf = sqrt(primal_inf) / b_plus_1;
|
||||
dual_inf = sqrt(dual_inf) / c_plus_1;
|
||||
|
||||
primal_obj = 0.5 * x_h_x;
|
||||
dual_obj = -0.5 * x_h_x;
|
||||
for (i = 0; i < n; i++) {
|
||||
primal_obj += c[i] * x[i];
|
||||
dual_obj += l[i] * z[i] - u[i] * s[i];
|
||||
}
|
||||
for (i = 0; i < m; i++)
|
||||
dual_obj += b[i] * y[i];
|
||||
|
||||
sigfig = log10(ABS(primal_obj) + 1) - log10(ABS(primal_obj - dual_obj));
|
||||
sigfig = max(sigfig, 0);
|
||||
|
||||
/* the diagnostics - after we computed our results we will
|
||||
analyze them */
|
||||
|
||||
if (counter > counter_max)
|
||||
status = ITERATION_LIMIT;
|
||||
if (sigfig > sigfig_max)
|
||||
status = OPTIMAL_SOLUTION;
|
||||
if (primal_inf > 10e100)
|
||||
status = PRIMAL_INFEASIBLE;
|
||||
if (dual_inf > 10e100)
|
||||
status = DUAL_INFEASIBLE;
|
||||
if ((primal_inf > 10e100) & (dual_inf > 10e100))
|
||||
status = PRIMAL_AND_DUAL_INFEASIBLE;
|
||||
if (ABS(primal_obj) > 10e100)
|
||||
status = PRIMAL_UNBOUNDED;
|
||||
if (ABS(dual_obj) > 10e100)
|
||||
status = DUAL_UNBOUNDED;
|
||||
|
||||
/* write some nice routine to enforce the time limit if you
|
||||
_really_ want, however it's quite useless as you can compute
|
||||
the time from the maximum number of iterations as every
|
||||
iteration costs one cholesky decomposition plus a couple of
|
||||
backsubstitutions */
|
||||
|
||||
/* generate report */
|
||||
if ((verb >= FLOOD) | ((verb == STATUS) & (status != 0)))
|
||||
printf("%7i | %.2e | %.2e | % .2e | % .2e | %6.3f | %.4f | %.2e\n",
|
||||
counter, primal_inf, dual_inf, primal_obj, dual_obj, sigfig, alfa,
|
||||
mu);
|
||||
|
||||
counter++;
|
||||
|
||||
if (status == 0) { /* we may keep on going, otherwise
|
||||
it'll cost one loop extra plus a
|
||||
messed up main diagonal of h_x */
|
||||
/* intermediate variables (the ones with hat) */
|
||||
for (i = 0; i < n; i++) {
|
||||
hat_nu[i] = nu[i] + g[i] * gamma_z[i] / z[i];
|
||||
hat_tau[i] = tau[i] - t[i] * gamma_s[i] / s[i];
|
||||
/* diagonal terms */
|
||||
d[i] = z[i] / g[i] + s[i] / t[i];
|
||||
}
|
||||
|
||||
/* initialization before the cholesky solver */
|
||||
for (i = 0; i < n; i++) {
|
||||
h_x[(n + 1) * i] = diag_h_x[i] + d[i];
|
||||
c_x[i] = sigma[i] - z[i] * hat_nu[i] / g[i] - s[i] * hat_tau[i] / t[i];
|
||||
}
|
||||
for (i = 0; i < m; i++) {
|
||||
c_y[i] = rho[i];
|
||||
for (j = i; j < m; j++)
|
||||
h_y[m * i + j] = 0;
|
||||
}
|
||||
|
||||
/* and do it */
|
||||
solve_reduced(n, m, h_x, h_y, a, delta_x, delta_y, c_x, c_y, workspace,
|
||||
PREDICTOR);
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
/* backsubstitution */
|
||||
delta_s[i] = s[i] * (delta_x[i] - hat_tau[i]) / t[i];
|
||||
delta_z[i] = z[i] * (hat_nu[i] - delta_x[i]) / g[i];
|
||||
|
||||
delta_g[i] = g[i] * (gamma_z[i] - delta_z[i]) / z[i];
|
||||
delta_t[i] = t[i] * (gamma_s[i] - delta_s[i]) / s[i];
|
||||
|
||||
/* central path (corrector) */
|
||||
gamma_z[i] = mu / g[i] - z[i] - delta_z[i] * delta_g[i] / g[i];
|
||||
gamma_s[i] = mu / t[i] - s[i] - delta_s[i] * delta_t[i] / t[i];
|
||||
|
||||
/* (some more intermediate variables) the hat variables */
|
||||
hat_nu[i] = nu[i] + g[i] * gamma_z[i] / z[i];
|
||||
hat_tau[i] = tau[i] - t[i] * gamma_s[i] / s[i];
|
||||
|
||||
/* initialization before the cholesky */
|
||||
c_x[i] = sigma[i] - z[i] * hat_nu[i] / g[i] - s[i] * hat_tau[i] / t[i];
|
||||
}
|
||||
|
||||
for (i = 0; i < m; i++) { /* comput c_y and rho */
|
||||
c_y[i] = rho[i];
|
||||
for (j = i; j < m; j++)
|
||||
h_y[m * i + j] = 0;
|
||||
}
|
||||
|
||||
/* and do it */
|
||||
solve_reduced(n, m, h_x, h_y, a, delta_x, delta_y, c_x, c_y, workspace,
|
||||
CORRECTOR);
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
/* backsubstitution */
|
||||
delta_s[i] = s[i] * (delta_x[i] - hat_tau[i]) / t[i];
|
||||
delta_z[i] = z[i] * (hat_nu[i] - delta_x[i]) / g[i];
|
||||
|
||||
delta_g[i] = g[i] * (gamma_z[i] - delta_z[i]) / z[i];
|
||||
delta_t[i] = t[i] * (gamma_s[i] - delta_s[i]) / s[i];
|
||||
}
|
||||
|
||||
alfa = -1;
|
||||
for (i = 0; i < n; i++) {
|
||||
alfa = min(alfa, delta_g[i] / g[i]);
|
||||
alfa = min(alfa, delta_t[i] / t[i]);
|
||||
alfa = min(alfa, delta_s[i] / s[i]);
|
||||
alfa = min(alfa, delta_z[i] / z[i]);
|
||||
}
|
||||
alfa = (margin - 1) / alfa;
|
||||
|
||||
/* compute mu */
|
||||
for (i = 0, mu = 0; i < n; i++)
|
||||
mu += z[i] * g[i] + s[i] * t[i];
|
||||
mu = mu / (2 * n);
|
||||
mu = mu * sqr((alfa - 1) / (alfa + 10));
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
x[i] += alfa * delta_x[i];
|
||||
g[i] += alfa * delta_g[i];
|
||||
t[i] += alfa * delta_t[i];
|
||||
z[i] += alfa * delta_z[i];
|
||||
s[i] += alfa * delta_s[i];
|
||||
}
|
||||
|
||||
for (i = 0; i < m; i++)
|
||||
y[i] += alfa * delta_y[i];
|
||||
}
|
||||
}
|
||||
if ((status == 1) && (verb >= STATUS)) {
|
||||
printf("-------------------------------------------------------------------"
|
||||
"---------------\n");
|
||||
printf("optimization converged\n");
|
||||
}
|
||||
|
||||
/* free memory */
|
||||
free(workspace);
|
||||
free(diag_h_x);
|
||||
free(h_y);
|
||||
free(c_x);
|
||||
free(c_y);
|
||||
free(h_dot_x);
|
||||
|
||||
free(rho);
|
||||
free(nu);
|
||||
free(tau);
|
||||
free(sigma);
|
||||
free(gamma_z);
|
||||
free(gamma_s);
|
||||
|
||||
free(hat_nu);
|
||||
free(hat_tau);
|
||||
|
||||
free(delta_x);
|
||||
free(delta_y);
|
||||
free(delta_s);
|
||||
free(delta_z);
|
||||
free(delta_g);
|
||||
free(delta_t);
|
||||
|
||||
free(d);
|
||||
|
||||
/* and return to sender */
|
||||
return status;
|
||||
}
|
93
src/classifier/svm/svm_light/pr_loqo/pr_loqo.h
Normal file
93
src/classifier/svm/svm_light/pr_loqo/pr_loqo.h
Normal file
@@ -0,0 +1,93 @@
|
||||
/*
|
||||
* File: pr_loqo.h
|
||||
* Purpose: solves quadratic programming problem for pattern recognition
|
||||
* for support vectors
|
||||
*
|
||||
* Author: Alex J. Smola
|
||||
* Created: 10/14/97
|
||||
* Updated: 11/08/97
|
||||
*
|
||||
*
|
||||
* Copyright (c) 1997 GMD Berlin - All rights reserved
|
||||
* THIS IS UNPUBLISHED PROPRIETARY SOURCE CODE of GMD Berlin
|
||||
* The copyright notice above does not evidence any
|
||||
* actual or intended publication of this work.
|
||||
*
|
||||
* Unauthorized commercial use of this software is not allowed
|
||||
*/
|
||||
|
||||
/* verbosity levels */
|
||||
#ifndef _PR_LOQO_H_
|
||||
#define _PR_LOQO_H_
|
||||
|
||||
#define QUIET 0
|
||||
#define STATUS 1
|
||||
#define FLOOD 2
|
||||
|
||||
/* status outputs */
|
||||
|
||||
#define STILL_RUNNING 0
|
||||
#define OPTIMAL_SOLUTION 1
|
||||
#define SUBOPTIMAL_SOLUTION 2
|
||||
#define ITERATION_LIMIT 3
|
||||
#define PRIMAL_INFEASIBLE 4
|
||||
#define DUAL_INFEASIBLE 5
|
||||
#define PRIMAL_AND_DUAL_INFEASIBLE 6
|
||||
#define INCONSISTENT 7
|
||||
#define PRIMAL_UNBOUNDED 8
|
||||
#define DUAL_UNBOUNDED 9
|
||||
#define TIME_LIMIT 10
|
||||
|
||||
/*
|
||||
* solve the quadratic programming problem
|
||||
*
|
||||
* minimize c' * x + 1/2 x' * H * x
|
||||
* subject to A*x = b
|
||||
* l <= x <= u
|
||||
*
|
||||
* for a documentation see R. Vanderbei, LOQO: an Interior Point Code
|
||||
* for Quadratic Programming
|
||||
*/
|
||||
|
||||
/*
|
||||
* n : number of primal variables
|
||||
* m : number of constraints (typically 1)
|
||||
* h_x : dot product matrix (n.n)
|
||||
* a : constraint matrix (n.m)
|
||||
* b : constant term (m)
|
||||
* l : lower bound (n)
|
||||
* u : upper bound (m)
|
||||
*
|
||||
* primal : workspace for primal variables, has to be of size 3 n
|
||||
*
|
||||
* x = primal; n
|
||||
* g = x + n; n
|
||||
* t = g + n; n
|
||||
*
|
||||
* dual : workspace for dual variables, has to be of size m + 2 n
|
||||
*
|
||||
* y = dual; m
|
||||
* z = y + m; n
|
||||
* s = z + n; n
|
||||
*
|
||||
* verb : verbosity level
|
||||
* sigfig_max : number of significant digits
|
||||
* counter_max: stopping criterion
|
||||
* restart : 1 if restart desired
|
||||
*
|
||||
*/
|
||||
|
||||
int pr_loqo(int n, int m, double c[], double h_x[], double a[], double b[],
|
||||
double l[], double u[], double primal[], double dual[], int verb,
|
||||
double sigfig_max, int counter_max, double margin, double bound,
|
||||
int restart);
|
||||
|
||||
/*
|
||||
* compile with
|
||||
cc -O4 -c pr_loqo.c
|
||||
cc -xO4 -fast -xarch=v8plus -xchip=ultra -xparallel -c pr_loqo.c
|
||||
mex pr_loqo_c.c pr_loqo.o
|
||||
cmex4 pr_loqo_c.c pr_loqo.o -DMATLAB4 -o pr_loqo_c4
|
||||
*
|
||||
*/
|
||||
#endif // !_PR_LOQO_H_
|
198
src/classifier/svm/svm_light/svm_classify.c
Normal file
198
src/classifier/svm/svm_light/svm_classify.c
Normal file
@@ -0,0 +1,198 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_classify.c */
|
||||
/* */
|
||||
/* Classification module of Support Vector Machine. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 02.07.02 */
|
||||
/* */
|
||||
/* Copyright (c) 2002 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/************************************************************************/
|
||||
|
||||
# include "svm_common.h"
|
||||
|
||||
char docfile[200];
|
||||
char modelfile[200];
|
||||
char predictionsfile[200];
|
||||
|
||||
void read_input_parameters(int, char **, char *, char *, char *, long *,
|
||||
long *);
|
||||
void print_help(void);
|
||||
|
||||
|
||||
int main (int argc, char* argv[])
|
||||
{
|
||||
DOC *doc; /* test example */
|
||||
WORD *words;
|
||||
long max_docs,max_words_doc,lld;
|
||||
long totdoc=0,queryid,slackid;
|
||||
long correct=0,incorrect=0,no_accuracy=0;
|
||||
long res_a=0,res_b=0,res_c=0,res_d=0,wnum,pred_format;
|
||||
long j;
|
||||
double t1,runtime=0;
|
||||
double dist,doc_label,costfactor;
|
||||
char *line,*comment;
|
||||
FILE *predfl,*docfl;
|
||||
MODEL *model;
|
||||
|
||||
read_input_parameters(argc,argv,docfile,modelfile,predictionsfile,
|
||||
&verbosity,&pred_format);
|
||||
|
||||
nol_ll(docfile,&max_docs,&max_words_doc,&lld); /* scan size of input file */
|
||||
max_words_doc+=2;
|
||||
lld+=2;
|
||||
|
||||
line = (char *)my_malloc(sizeof(char)*lld);
|
||||
words = (WORD *)my_malloc(sizeof(WORD)*(max_words_doc+10));
|
||||
|
||||
model=read_model(modelfile);
|
||||
|
||||
if(model->kernel_parm.kernel_type == 0) { /* linear kernel */
|
||||
/* compute weight vector */
|
||||
add_weight_vector_to_linear_model(model);
|
||||
}
|
||||
|
||||
if(verbosity>=2) {
|
||||
printf("Classifying test examples.."); fflush(stdout);
|
||||
}
|
||||
|
||||
if ((docfl = fopen (docfile, "r")) == NULL)
|
||||
{ perror (docfile); exit (1); }
|
||||
if ((predfl = fopen (predictionsfile, "w")) == NULL)
|
||||
{ perror (predictionsfile); exit (1); }
|
||||
|
||||
while((!feof(docfl)) && fgets(line,(int)lld,docfl)) {
|
||||
if(line[0] == '#') continue; /* line contains comments */
|
||||
parse_document(line,words,&doc_label,&queryid,&slackid,&costfactor,&wnum,
|
||||
max_words_doc,&comment);
|
||||
totdoc++;
|
||||
if(model->kernel_parm.kernel_type == LINEAR) {/* For linear kernel, */
|
||||
for(j=0;(words[j]).wnum != 0;j++) { /* check if feature numbers */
|
||||
if((words[j]).wnum>model->totwords) /* are not larger than in */
|
||||
(words[j]).wnum=0; /* model. Remove feature if */
|
||||
} /* necessary. */
|
||||
}
|
||||
doc = create_example(-1,0,0,0.0,create_svector(words,comment,1.0));
|
||||
t1=get_runtime();
|
||||
|
||||
if(model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
|
||||
dist=classify_example_linear(model,doc);
|
||||
}
|
||||
else { /* non-linear kernel */
|
||||
dist=classify_example(model,doc);
|
||||
}
|
||||
|
||||
runtime+=(get_runtime()-t1);
|
||||
free_example(doc,1);
|
||||
|
||||
if(dist>0) {
|
||||
if(pred_format==0) { /* old weired output format */
|
||||
fprintf(predfl,"%.8g:+1 %.8g:-1\n",dist,-dist);
|
||||
}
|
||||
if(doc_label>0) correct++; else incorrect++;
|
||||
if(doc_label>0) res_a++; else res_b++;
|
||||
}
|
||||
else {
|
||||
if(pred_format==0) { /* old weired output format */
|
||||
fprintf(predfl,"%.8g:-1 %.8g:+1\n",-dist,dist);
|
||||
}
|
||||
if(doc_label<0) correct++; else incorrect++;
|
||||
if(doc_label>0) res_c++; else res_d++;
|
||||
}
|
||||
if(pred_format==1) { /* output the value of decision function */
|
||||
fprintf(predfl,"%.8g\n",dist);
|
||||
}
|
||||
if((int)(0.01+(doc_label*doc_label)) != 1)
|
||||
{ no_accuracy=1; } /* test data is not binary labeled */
|
||||
if(verbosity>=2) {
|
||||
if(totdoc % 100 == 0) {
|
||||
printf("%ld..",totdoc); fflush(stdout);
|
||||
}
|
||||
}
|
||||
}
|
||||
fclose(predfl);
|
||||
fclose(docfl);
|
||||
free(line);
|
||||
free(words);
|
||||
free_model(model,1);
|
||||
|
||||
if(verbosity>=2) {
|
||||
printf("done\n");
|
||||
|
||||
/* Note by Gary Boone Date: 29 April 2000 */
|
||||
/* o Timing is inaccurate. The timer has 0.01 second resolution. */
|
||||
/* Because classification of a single vector takes less than */
|
||||
/* 0.01 secs, the timer was underflowing. */
|
||||
printf("Runtime (without IO) in cpu-seconds: %.2f\n",
|
||||
(float)(runtime/100.0));
|
||||
|
||||
}
|
||||
if((!no_accuracy) && (verbosity>=1)) {
|
||||
printf("Accuracy on test set: %.2f%% (%ld correct, %ld incorrect, %ld total)\n",(float)(correct)*100.0/totdoc,correct,incorrect,totdoc);
|
||||
printf("Precision/recall on test set: %.2f%%/%.2f%%\n",(float)(res_a)*100.0/(res_a+res_b),(float)(res_a)*100.0/(res_a+res_c));
|
||||
}
|
||||
|
||||
return(0);
|
||||
}
|
||||
|
||||
void read_input_parameters(int argc, char **argv, char *docfile,
|
||||
char *modelfile, char *predictionsfile,
|
||||
long int *verbosity, long int *pred_format)
|
||||
{
|
||||
long i;
|
||||
|
||||
/* set default */
|
||||
strcpy (modelfile, "svm_model");
|
||||
strcpy (predictionsfile, "svm_predictions");
|
||||
(*verbosity)=2;
|
||||
(*pred_format)=1;
|
||||
|
||||
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||
switch ((argv[i])[1])
|
||||
{
|
||||
case 'h': print_help(); exit(0);
|
||||
case 'v': i++; (*verbosity)=atol(argv[i]); break;
|
||||
case 'f': i++; (*pred_format)=atol(argv[i]); break;
|
||||
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
if((i+1)>=argc) {
|
||||
printf("\nNot enough input parameters!\n\n");
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
strcpy (docfile, argv[i]);
|
||||
strcpy (modelfile, argv[i+1]);
|
||||
if((i+2)<argc) {
|
||||
strcpy (predictionsfile, argv[i+2]);
|
||||
}
|
||||
if(((*pred_format) != 0) && ((*pred_format) != 1)) {
|
||||
printf("\nOutput format can only take the values 0 or 1!\n\n");
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
void print_help(void)
|
||||
{
|
||||
printf("\nSVM-light %s: Support Vector Machine, classification module %s\n",VERSION,VERSION_DATE);
|
||||
copyright_notice();
|
||||
printf(" usage: svm_classify [options] example_file model_file output_file\n\n");
|
||||
printf("options: -h -> this help\n");
|
||||
printf(" -v [0..3] -> verbosity level (default 2)\n");
|
||||
printf(" -f [0,1] -> 0: old output format of V1.0\n");
|
||||
printf(" -> 1: output the value of decision function (default)\n\n");
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
1854
src/classifier/svm/svm_light/svm_common.c
Normal file
1854
src/classifier/svm/svm_light/svm_common.c
Normal file
File diff suppressed because it is too large
Load Diff
385
src/classifier/svm/svm_light/svm_common.h
Normal file
385
src/classifier/svm/svm_light/svm_common.h
Normal file
@@ -0,0 +1,385 @@
|
||||
/************************************************************************/
|
||||
/* */
|
||||
/* svm_common.h */
|
||||
/* */
|
||||
/* Definitions and functions used in both svm_learn and svm_classify. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 31.10.05 */
|
||||
/* */
|
||||
/* Copyright (c) 2005 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/************************************************************************/
|
||||
|
||||
#ifndef SVM_COMMON
|
||||
#define SVM_COMMON
|
||||
|
||||
#include <ctype.h>
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <time.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define VERSION "V6.20"
|
||||
#define VERSION_DATE "14.08.08"
|
||||
|
||||
#define CFLOAT float /* the type of float to use for caching */
|
||||
/* kernel evaluations. Using float saves */
|
||||
/* us some memory, but you can use double, too */
|
||||
#define FNUM int32_t /* the type used for storing feature ids */
|
||||
#define FNUM_MAX 2147483647 /* maximum value that FNUM type can take */
|
||||
#define FVAL float /* the type used for storing feature values */
|
||||
#define MAXFEATNUM \
|
||||
99999999 /* maximum feature number (must be in \
|
||||
valid range of FNUM type and long int!) */
|
||||
|
||||
#define LINEAR 0 /* linear kernel type */
|
||||
#define POLY 1 /* polynomial kernel type */
|
||||
#define RBF 2 /* rbf kernel type */
|
||||
#define SIGMOID 3 /* sigmoid kernel type */
|
||||
#define CUSTOM 4 /* userdefined kernel function from kernel.h */
|
||||
#define GRAM 5 /* use explicit gram matrix from kernel_parm */
|
||||
|
||||
#define CLASSIFICATION 1 /* train classification model */
|
||||
#define REGRESSION 2 /* train regression model */
|
||||
#define RANKING 3 /* train ranking model */
|
||||
#define OPTIMIZATION 4 /* train on general set of constraints */
|
||||
|
||||
#define MAXSHRINK 50000 /* maximum number of shrinking rounds */
|
||||
|
||||
typedef struct word {
|
||||
FNUM wnum; /* word number */
|
||||
FVAL weight; /* word weight */
|
||||
} WORD;
|
||||
|
||||
typedef struct svector {
|
||||
WORD *words; /* The features/values in the vector by
|
||||
increasing feature-number. Feature
|
||||
numbers that are skipped are
|
||||
interpreted as having value zero. */
|
||||
double twonorm_sq; /* The squared euclidian length of the
|
||||
vector. Used to speed up the RBF kernel. */
|
||||
char *userdefined; /* You can put additional information
|
||||
here. This can be useful, if you are
|
||||
implementing your own kernel that
|
||||
does not work with feature/values
|
||||
representations (for example a
|
||||
string kernel). By default,
|
||||
svm-light will put here the string
|
||||
after the # sign from each line of
|
||||
the input file. */
|
||||
long kernel_id; /* Feature vectors with different
|
||||
kernel_id's are orthogonal (ie. the
|
||||
feature number do not match). This
|
||||
is used for computing component
|
||||
kernels for linear constraints which
|
||||
are a sum of several different
|
||||
weight vectors. (currently not
|
||||
implemented). */
|
||||
struct svector *next; /* Let's you set up a list of SVECTOR's
|
||||
for linear constraints which are a
|
||||
sum of multiple feature
|
||||
vectors. List is terminated by
|
||||
NULL. */
|
||||
double factor; /* Factor by which this feature vector
|
||||
is multiplied in the sum. */
|
||||
} SVECTOR;
|
||||
|
||||
typedef struct doc {
|
||||
long docnum; /* Document ID. This has to be the position of
|
||||
the document in the training set array. */
|
||||
long queryid; /* for learning rankings, constraints are
|
||||
generated for documents with the same
|
||||
queryID. */
|
||||
double costfactor; /* Scales the cost of misclassifying this
|
||||
document by this factor. The effect of this
|
||||
value is, that the upper bound on the alpha
|
||||
for this example is scaled by this factor.
|
||||
The factors are set by the feature
|
||||
'cost:<val>' in the training data. */
|
||||
long slackid; /* Index of the slack variable
|
||||
corresponding to this
|
||||
constraint. All constraints with the
|
||||
same slackid share the same slack
|
||||
variable. This can only be used for
|
||||
svm_learn_optimization. */
|
||||
long kernelid; /* Position in gram matrix where kernel
|
||||
value can be found when using an
|
||||
explicit gram matrix
|
||||
(i.e. kernel_type=GRAM). */
|
||||
SVECTOR *fvec; /* Feature vector of the example. The
|
||||
feature vector can actually be a
|
||||
list of feature vectors. For
|
||||
example, the list will have two
|
||||
elements, if this DOC is a
|
||||
preference constraint. The one
|
||||
vector that is supposed to be ranked
|
||||
higher, will have a factor of +1,
|
||||
the lower ranked one should have a
|
||||
factor of -1. */
|
||||
} DOC;
|
||||
|
||||
typedef struct learn_parm {
|
||||
long type; /* selects between regression and
|
||||
classification */
|
||||
double svm_c; /* upper bound C on alphas */
|
||||
double eps; /* regression epsilon (eps=1.0 for
|
||||
classification */
|
||||
double svm_costratio; /* factor to multiply C for positive examples */
|
||||
double transduction_posratio; /* fraction of unlabeled examples to be */
|
||||
/* classified as positives */
|
||||
long biased_hyperplane; /* if nonzero, use hyperplane w*x+b=0
|
||||
otherwise w*x=0 */
|
||||
long sharedslack; /* if nonzero, it will use the shared
|
||||
slack variable mode in
|
||||
svm_learn_optimization. It requires
|
||||
that the slackid is set for every
|
||||
training example */
|
||||
long svm_maxqpsize; /* size q of working set */
|
||||
long svm_newvarsinqp; /* new variables to enter the working set
|
||||
in each iteration */
|
||||
long kernel_cache_size; /* size of kernel cache in megabytes */
|
||||
double epsilon_crit; /* tolerable error for distances used
|
||||
in stopping criterion */
|
||||
double epsilon_shrink; /* how much a multiplier should be above
|
||||
zero for shrinking */
|
||||
long svm_iter_to_shrink; /* iterations h after which an example can
|
||||
be removed by shrinking */
|
||||
long maxiter; /* number of iterations after which the
|
||||
optimizer terminates, if there was
|
||||
no progress in maxdiff */
|
||||
long remove_inconsistent; /* exclude examples with alpha at C and
|
||||
retrain */
|
||||
long skip_final_opt_check; /* do not check KT-Conditions at the end of
|
||||
optimization for examples removed by
|
||||
shrinking. WARNING: This might lead to
|
||||
sub-optimal solutions! */
|
||||
long compute_loo; /* if nonzero, computes leave-one-out
|
||||
estimates */
|
||||
double rho; /* parameter in xi/alpha-estimates and for
|
||||
pruning leave-one-out range [1..2] */
|
||||
long xa_depth; /* parameter in xi/alpha-estimates upper
|
||||
bounding the number of SV the current
|
||||
alpha_t is distributed over */
|
||||
char predfile[200]; /* file for predicitions on unlabeled examples
|
||||
in transduction */
|
||||
char alphafile[200]; /* file to store optimal alphas in. use
|
||||
empty string if alphas should not be
|
||||
output */
|
||||
|
||||
/* you probably do not want to touch the following */
|
||||
double epsilon_const; /* tolerable error on eq-constraint */
|
||||
double epsilon_a; /* tolerable error on alphas at bounds */
|
||||
double opt_precision; /* precision of solver, set to e.g. 1e-21
|
||||
if you get convergence problems */
|
||||
|
||||
/* the following are only for internal use */
|
||||
long svm_c_steps; /* do so many steps for finding optimal C */
|
||||
double svm_c_factor; /* increase C by this factor every step */
|
||||
double svm_costratio_unlab;
|
||||
double svm_unlabbound;
|
||||
double *svm_cost; /* individual upper bounds for each var */
|
||||
long totwords; /* number of features */
|
||||
} LEARN_PARM;
|
||||
|
||||
typedef struct matrix {
|
||||
int n; /* number of rows */
|
||||
int m; /* number of colums */
|
||||
double **element;
|
||||
} MATRIX;
|
||||
|
||||
typedef struct kernel_parm {
|
||||
long kernel_type; /* 0=linear, 1=poly, 2=rbf, 3=sigmoid,
|
||||
4=custom, 5=matrix */
|
||||
long poly_degree;
|
||||
double rbf_gamma;
|
||||
double coef_lin;
|
||||
double coef_const;
|
||||
char custom[50]; /* for user supplied kernel */
|
||||
MATRIX *gram_matrix; /* here one can directly supply the kernel
|
||||
matrix. The matrix is accessed if
|
||||
kernel_type=5 is selected. */
|
||||
} KERNEL_PARM;
|
||||
|
||||
typedef struct model {
|
||||
long sv_num;
|
||||
long at_upper_bound;
|
||||
double b;
|
||||
DOC **supvec;
|
||||
double *alpha;
|
||||
long *index; /* index from docnum to position in model */
|
||||
long totwords; /* number of features */
|
||||
long totdoc; /* number of training documents */
|
||||
KERNEL_PARM kernel_parm; /* kernel */
|
||||
|
||||
/* the following values are not written to file */
|
||||
double loo_error, loo_recall, loo_precision; /* leave-one-out estimates */
|
||||
double xa_error, xa_recall, xa_precision; /* xi/alpha estimates */
|
||||
double *lin_weights; /* weights for linear case using
|
||||
folding */
|
||||
double maxdiff; /* precision, up to which this
|
||||
model is accurate */
|
||||
} MODEL;
|
||||
|
||||
/* The following specifies a quadratic problem of the following form
|
||||
|
||||
minimize g0 * x + 1/2 x' * G * x
|
||||
subject to ce*x - ce0 = 0
|
||||
l <= x <= u
|
||||
*/
|
||||
typedef struct quadratic_program {
|
||||
long opt_n; /* number of variables */
|
||||
long opt_m; /* number of linear equality constraints */
|
||||
double *opt_ce, *opt_ce0; /* linear equality constraints
|
||||
opt_ce[i]*x - opt_ceo[i]=0 */
|
||||
double *opt_g; /* hessian of objective */
|
||||
double *opt_g0; /* linear part of objective */
|
||||
double *opt_xinit; /* initial value for variables */
|
||||
double *opt_low, *opt_up; /* box constraints */
|
||||
} QP;
|
||||
|
||||
typedef struct kernel_cache {
|
||||
long *index; /* cache some kernel evalutations */
|
||||
CFLOAT *buffer; /* to improve speed */
|
||||
long *invindex;
|
||||
long *active2totdoc;
|
||||
long *totdoc2active;
|
||||
long *lru;
|
||||
long *occu;
|
||||
long elems;
|
||||
long max_elems;
|
||||
long time;
|
||||
long activenum;
|
||||
long buffsize;
|
||||
} KERNEL_CACHE;
|
||||
|
||||
typedef struct timing_profile {
|
||||
double time_kernel;
|
||||
double time_opti;
|
||||
double time_shrink;
|
||||
double time_update;
|
||||
double time_model;
|
||||
double time_check;
|
||||
double time_select;
|
||||
} TIMING;
|
||||
|
||||
typedef struct shrink_state {
|
||||
long *active;
|
||||
long *inactive_since;
|
||||
long deactnum;
|
||||
double **a_history; /* for shrinking with non-linear kernel */
|
||||
long maxhistory;
|
||||
double *last_a; /* for shrinking with linear kernel */
|
||||
double *last_lin; /* for shrinking with linear kernel */
|
||||
} SHRINK_STATE;
|
||||
|
||||
typedef struct randpair {
|
||||
long val, sort;
|
||||
} RANDPAIR;
|
||||
|
||||
double classify_example(MODEL *, DOC *);
|
||||
double classify_example_linear(MODEL *, DOC *);
|
||||
double kernel(KERNEL_PARM *, DOC *, DOC *);
|
||||
double single_kernel(KERNEL_PARM *, SVECTOR *, SVECTOR *);
|
||||
double custom_kernel(KERNEL_PARM *, SVECTOR *, SVECTOR *);
|
||||
SVECTOR *create_svector(WORD *, char *, double);
|
||||
SVECTOR *create_svector_shallow(WORD *, char *, double);
|
||||
SVECTOR *create_svector_n(double *, long, char *, double);
|
||||
SVECTOR *create_svector_n_r(double *, long, char *, double, double);
|
||||
SVECTOR *copy_svector(SVECTOR *);
|
||||
SVECTOR *copy_svector_shallow(SVECTOR *);
|
||||
void free_svector(SVECTOR *);
|
||||
void free_svector_shallow(SVECTOR *);
|
||||
double sprod_ss(SVECTOR *, SVECTOR *);
|
||||
SVECTOR *sub_ss(SVECTOR *, SVECTOR *);
|
||||
SVECTOR *sub_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||
SVECTOR *add_ss(SVECTOR *, SVECTOR *);
|
||||
SVECTOR *add_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||
SVECTOR *multadd_ss(SVECTOR *a, SVECTOR *b, double fa, double fb);
|
||||
SVECTOR *multadd_ss_r(SVECTOR *a, SVECTOR *b, double fa, double fb,
|
||||
double min_non_zero);
|
||||
SVECTOR *add_list_ns(SVECTOR *a);
|
||||
SVECTOR *add_dual_list_ns_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||
SVECTOR *add_list_ns_r(SVECTOR *a, double min_non_zero);
|
||||
SVECTOR *add_list_ss(SVECTOR *);
|
||||
SVECTOR *add_dual_list_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||
SVECTOR *add_list_ss_r(SVECTOR *, double min_non_zero);
|
||||
SVECTOR *add_list_sort_ss(SVECTOR *);
|
||||
SVECTOR *add_dual_list_sort_ss_r(SVECTOR *, SVECTOR *, double min_non_zero);
|
||||
SVECTOR *add_list_sort_ss_r(SVECTOR *, double min_non_zero);
|
||||
void add_list_n_ns(double *vec_n, SVECTOR *vec_s, double faktor);
|
||||
void append_svector_list(SVECTOR *a, SVECTOR *b);
|
||||
void mult_svector_list(SVECTOR *a, double factor);
|
||||
void setfactor_svector_list(SVECTOR *a, double factor);
|
||||
SVECTOR *smult_s(SVECTOR *, double);
|
||||
SVECTOR *shift_s(SVECTOR *a, long shift);
|
||||
int featvec_eq(SVECTOR *, SVECTOR *);
|
||||
double model_length_s(MODEL *);
|
||||
double model_length_n(MODEL *);
|
||||
void mult_vector_ns(double *, SVECTOR *, double);
|
||||
void add_vector_ns(double *, SVECTOR *, double);
|
||||
double sprod_ns(double *, SVECTOR *);
|
||||
void add_weight_vector_to_linear_model(MODEL *);
|
||||
DOC *create_example(long, long, long, double, SVECTOR *);
|
||||
void free_example(DOC *, long);
|
||||
long *random_order(long n);
|
||||
void print_percent_progress(long *progress, long maximum, long percentperdot,
|
||||
char *symbol);
|
||||
MATRIX *create_matrix(int n, int m);
|
||||
MATRIX *realloc_matrix(MATRIX *matrix, int n, int m);
|
||||
double *create_nvector(int n);
|
||||
void clear_nvector(double *vec, long int n);
|
||||
MATRIX *copy_matrix(MATRIX *matrix);
|
||||
void free_matrix(MATRIX *matrix);
|
||||
void free_nvector(double *vector);
|
||||
MATRIX *transpose_matrix(MATRIX *matrix);
|
||||
MATRIX *cholesky_matrix(MATRIX *A);
|
||||
double *find_indep_subset_of_matrix(MATRIX *A, double epsilon);
|
||||
MATRIX *invert_ltriangle_matrix(MATRIX *L);
|
||||
double *prod_nvector_matrix(double *v, MATRIX *A);
|
||||
double *prod_matrix_nvector(MATRIX *A, double *v);
|
||||
double *prod_nvector_ltmatrix(double *v, MATRIX *A);
|
||||
double *prod_ltmatrix_nvector(MATRIX *A, double *v);
|
||||
MATRIX *prod_matrix_matrix(MATRIX *A, MATRIX *B);
|
||||
void print_matrix(MATRIX *matrix);
|
||||
MODEL *read_model(char *);
|
||||
MODEL *copy_model(MODEL *);
|
||||
MODEL *compact_linear_model(MODEL *model);
|
||||
void free_model(MODEL *, int);
|
||||
void read_documents(char *, DOC ***, double **, long *, long *);
|
||||
int parse_document(char *, WORD *, double *, long *, long *, double *, long *,
|
||||
long, char **);
|
||||
int read_word(char *in, char *out);
|
||||
double *read_alphas(char *, long);
|
||||
void set_learning_defaults(LEARN_PARM *, KERNEL_PARM *);
|
||||
int check_learning_parms(LEARN_PARM *, KERNEL_PARM *);
|
||||
void nol_ll(char *, long *, long *, long *);
|
||||
long minl(long, long);
|
||||
long maxl(long, long);
|
||||
double get_runtime(void);
|
||||
int space_or_null(int);
|
||||
void *my_malloc(size_t);
|
||||
void copyright_notice(void);
|
||||
#ifdef _MSC_VER
|
||||
int isnan(double);
|
||||
#endif
|
||||
|
||||
extern long verbosity; /* verbosity level (0-4) */
|
||||
extern long kernel_cache_statistic;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
1054
src/classifier/svm/svm_light/svm_hideo.c
Normal file
1054
src/classifier/svm/svm_light/svm_hideo.c
Normal file
File diff suppressed because it is too large
Load Diff
4223
src/classifier/svm/svm_light/svm_learn.c
Normal file
4223
src/classifier/svm/svm_light/svm_learn.c
Normal file
File diff suppressed because it is too large
Load Diff
169
src/classifier/svm/svm_light/svm_learn.h
Normal file
169
src/classifier/svm/svm_light/svm_learn.h
Normal file
@@ -0,0 +1,169 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_learn.h */
|
||||
/* */
|
||||
/* Declarations for learning module of Support Vector Machine. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 02.07.02 */
|
||||
/* */
|
||||
/* Copyright (c) 2002 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#ifndef SVM_LEARN
|
||||
#define SVM_LEARN
|
||||
|
||||
#include "svm_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void svm_learn_classification(DOC **, double *, long, long, LEARN_PARM *,
|
||||
KERNEL_PARM *, KERNEL_CACHE *, MODEL *, double *);
|
||||
void svm_learn_regression(DOC **, double *, long, long, LEARN_PARM *,
|
||||
KERNEL_PARM *, KERNEL_CACHE **, MODEL *);
|
||||
void svm_learn_ranking(DOC **, double *, long, long, LEARN_PARM *,
|
||||
KERNEL_PARM *, KERNEL_CACHE **, MODEL *);
|
||||
void svm_learn_optimization(DOC **, double *, long, long, LEARN_PARM *,
|
||||
KERNEL_PARM *, KERNEL_CACHE *, MODEL *, double *);
|
||||
long optimize_to_convergence(DOC **, long *, long, long, LEARN_PARM *,
|
||||
KERNEL_PARM *, KERNEL_CACHE *, SHRINK_STATE *,
|
||||
MODEL *, long *, long *, double *, double *,
|
||||
double *, TIMING *, double *, long, long);
|
||||
long optimize_to_convergence_sharedslack(DOC **, long *, long, long,
|
||||
LEARN_PARM *, KERNEL_PARM *,
|
||||
KERNEL_CACHE *, SHRINK_STATE *,
|
||||
MODEL *, double *, double *, double *,
|
||||
TIMING *, double *);
|
||||
double compute_objective_function(double *, double *, double *, double, long *,
|
||||
long *);
|
||||
void clear_index(long *);
|
||||
void add_to_index(long *, long);
|
||||
long compute_index(long *, long, long *);
|
||||
void optimize_svm(DOC **, long *, long *, long *, double, long *, long *,
|
||||
MODEL *, long, long *, long, double *, double *, double *,
|
||||
LEARN_PARM *, CFLOAT *, KERNEL_PARM *, QP *, double *);
|
||||
void compute_matrices_for_optimization(DOC **, long *, long *, long *, double,
|
||||
long *, long *, long *, MODEL *,
|
||||
double *, double *, double *, long, long,
|
||||
LEARN_PARM *, CFLOAT *, KERNEL_PARM *,
|
||||
QP *);
|
||||
long calculate_svm_model(DOC **, long *, long *, double *, double *, double *,
|
||||
double *, LEARN_PARM *, long *, long *, MODEL *);
|
||||
long check_optimality(MODEL *, long *, long *, double *, double *, double *,
|
||||
long, LEARN_PARM *, double *, double, long *, long *,
|
||||
long *, long *, long, KERNEL_PARM *);
|
||||
long check_optimality_sharedslack(
|
||||
DOC **docs, MODEL *model, long int *label, double *a, double *lin,
|
||||
double *c, double *slack, double *alphaslack, long int totdoc,
|
||||
LEARN_PARM *learn_parm, double *maxdiff, double epsilon_crit_org,
|
||||
long int *misclassified, long int *active2dnum,
|
||||
long int *last_suboptimal_at, long int iteration, KERNEL_PARM *kernel_parm);
|
||||
void compute_shared_slacks(DOC **docs, long int *label, double *a, double *lin,
|
||||
double *c, long int *active2dnum,
|
||||
LEARN_PARM *learn_parm, double *slack,
|
||||
double *alphaslack);
|
||||
long identify_inconsistent(double *, long *, long *, long, LEARN_PARM *, long *,
|
||||
long *);
|
||||
long identify_misclassified(double *, long *, long *, long, MODEL *, long *,
|
||||
long *);
|
||||
long identify_one_misclassified(double *, long *, long *, long, MODEL *, long *,
|
||||
long *);
|
||||
long incorporate_unlabeled_examples(MODEL *, long *, long *, long *, double *,
|
||||
double *, long, double *, long *, long *,
|
||||
long, KERNEL_PARM *, LEARN_PARM *);
|
||||
void update_linear_component(DOC **, long *, long *, double *, double *, long *,
|
||||
long, long, KERNEL_PARM *, KERNEL_CACHE *,
|
||||
double *, CFLOAT *, double *);
|
||||
long select_next_qp_subproblem_grad(long *, long *, double *, double *,
|
||||
double *, long, long, LEARN_PARM *, long *,
|
||||
long *, long *, double *, long *,
|
||||
KERNEL_CACHE *, long, long *, long *);
|
||||
long select_next_qp_subproblem_rand(long *, long *, double *, double *,
|
||||
double *, long, long, LEARN_PARM *, long *,
|
||||
long *, long *, double *, long *,
|
||||
KERNEL_CACHE *, long *, long *, long);
|
||||
long select_next_qp_slackset(DOC **docs, long int *label, double *a,
|
||||
double *lin, double *slack, double *alphaslack,
|
||||
double *c, LEARN_PARM *learn_parm,
|
||||
long int *active2dnum, double *maxviol);
|
||||
void select_top_n(double *, long, long *, long);
|
||||
void init_shrink_state(SHRINK_STATE *, long, long);
|
||||
void shrink_state_cleanup(SHRINK_STATE *);
|
||||
long shrink_problem(DOC **, LEARN_PARM *, SHRINK_STATE *, KERNEL_PARM *, long *,
|
||||
long *, long, long, long, double *, long *);
|
||||
void reactivate_inactive_examples(long *, long *, double *, SHRINK_STATE *,
|
||||
double *, double *, long, long, long,
|
||||
LEARN_PARM *, long *, DOC **, KERNEL_PARM *,
|
||||
KERNEL_CACHE *, MODEL *, CFLOAT *, double *,
|
||||
double *);
|
||||
|
||||
/* cache kernel evalutations to improve speed */
|
||||
KERNEL_CACHE *kernel_cache_init(long, long);
|
||||
void kernel_cache_cleanup(KERNEL_CACHE *);
|
||||
void get_kernel_row(KERNEL_CACHE *, DOC **, long, long, long *, CFLOAT *,
|
||||
KERNEL_PARM *);
|
||||
void cache_kernel_row(KERNEL_CACHE *, DOC **, long, KERNEL_PARM *);
|
||||
void cache_multiple_kernel_rows(KERNEL_CACHE *, DOC **, long *, long,
|
||||
KERNEL_PARM *);
|
||||
void kernel_cache_shrink(KERNEL_CACHE *, long, long, long *);
|
||||
void kernel_cache_reset_lru(KERNEL_CACHE *);
|
||||
long kernel_cache_malloc(KERNEL_CACHE *);
|
||||
void kernel_cache_free(KERNEL_CACHE *, long);
|
||||
long kernel_cache_free_lru(KERNEL_CACHE *);
|
||||
CFLOAT *kernel_cache_clean_and_malloc(KERNEL_CACHE *, long);
|
||||
long kernel_cache_touch(KERNEL_CACHE *, long);
|
||||
long kernel_cache_check(KERNEL_CACHE *, long);
|
||||
long kernel_cache_space_available(KERNEL_CACHE *);
|
||||
|
||||
void compute_xa_estimates(MODEL *, long *, long *, long, DOC **, double *,
|
||||
double *, KERNEL_PARM *, LEARN_PARM *, double *,
|
||||
double *, double *);
|
||||
double xa_estimate_error(MODEL *, long *, long *, long, DOC **, double *,
|
||||
double *, KERNEL_PARM *, LEARN_PARM *);
|
||||
double xa_estimate_recall(MODEL *, long *, long *, long, DOC **, double *,
|
||||
double *, KERNEL_PARM *, LEARN_PARM *);
|
||||
double xa_estimate_precision(MODEL *, long *, long *, long, DOC **, double *,
|
||||
double *, KERNEL_PARM *, LEARN_PARM *);
|
||||
void avg_similarity_of_sv_of_one_class(MODEL *, DOC **, double *, long *,
|
||||
KERNEL_PARM *, double *, double *);
|
||||
double most_similar_sv_of_same_class(MODEL *, DOC **, double *, long, long *,
|
||||
KERNEL_PARM *, LEARN_PARM *);
|
||||
double distribute_alpha_t_greedily(long *, long, DOC **, double *, long, long *,
|
||||
KERNEL_PARM *, LEARN_PARM *, double);
|
||||
double distribute_alpha_t_greedily_noindex(MODEL *, DOC **, double *, long,
|
||||
long *, KERNEL_PARM *, LEARN_PARM *,
|
||||
double);
|
||||
void estimate_transduction_quality(MODEL *, long *, long *, long, DOC **,
|
||||
double *);
|
||||
double estimate_margin_vcdim(MODEL *, double, double);
|
||||
double estimate_sphere(MODEL *);
|
||||
double estimate_r_delta_average(DOC **, long, KERNEL_PARM *);
|
||||
double estimate_r_delta(DOC **, long, KERNEL_PARM *);
|
||||
double length_of_longest_document_vector(DOC **, long, KERNEL_PARM *);
|
||||
|
||||
void write_model(char *, MODEL *);
|
||||
void write_prediction(char *, MODEL *, double *, double *, long *, long *, long,
|
||||
LEARN_PARM *);
|
||||
void write_alphas(char *, double *, long *, long);
|
||||
|
||||
typedef struct cache_parm_s {
|
||||
KERNEL_CACHE *kernel_cache;
|
||||
CFLOAT *cache;
|
||||
DOC **docs;
|
||||
long m;
|
||||
KERNEL_PARM *kernel_parm;
|
||||
long offset, stepsize;
|
||||
} cache_parm_t;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
303
src/classifier/svm/svm_light/svm_learn_main.c
Normal file
303
src/classifier/svm/svm_light/svm_learn_main.c
Normal file
@@ -0,0 +1,303 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_learn_main.c */
|
||||
/* */
|
||||
/* Command line interface to the learning module of the */
|
||||
/* Support Vector Machine. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 02.07.02 */
|
||||
/* */
|
||||
/* Copyright (c) 2000 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
|
||||
/* if svm-learn is used out of C++, define it as extern "C" */
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
# include "svm_common.h"
|
||||
# include "svm_learn.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
char docfile[200]; /* file with training examples */
|
||||
char modelfile[200]; /* file for resulting classifier */
|
||||
char restartfile[200]; /* file with initial alphas */
|
||||
|
||||
void read_input_parameters(int, char **, char *, char *, char *, long *,
|
||||
LEARN_PARM *, KERNEL_PARM *);
|
||||
void wait_any_key();
|
||||
void print_help();
|
||||
|
||||
|
||||
|
||||
int main (int argc, char* argv[])
|
||||
{
|
||||
DOC **docs; /* training examples */
|
||||
long totwords,totdoc,i;
|
||||
double *target;
|
||||
double *alpha_in=NULL;
|
||||
KERNEL_CACHE *kernel_cache;
|
||||
LEARN_PARM learn_parm;
|
||||
KERNEL_PARM kernel_parm;
|
||||
MODEL *model=(MODEL *)my_malloc(sizeof(MODEL));
|
||||
|
||||
read_input_parameters(argc,argv,docfile,modelfile,restartfile,&verbosity,
|
||||
&learn_parm,&kernel_parm);
|
||||
read_documents(docfile,&docs,&target,&totwords,&totdoc);
|
||||
if(restartfile[0]) alpha_in=read_alphas(restartfile,totdoc);
|
||||
|
||||
if(kernel_parm.kernel_type == LINEAR) { /* don't need the cache */
|
||||
kernel_cache=NULL;
|
||||
}
|
||||
else {
|
||||
/* Always get a new kernel cache. It is not possible to use the
|
||||
same cache for two different training runs */
|
||||
kernel_cache=kernel_cache_init(totdoc,learn_parm.kernel_cache_size);
|
||||
}
|
||||
|
||||
if(learn_parm.type == CLASSIFICATION) {
|
||||
svm_learn_classification(docs,target,totdoc,totwords,&learn_parm,
|
||||
&kernel_parm,kernel_cache,model,alpha_in);
|
||||
}
|
||||
else if(learn_parm.type == REGRESSION) {
|
||||
svm_learn_regression(docs,target,totdoc,totwords,&learn_parm,
|
||||
&kernel_parm,&kernel_cache,model);
|
||||
}
|
||||
else if(learn_parm.type == RANKING) {
|
||||
svm_learn_ranking(docs,target,totdoc,totwords,&learn_parm,
|
||||
&kernel_parm,&kernel_cache,model);
|
||||
}
|
||||
else if(learn_parm.type == OPTIMIZATION) {
|
||||
svm_learn_optimization(docs,target,totdoc,totwords,&learn_parm,
|
||||
&kernel_parm,kernel_cache,model,alpha_in);
|
||||
}
|
||||
|
||||
if(kernel_cache) {
|
||||
/* Free the memory used for the cache. */
|
||||
kernel_cache_cleanup(kernel_cache);
|
||||
}
|
||||
|
||||
/* Warning: The model contains references to the original data 'docs'.
|
||||
If you want to free the original data, and only keep the model, you
|
||||
have to make a deep copy of 'model'. */
|
||||
/* deep_copy_of_model=copy_model(model); */
|
||||
write_model(modelfile,model);
|
||||
|
||||
free(alpha_in);
|
||||
free_model(model,0);
|
||||
for(i=0;i<totdoc;i++)
|
||||
free_example(docs[i],1);
|
||||
free(docs);
|
||||
free(target);
|
||||
|
||||
return(0);
|
||||
}
|
||||
|
||||
/*---------------------------------------------------------------------------*/
|
||||
|
||||
void read_input_parameters(int argc,char *argv[],char *docfile,char *modelfile,
|
||||
char *restartfile,long *verbosity,
|
||||
LEARN_PARM *learn_parm,KERNEL_PARM *kernel_parm)
|
||||
{
|
||||
long i;
|
||||
char type[100];
|
||||
|
||||
/* set default */
|
||||
set_learning_defaults(learn_parm, kernel_parm);
|
||||
strcpy (modelfile, "svm_model");
|
||||
strcpy (restartfile, "");
|
||||
(*verbosity)=1;
|
||||
strcpy(type,"c");
|
||||
|
||||
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||
switch ((argv[i])[1])
|
||||
{
|
||||
case '?': print_help(); exit(0);
|
||||
case 'z': i++; strcpy(type,argv[i]); break;
|
||||
case 'v': i++; (*verbosity)=atol(argv[i]); break;
|
||||
case 'b': i++; learn_parm->biased_hyperplane=atol(argv[i]); break;
|
||||
case 'i': i++; learn_parm->remove_inconsistent=atol(argv[i]); break;
|
||||
case 'f': i++; learn_parm->skip_final_opt_check=!atol(argv[i]); break;
|
||||
case 'q': i++; learn_parm->svm_maxqpsize=atol(argv[i]); break;
|
||||
case 'n': i++; learn_parm->svm_newvarsinqp=atol(argv[i]); break;
|
||||
case '#': i++; learn_parm->maxiter=atol(argv[i]); break;
|
||||
case 'h': i++; learn_parm->svm_iter_to_shrink=atol(argv[i]); break;
|
||||
case 'm': i++; learn_parm->kernel_cache_size=atol(argv[i]); break;
|
||||
case 'c': i++; learn_parm->svm_c=atof(argv[i]); break;
|
||||
case 'w': i++; learn_parm->eps=atof(argv[i]); break;
|
||||
case 'p': i++; learn_parm->transduction_posratio=atof(argv[i]); break;
|
||||
case 'j': i++; learn_parm->svm_costratio=atof(argv[i]); break;
|
||||
case 'e': i++; learn_parm->epsilon_crit=atof(argv[i]); break;
|
||||
case 'o': i++; learn_parm->rho=atof(argv[i]); break;
|
||||
case 'k': i++; learn_parm->xa_depth=atol(argv[i]); break;
|
||||
case 'x': i++; learn_parm->compute_loo=atol(argv[i]); break;
|
||||
case 't': i++; kernel_parm->kernel_type=atol(argv[i]); break;
|
||||
case 'd': i++; kernel_parm->poly_degree=atol(argv[i]); break;
|
||||
case 'g': i++; kernel_parm->rbf_gamma=atof(argv[i]); break;
|
||||
case 's': i++; kernel_parm->coef_lin=atof(argv[i]); break;
|
||||
case 'r': i++; kernel_parm->coef_const=atof(argv[i]); break;
|
||||
case 'u': i++; strcpy(kernel_parm->custom,argv[i]); break;
|
||||
case 'l': i++; strcpy(learn_parm->predfile,argv[i]); break;
|
||||
case 'a': i++; strcpy(learn_parm->alphafile,argv[i]); break;
|
||||
case 'y': i++; strcpy(restartfile,argv[i]); break;
|
||||
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
if(i>=argc) {
|
||||
printf("\nNot enough input parameters!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
strcpy (docfile, argv[i]);
|
||||
if((i+1)<argc) {
|
||||
strcpy (modelfile, argv[i+1]);
|
||||
}
|
||||
if(learn_parm->svm_iter_to_shrink == -9999) {
|
||||
if(kernel_parm->kernel_type == LINEAR)
|
||||
learn_parm->svm_iter_to_shrink=2;
|
||||
else
|
||||
learn_parm->svm_iter_to_shrink=100;
|
||||
}
|
||||
if(strcmp(type,"c")==0) {
|
||||
learn_parm->type=CLASSIFICATION;
|
||||
}
|
||||
else if(strcmp(type,"r")==0) {
|
||||
learn_parm->type=REGRESSION;
|
||||
}
|
||||
else if(strcmp(type,"p")==0) {
|
||||
learn_parm->type=RANKING;
|
||||
}
|
||||
else if(strcmp(type,"o")==0) {
|
||||
learn_parm->type=OPTIMIZATION;
|
||||
}
|
||||
else if(strcmp(type,"s")==0) {
|
||||
learn_parm->type=OPTIMIZATION;
|
||||
learn_parm->sharedslack=1;
|
||||
}
|
||||
else {
|
||||
printf("\nUnknown type '%s': Valid types are 'c' (classification), 'r' regession, and 'p' preference ranking.\n",type);
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if (!check_learning_parms(learn_parm, kernel_parm)) {
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
void wait_any_key()
|
||||
{
|
||||
printf("\n(more)\n");
|
||||
(void)getc(stdin);
|
||||
}
|
||||
|
||||
void print_help()
|
||||
{
|
||||
printf("\nSVM-light %s: Support Vector Machine, learning module %s\n",VERSION,VERSION_DATE);
|
||||
copyright_notice();
|
||||
printf(" usage: svm_learn [options] example_file model_file\n\n");
|
||||
printf("Arguments:\n");
|
||||
printf(" example_file-> file with training data\n");
|
||||
printf(" model_file -> file to store learned decision rule in\n");
|
||||
|
||||
printf("General options:\n");
|
||||
printf(" -? -> this help\n");
|
||||
printf(" -v [0..3] -> verbosity level (default 1)\n");
|
||||
printf("Learning options:\n");
|
||||
printf(" -z {c,r,p} -> select between classification (c), regression (r),\n");
|
||||
printf(" and preference ranking (p) (default classification)\n");
|
||||
printf(" -c float -> C: trade-off between training error\n");
|
||||
printf(" and margin (default [avg. x*x]^-1)\n");
|
||||
printf(" -w [0..] -> epsilon width of tube for regression\n");
|
||||
printf(" (default 0.1)\n");
|
||||
printf(" -j float -> Cost: cost-factor, by which training errors on\n");
|
||||
printf(" positive examples outweight errors on negative\n");
|
||||
printf(" examples (default 1) (see [4])\n");
|
||||
printf(" -b [0,1] -> use biased hyperplane (i.e. x*w+b>0) instead\n");
|
||||
printf(" of unbiased hyperplane (i.e. x*w>0) (default 1)\n");
|
||||
printf(" -i [0,1] -> remove inconsistent training examples\n");
|
||||
printf(" and retrain (default 0)\n");
|
||||
printf("Performance estimation options:\n");
|
||||
printf(" -x [0,1] -> compute leave-one-out estimates (default 0)\n");
|
||||
printf(" (see [5])\n");
|
||||
printf(" -o ]0..2] -> value of rho for XiAlpha-estimator and for pruning\n");
|
||||
printf(" leave-one-out computation (default 1.0) (see [2])\n");
|
||||
printf(" -k [0..100] -> search depth for extended XiAlpha-estimator \n");
|
||||
printf(" (default 0)\n");
|
||||
printf("Transduction options (see [3]):\n");
|
||||
printf(" -p [0..1] -> fraction of unlabeled examples to be classified\n");
|
||||
printf(" into the positive class (default is the ratio of\n");
|
||||
printf(" positive and negative examples in the training data)\n");
|
||||
printf("Kernel options:\n");
|
||||
printf(" -t int -> type of kernel function:\n");
|
||||
printf(" 0: linear (default)\n");
|
||||
printf(" 1: polynomial (s a*b+c)^d\n");
|
||||
printf(" 2: radial basis function exp(-gamma ||a-b||^2)\n");
|
||||
printf(" 3: sigmoid tanh(s a*b + c)\n");
|
||||
printf(" 4: user defined kernel from kernel.h\n");
|
||||
printf(" -d int -> parameter d in polynomial kernel\n");
|
||||
printf(" -g float -> parameter gamma in rbf kernel\n");
|
||||
printf(" -s float -> parameter s in sigmoid/poly kernel\n");
|
||||
printf(" -r float -> parameter c in sigmoid/poly kernel\n");
|
||||
printf(" -u string -> parameter of user defined kernel\n");
|
||||
printf("Optimization options (see [1]):\n");
|
||||
printf(" -q [2..] -> maximum size of QP-subproblems (default 10)\n");
|
||||
printf(" -n [2..q] -> number of new variables entering the working set\n");
|
||||
printf(" in each iteration (default n = q). Set n < q to \n");
|
||||
printf(" prevent zig-zagging.\n");
|
||||
printf(" -m [5..] -> size of cache for kernel evaluations in MB (default 40)\n");
|
||||
printf(" The larger the faster...\n");
|
||||
printf(" -e float -> eps: Allow that error for termination criterion\n");
|
||||
printf(" [y [w*x+b] - 1] >= eps (default 0.001)\n");
|
||||
printf(" -y [0,1] -> restart the optimization from alpha values in file\n");
|
||||
printf(" specified by -a option. (default 0)\n");
|
||||
printf(" -h [5..] -> number of iterations a variable needs to be\n");
|
||||
printf(" optimal before considered for shrinking (default 100)\n");
|
||||
printf(" -f [0,1] -> do final optimality check for variables removed\n");
|
||||
printf(" by shrinking. Although this test is usually \n");
|
||||
printf(" positive, there is no guarantee that the optimum\n");
|
||||
printf(" was found if the test is omitted. (default 1)\n");
|
||||
printf(" -y string -> if option is given, reads alphas from file with given\n");
|
||||
printf(" and uses them as starting point. (default 'disabled')\n");
|
||||
printf(" -# int -> terminate optimization, if no progress after this\n");
|
||||
printf(" number of iterations. (default 100000)\n");
|
||||
printf("Output options:\n");
|
||||
printf(" -l string -> file to write predicted labels of unlabeled\n");
|
||||
printf(" examples into after transductive learning\n");
|
||||
printf(" -a string -> write all alphas to this file after learning\n");
|
||||
printf(" (in the same order as in the training set)\n");
|
||||
wait_any_key();
|
||||
printf("\nMore details in:\n");
|
||||
printf("[1] T. Joachims, Making Large-Scale SVM Learning Practical. Advances in\n");
|
||||
printf(" Kernel Methods - Support Vector Learning, B. Sch<63>lkopf and C. Burges and\n");
|
||||
printf(" A. Smola (ed.), MIT Press, 1999.\n");
|
||||
printf("[2] T. Joachims, Estimating the Generalization performance of an SVM\n");
|
||||
printf(" Efficiently. International Conference on Machine Learning (ICML), 2000.\n");
|
||||
printf("[3] T. Joachims, Transductive Inference for Text Classification using Support\n");
|
||||
printf(" Vector Machines. International Conference on Machine Learning (ICML),\n");
|
||||
printf(" 1999.\n");
|
||||
printf("[4] K. Morik, P. Brockhausen, and T. Joachims, Combining statistical learning\n");
|
||||
printf(" with a knowledge-based approach - A case study in intensive care \n");
|
||||
printf(" monitoring. International Conference on Machine Learning (ICML), 1999.\n");
|
||||
printf("[5] T. Joachims, Learning to Classify Text Using Support Vector\n");
|
||||
printf(" Machines: Methods, Theory, and Algorithms. Dissertation, Kluwer,\n");
|
||||
printf(" 2002.\n\n");
|
||||
}
|
||||
|
||||
|
211
src/classifier/svm/svm_light/svm_loqo.c
Normal file
211
src/classifier/svm/svm_light/svm_loqo.c
Normal file
@@ -0,0 +1,211 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_loqo.c */
|
||||
/* */
|
||||
/* Interface to the PR_LOQO optimization package for SVM. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 19.07.99 */
|
||||
/* */
|
||||
/* Copyright (c) 1999 Universitaet Dortmund - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
# include <math.h>
|
||||
# include "pr_loqo/pr_loqo.h"
|
||||
# include "svm_common.h"
|
||||
|
||||
/* Common Block Declarations */
|
||||
|
||||
long verbosity;
|
||||
|
||||
/* /////////////////////////////////////////////////////////////// */
|
||||
|
||||
# define DEF_PRECISION_LINEAR 1E-8
|
||||
# define DEF_PRECISION_NONLINEAR 1E-14
|
||||
|
||||
double *optimize_qp();
|
||||
double *primal=0,*dual=0;
|
||||
double init_margin=0.15;
|
||||
long init_iter=500,precision_violations=0;
|
||||
double model_b;
|
||||
double opt_precision=DEF_PRECISION_LINEAR;
|
||||
|
||||
/* /////////////////////////////////////////////////////////////// */
|
||||
|
||||
void *my_malloc();
|
||||
|
||||
double *optimize_qp(qp,epsilon_crit,nx,threshold,learn_parm)
|
||||
QP *qp;
|
||||
double *epsilon_crit;
|
||||
long nx; /* Maximum number of variables in QP */
|
||||
double *threshold;
|
||||
LEARN_PARM *learn_parm;
|
||||
/* start the optimizer and return the optimal values */
|
||||
{
|
||||
register long i,j,result;
|
||||
double margin,obj_before,obj_after;
|
||||
double sigdig,dist,epsilon_loqo;
|
||||
int iter;
|
||||
|
||||
if(!primal) { /* allocate memory at first call */
|
||||
primal=(double *)my_malloc(sizeof(double)*nx*3);
|
||||
dual=(double *)my_malloc(sizeof(double)*(nx*2+1));
|
||||
}
|
||||
|
||||
if(verbosity>=4) { /* really verbose */
|
||||
printf("\n\n");
|
||||
for(i=0;i<qp->opt_n;i++) {
|
||||
printf("%f: ",qp->opt_g0[i]);
|
||||
for(j=0;j<qp->opt_n;j++) {
|
||||
printf("%f ",qp->opt_g[i*qp->opt_n+j]);
|
||||
}
|
||||
printf(": a%ld=%.10f < %f",i,qp->opt_xinit[i],qp->opt_up[i]);
|
||||
printf(": y=%f\n",qp->opt_ce[i]);
|
||||
}
|
||||
for(j=0;j<qp->opt_m;j++) {
|
||||
printf("EQ-%ld: %f*a0",j,qp->opt_ce[j]);
|
||||
for(i=1;i<qp->opt_n;i++) {
|
||||
printf(" + %f*a%ld",qp->opt_ce[i],i);
|
||||
}
|
||||
printf(" = %f\n\n",-qp->opt_ce0[0]);
|
||||
}
|
||||
}
|
||||
|
||||
obj_before=0; /* calculate objective before optimization */
|
||||
for(i=0;i<qp->opt_n;i++) {
|
||||
obj_before+=(qp->opt_g0[i]*qp->opt_xinit[i]);
|
||||
obj_before+=(0.5*qp->opt_xinit[i]*qp->opt_xinit[i]*qp->opt_g[i*qp->opt_n+i]);
|
||||
for(j=0;j<i;j++) {
|
||||
obj_before+=(qp->opt_xinit[j]*qp->opt_xinit[i]*qp->opt_g[j*qp->opt_n+i]);
|
||||
}
|
||||
}
|
||||
|
||||
result=STILL_RUNNING;
|
||||
qp->opt_ce0[0]*=(-1.0);
|
||||
/* Run pr_loqo. If a run fails, try again with parameters which lead */
|
||||
/* to a slower, but more robust setting. */
|
||||
for(margin=init_margin,iter=init_iter;
|
||||
(margin<=0.9999999) && (result!=OPTIMAL_SOLUTION);) {
|
||||
sigdig=-log10(opt_precision);
|
||||
|
||||
result=pr_loqo((int)qp->opt_n,(int)qp->opt_m,
|
||||
(double *)qp->opt_g0,(double *)qp->opt_g,
|
||||
(double *)qp->opt_ce,(double *)qp->opt_ce0,
|
||||
(double *)qp->opt_low,(double *)qp->opt_up,
|
||||
(double *)primal,(double *)dual,
|
||||
(int)(verbosity-2),
|
||||
(double)sigdig,(int)iter,
|
||||
(double)margin,(double)(qp->opt_up[0])/4.0,(int)0);
|
||||
|
||||
if(isnan(dual[0])) { /* check for choldc problem */
|
||||
if(verbosity>=2) {
|
||||
printf("NOTICE: Restarting PR_LOQO with more conservative parameters.\n");
|
||||
}
|
||||
if(init_margin<0.80) { /* become more conservative in general */
|
||||
init_margin=(4.0*margin+1.0)/5.0;
|
||||
}
|
||||
margin=(margin+1.0)/2.0;
|
||||
(opt_precision)*=10.0; /* reduce precision */
|
||||
if(verbosity>=2) {
|
||||
printf("NOTICE: Reducing precision of PR_LOQO.\n");
|
||||
}
|
||||
}
|
||||
else if(result!=OPTIMAL_SOLUTION) {
|
||||
iter+=2000;
|
||||
init_iter+=10;
|
||||
(opt_precision)*=10.0; /* reduce precision */
|
||||
if(verbosity>=2) {
|
||||
printf("NOTICE: Reducing precision of PR_LOQO due to (%ld).\n",result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(qp->opt_m) /* Thanks to Alex Smola for this hint */
|
||||
model_b=dual[0];
|
||||
else
|
||||
model_b=0;
|
||||
|
||||
/* Check the precision of the alphas. If results of current optimization */
|
||||
/* violate KT-Conditions, relax the epsilon on the bounds on alphas. */
|
||||
epsilon_loqo=1E-10;
|
||||
for(i=0;i<qp->opt_n;i++) {
|
||||
dist=-model_b*qp->opt_ce[i];
|
||||
dist+=(qp->opt_g0[i]+1.0);
|
||||
for(j=0;j<i;j++) {
|
||||
dist+=(primal[j]*qp->opt_g[j*qp->opt_n+i]);
|
||||
}
|
||||
for(j=i;j<qp->opt_n;j++) {
|
||||
dist+=(primal[j]*qp->opt_g[i*qp->opt_n+j]);
|
||||
}
|
||||
/* printf("LOQO: a[%d]=%f, dist=%f, b=%f\n",i,primal[i],dist,dual[0]); */
|
||||
if((primal[i]<(qp->opt_up[i]-epsilon_loqo)) && (dist < (1.0-(*epsilon_crit)))) {
|
||||
epsilon_loqo=(qp->opt_up[i]-primal[i])*2.0;
|
||||
}
|
||||
else if((primal[i]>(0+epsilon_loqo)) && (dist > (1.0+(*epsilon_crit)))) {
|
||||
epsilon_loqo=primal[i]*2.0;
|
||||
}
|
||||
}
|
||||
|
||||
for(i=0;i<qp->opt_n;i++) { /* clip alphas to bounds */
|
||||
if(primal[i]<=(0+epsilon_loqo)) {
|
||||
primal[i]=0;
|
||||
}
|
||||
else if(primal[i]>=(qp->opt_up[i]-epsilon_loqo)) {
|
||||
primal[i]=qp->opt_up[i];
|
||||
}
|
||||
}
|
||||
|
||||
obj_after=0; /* calculate objective after optimization */
|
||||
for(i=0;i<qp->opt_n;i++) {
|
||||
obj_after+=(qp->opt_g0[i]*primal[i]);
|
||||
obj_after+=(0.5*primal[i]*primal[i]*qp->opt_g[i*qp->opt_n+i]);
|
||||
for(j=0;j<i;j++) {
|
||||
obj_after+=(primal[j]*primal[i]*qp->opt_g[j*qp->opt_n+i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* if optimizer returned NAN values, reset and retry with smaller */
|
||||
/* working set. */
|
||||
if(isnan(obj_after) || isnan(model_b)) {
|
||||
for(i=0;i<qp->opt_n;i++) {
|
||||
primal[i]=qp->opt_xinit[i];
|
||||
}
|
||||
model_b=0;
|
||||
if(learn_parm->svm_maxqpsize>2) {
|
||||
learn_parm->svm_maxqpsize--; /* decrease size of qp-subproblems */
|
||||
}
|
||||
}
|
||||
|
||||
if(obj_after >= obj_before) { /* check whether there was progress */
|
||||
(opt_precision)/=100.0;
|
||||
precision_violations++;
|
||||
if(verbosity>=2) {
|
||||
printf("NOTICE: Increasing Precision of PR_LOQO.\n");
|
||||
}
|
||||
}
|
||||
|
||||
if(precision_violations > 500) {
|
||||
(*epsilon_crit)*=10.0;
|
||||
precision_violations=0;
|
||||
if(verbosity>=1) {
|
||||
printf("\nWARNING: Relaxing epsilon on KT-Conditions.\n");
|
||||
}
|
||||
}
|
||||
|
||||
(*threshold)=model_b;
|
||||
|
||||
if(result!=OPTIMAL_SOLUTION) {
|
||||
printf("\nERROR: PR_LOQO did not converge. \n");
|
||||
return(qp->opt_xinit);
|
||||
}
|
||||
else {
|
||||
return(primal);
|
||||
}
|
||||
}
|
||||
|
70
src/classifier/svm/svm_multiclass_classifier.cpp
Normal file
70
src/classifier/svm/svm_multiclass_classifier.cpp
Normal file
@@ -0,0 +1,70 @@
|
||||
#include "svm_multiclass_classifier.hpp"
|
||||
#include "svm_struct/svm_struct_common.h"
|
||||
#include "svm_struct_api.h"
|
||||
|
||||
namespace ovclassifier {
|
||||
SVMMultiClassClassifier::SVMMultiClassClassifier() {}
|
||||
SVMMultiClassClassifier::~SVMMultiClassClassifier() {
|
||||
if (model_ != NULL) {
|
||||
free_struct_model(*model_);
|
||||
free(model_);
|
||||
model_ = NULL;
|
||||
}
|
||||
if (sparm_ != NULL) {
|
||||
free(sparm_);
|
||||
sparm_ = NULL;
|
||||
}
|
||||
}
|
||||
int SVMMultiClassClassifier::LoadModel(const char *modelfile) {
|
||||
if (model_ != NULL) {
|
||||
free_struct_model(*model_);
|
||||
}
|
||||
if (sparm_ != NULL) {
|
||||
free(sparm_);
|
||||
}
|
||||
model_ = (STRUCTMODEL *)my_malloc(sizeof(STRUCTMODEL));
|
||||
sparm_ = (STRUCT_LEARN_PARM *)my_malloc(sizeof(STRUCT_LEARN_PARM));
|
||||
(*model_) = read_struct_model((char *)modelfile, sparm_);
|
||||
if (model_->svm_model->kernel_parm.kernel_type ==
|
||||
LINEAR) { /* linear kernel */
|
||||
/* compute weight vector */
|
||||
add_weight_vector_to_linear_model(model_->svm_model);
|
||||
model_->w = model_->svm_model->lin_weights;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
double SVMMultiClassClassifier::Predict(const float *vec) { return 0; }
|
||||
|
||||
int SVMMultiClassClassifier::Classify(const float *vec,
|
||||
std::vector<float> &scores) {
|
||||
if (model_ == NULL || sparm_ == NULL) {
|
||||
return -1;
|
||||
}
|
||||
struct_verbosity = 5;
|
||||
int feats = sparm_->num_features;
|
||||
WORD *words = (WORD *)malloc(sizeof(WORD) * (feats + 10));
|
||||
for (int i = 0; i < (feats + 10); ++i) {
|
||||
if (i >= feats) {
|
||||
words[i].wnum = 0;
|
||||
words[i].weight = 0;
|
||||
} else {
|
||||
words[i].wnum = i + 1;
|
||||
words[i].weight = vec[i];
|
||||
}
|
||||
}
|
||||
DOC *doc =
|
||||
create_example(-1, 0, 0, 0.0, create_svector(words, (char *)"", 1.0));
|
||||
free(words);
|
||||
PATTERN pattern;
|
||||
pattern.doc = doc;
|
||||
LABEL y = classify_struct_example(pattern, model_, sparm_);
|
||||
free_pattern(pattern);
|
||||
scores.clear();
|
||||
for (int i = 1; i <= y.num_classes_; ++i) {
|
||||
scores.push_back(y.scores[i]);
|
||||
}
|
||||
free_label(y);
|
||||
return 0;
|
||||
}
|
||||
} // namespace ovclassifier
|
21
src/classifier/svm/svm_multiclass_classifier.hpp
Normal file
21
src/classifier/svm/svm_multiclass_classifier.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
#ifndef _CLASSIFIER_SVM_MULTICLASS_CLASSIFIER_H_
|
||||
#define _CLASSIFIER_SVM_MULTICLASS_CLASSIFIER_H_
|
||||
|
||||
#include "svm_classifier.hpp"
|
||||
#include "svm_struct_api_types.h"
|
||||
|
||||
namespace ovclassifier {
|
||||
class SVMMultiClassClassifier : public SVMClassifier {
|
||||
public:
|
||||
SVMMultiClassClassifier();
|
||||
~SVMMultiClassClassifier();
|
||||
int LoadModel(const char *modelfile);
|
||||
double Predict(const float *vec);
|
||||
int Classify(const float *vec, std::vector<float> &scores);
|
||||
|
||||
private:
|
||||
STRUCTMODEL *model_ = NULL;
|
||||
STRUCT_LEARN_PARM *sparm_ = NULL;
|
||||
};
|
||||
} // namespace ovclassifier
|
||||
#endif // !_CLASSIFIER_SVM_MULTICLASS_CLASSIFIER_H_
|
151
src/classifier/svm/svm_multiclass_trainer.cpp
Normal file
151
src/classifier/svm/svm_multiclass_trainer.cpp
Normal file
@@ -0,0 +1,151 @@
|
||||
#include "svm_multiclass_trainer.hpp"
|
||||
#include "svm_light/svm_learn.h"
|
||||
#include "svm_struct/svm_struct_learn.h"
|
||||
#include "svm_struct_api.h"
|
||||
|
||||
namespace ovclassifier {
|
||||
|
||||
SVMMultiClassTrainer::SVMMultiClassTrainer() {
|
||||
alg_type = DEFAULT_ALG_TYPE;
|
||||
struct_parm = (STRUCT_LEARN_PARM *)malloc(sizeof(STRUCT_LEARN_PARM));
|
||||
struct_parm->C = 10000;
|
||||
struct_parm->slack_norm = 1;
|
||||
struct_parm->epsilon = DEFAULT_EPS;
|
||||
struct_parm->custom_argc = 0;
|
||||
struct_parm->loss_function = DEFAULT_LOSS_FCT;
|
||||
struct_parm->loss_type = DEFAULT_RESCALING;
|
||||
struct_parm->newconstretrain = 100;
|
||||
struct_parm->ccache_size = 5;
|
||||
struct_parm->batch_size = 100;
|
||||
|
||||
learn_parm = (LEARN_PARM *)malloc(sizeof(LEARN_PARM));
|
||||
strcpy(learn_parm->predfile, "trans_predictions");
|
||||
strcpy(learn_parm->alphafile, "");
|
||||
learn_parm->biased_hyperplane = 1;
|
||||
learn_parm->remove_inconsistent = 0;
|
||||
learn_parm->skip_final_opt_check = 0;
|
||||
learn_parm->svm_maxqpsize = 10;
|
||||
learn_parm->svm_newvarsinqp = 0;
|
||||
// learn_parm->svm_iter_to_shrink = -9999;
|
||||
learn_parm->svm_iter_to_shrink = 100;
|
||||
learn_parm->maxiter = 100000;
|
||||
learn_parm->kernel_cache_size = 40;
|
||||
learn_parm->svm_c = 99999999; /* overridden by struct_parm->C */
|
||||
learn_parm->eps = 0.001; /* overridden by struct_parm->epsilon */
|
||||
learn_parm->transduction_posratio = -1.0;
|
||||
learn_parm->svm_costratio = 1.0;
|
||||
learn_parm->svm_costratio_unlab = 1.0;
|
||||
learn_parm->svm_unlabbound = 1E-5;
|
||||
learn_parm->epsilon_crit = 0.001;
|
||||
learn_parm->epsilon_a = 1E-10; /* changed from 1e-15 */
|
||||
learn_parm->compute_loo = 0;
|
||||
learn_parm->rho = 1.0;
|
||||
learn_parm->xa_depth = 0;
|
||||
kernel_parm = (KERNEL_PARM *)malloc(sizeof(KERNEL_PARM));
|
||||
kernel_parm->kernel_type = 0;
|
||||
kernel_parm->poly_degree = 3;
|
||||
kernel_parm->rbf_gamma = 1.0;
|
||||
kernel_parm->coef_lin = 1;
|
||||
kernel_parm->coef_const = 1;
|
||||
strcpy(kernel_parm->custom, "empty");
|
||||
|
||||
parse_struct_parameters(struct_parm);
|
||||
}
|
||||
|
||||
SVMMultiClassTrainer::~SVMMultiClassTrainer() {
|
||||
if (learn_parm != NULL) {
|
||||
free(learn_parm);
|
||||
learn_parm = NULL;
|
||||
}
|
||||
if (kernel_parm != NULL) {
|
||||
free(kernel_parm);
|
||||
kernel_parm = NULL;
|
||||
}
|
||||
if (learn_parm != NULL) {
|
||||
free(learn_parm);
|
||||
learn_parm = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void SVMMultiClassTrainer::Reset() {
|
||||
labels_ = 0;
|
||||
feats_ = 0;
|
||||
items_.clear();
|
||||
}
|
||||
void SVMMultiClassTrainer::SetLabels(int labels) { labels_ = labels; }
|
||||
void SVMMultiClassTrainer::SetFeatures(int feats) { feats_ = feats; }
|
||||
void SVMMultiClassTrainer::AddData(int label, const float *vec) {
|
||||
LabelItem itm;
|
||||
itm.label = label;
|
||||
for (int i = 0; i < feats_; ++i) {
|
||||
itm.vec.push_back(vec[i]);
|
||||
}
|
||||
items_.push_back(itm);
|
||||
}
|
||||
|
||||
int SVMMultiClassTrainer::Train(const char *modelfile) {
|
||||
struct_verbosity = 2;
|
||||
int totdoc = items_.size();
|
||||
if (totdoc == 0 || feats_ == 0 || labels_ == 0) {
|
||||
return -1;
|
||||
}
|
||||
EXAMPLE *examples = (EXAMPLE *)my_malloc(sizeof(EXAMPLE) * totdoc);
|
||||
WORD *words = (WORD *)my_malloc(sizeof(WORD) * (feats_ * 10));
|
||||
for (int dnum = 0; dnum < totdoc; ++dnum) {
|
||||
const int docFeats = items_[dnum].vec.size();
|
||||
for (int i = 0; i < (feats_ + 10); ++i) {
|
||||
if (i >= feats_) {
|
||||
words[i].wnum = 0;
|
||||
} else {
|
||||
(words[i]).wnum = i + 1;
|
||||
}
|
||||
if (i >= docFeats) {
|
||||
(words[i]).weight = 0;
|
||||
} else {
|
||||
(words[i]).weight = (FVAL)items_[dnum].vec[i];
|
||||
}
|
||||
}
|
||||
DOC *doc =
|
||||
create_example(dnum, 0, 0, 0, create_svector(words, (char *)"", 1.0));
|
||||
examples[dnum].x.doc = doc;
|
||||
examples[dnum].y.class_ = (double)items_[dnum].label + 0.1;
|
||||
examples[dnum].y.scores = NULL;
|
||||
examples[dnum].y.num_classes_ = (double)labels_ + 0.1;
|
||||
}
|
||||
free(words);
|
||||
|
||||
SAMPLE sample;
|
||||
sample.n = totdoc;
|
||||
sample.examples = examples;
|
||||
STRUCTMODEL structmodel;
|
||||
/* Do the learning and return structmodel. */
|
||||
if (alg_type == 0)
|
||||
svm_learn_struct(sample, struct_parm, learn_parm, kernel_parm, &structmodel,
|
||||
NSLACK_ALG);
|
||||
else if (alg_type == 1)
|
||||
svm_learn_struct(sample, struct_parm, learn_parm, kernel_parm, &structmodel,
|
||||
NSLACK_SHRINK_ALG);
|
||||
else if (alg_type == 2)
|
||||
svm_learn_struct_joint(sample, struct_parm, learn_parm, kernel_parm,
|
||||
&structmodel, ONESLACK_PRIMAL_ALG);
|
||||
else if (alg_type == 3)
|
||||
svm_learn_struct_joint(sample, struct_parm, learn_parm, kernel_parm,
|
||||
&structmodel, ONESLACK_DUAL_ALG);
|
||||
else if (alg_type == 4)
|
||||
svm_learn_struct_joint(sample, struct_parm, learn_parm, kernel_parm,
|
||||
&structmodel, ONESLACK_DUAL_CACHE_ALG);
|
||||
else if (alg_type == 9)
|
||||
svm_learn_struct_joint_custom(sample, struct_parm, learn_parm, kernel_parm,
|
||||
&structmodel);
|
||||
else
|
||||
return -1;
|
||||
|
||||
write_struct_model((char *)modelfile, &structmodel, struct_parm);
|
||||
|
||||
free_struct_sample(sample);
|
||||
free_struct_model(structmodel);
|
||||
|
||||
svm_struct_learn_api_exit();
|
||||
return 0;
|
||||
}
|
||||
} // namespace ovclassifier
|
31
src/classifier/svm/svm_multiclass_trainer.hpp
Normal file
31
src/classifier/svm/svm_multiclass_trainer.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
#ifndef _SVM_MULTCLASS_TRAINER_H_
|
||||
#define _SVM_MULTCLASS_TRAINER_H_
|
||||
|
||||
#include "svm_common.hpp"
|
||||
#include "svm_light/svm_common.h"
|
||||
#include "svm_struct_api_types.h"
|
||||
#include "svm_trainer.hpp"
|
||||
#include <vector>
|
||||
|
||||
namespace ovclassifier {
|
||||
class SVMMultiClassTrainer : public SVMTrainer {
|
||||
public:
|
||||
SVMMultiClassTrainer();
|
||||
~SVMMultiClassTrainer();
|
||||
void Reset();
|
||||
void SetLabels(int labels);
|
||||
void SetFeatures(int feats);
|
||||
void AddData(int label, const float *vec);
|
||||
int Train(const char *modelfile);
|
||||
|
||||
private:
|
||||
KERNEL_PARM *kernel_parm = NULL;
|
||||
LEARN_PARM *learn_parm = NULL;
|
||||
STRUCT_LEARN_PARM *struct_parm = NULL;
|
||||
int alg_type;
|
||||
int feats_;
|
||||
int labels_;
|
||||
std::vector<LabelItem> items_;
|
||||
};
|
||||
} // namespace ovclassifier
|
||||
#endif // _SVM_MULTICLASS_TRAINER_H_
|
186
src/classifier/svm/svm_struct/svm_struct_classify.c
Executable file
186
src/classifier/svm/svm_struct/svm_struct_classify.c
Executable file
@@ -0,0 +1,186 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_classify.c */
|
||||
/* */
|
||||
/* Classification module of SVM-struct. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 03.07.04 */
|
||||
/* */
|
||||
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/************************************************************************/
|
||||
|
||||
#include <stdio.h>
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#include "../svm_light/svm_common.h"
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#include "../svm_struct_api.h"
|
||||
#include "svm_struct_common.h"
|
||||
|
||||
char testfile[200];
|
||||
char modelfile[200];
|
||||
char predictionsfile[200];
|
||||
|
||||
void read_input_parameters(int, char **, char *, char *, char *,
|
||||
STRUCT_LEARN_PARM *, long*, long *);
|
||||
void print_help(void);
|
||||
|
||||
|
||||
int main (int argc, char* argv[])
|
||||
{
|
||||
long correct=0,incorrect=0,no_accuracy=0;
|
||||
long i;
|
||||
double t1,runtime=0;
|
||||
double avgloss=0,l;
|
||||
FILE *predfl;
|
||||
STRUCTMODEL model;
|
||||
STRUCT_LEARN_PARM sparm;
|
||||
STRUCT_TEST_STATS teststats;
|
||||
SAMPLE testsample;
|
||||
LABEL y;
|
||||
|
||||
svm_struct_classify_api_init(argc,argv);
|
||||
|
||||
read_input_parameters(argc,argv,testfile,modelfile,predictionsfile,&sparm,
|
||||
&verbosity,&struct_verbosity);
|
||||
|
||||
if(struct_verbosity>=1) {
|
||||
printf("Reading model..."); fflush(stdout);
|
||||
}
|
||||
model=read_struct_model(modelfile,&sparm);
|
||||
if(struct_verbosity>=1) {
|
||||
fprintf(stdout, "done.\n");
|
||||
}
|
||||
|
||||
if(model.svm_model->kernel_parm.kernel_type == LINEAR) { /* linear kernel */
|
||||
/* compute weight vector */
|
||||
add_weight_vector_to_linear_model(model.svm_model);
|
||||
model.w=model.svm_model->lin_weights;
|
||||
}
|
||||
|
||||
if(struct_verbosity>=1) {
|
||||
printf("Reading test examples..."); fflush(stdout);
|
||||
}
|
||||
testsample=read_struct_examples(testfile,&sparm);
|
||||
if(struct_verbosity>=1) {
|
||||
printf("done.\n"); fflush(stdout);
|
||||
}
|
||||
|
||||
if(struct_verbosity>=1) {
|
||||
printf("Classifying test examples..."); fflush(stdout);
|
||||
}
|
||||
|
||||
if ((predfl = fopen (predictionsfile, "w")) == NULL)
|
||||
{ perror (predictionsfile); exit (1); }
|
||||
|
||||
for(i=0;i<testsample.n;i++) {
|
||||
t1=get_runtime();
|
||||
y=classify_struct_example(testsample.examples[i].x,&model,&sparm);
|
||||
runtime+=(get_runtime()-t1);
|
||||
|
||||
write_label(predfl,y);
|
||||
l=loss(testsample.examples[i].y,y,&sparm);
|
||||
avgloss+=l;
|
||||
if(l == 0)
|
||||
correct++;
|
||||
else
|
||||
incorrect++;
|
||||
eval_prediction(i,testsample.examples[i],y,&model,&sparm,&teststats);
|
||||
|
||||
if(empty_label(testsample.examples[i].y))
|
||||
{ no_accuracy=1; } /* test data is not labeled */
|
||||
if(struct_verbosity>=2) {
|
||||
if((i+1) % 100 == 0) {
|
||||
printf("%ld..",i+1); fflush(stdout);
|
||||
}
|
||||
}
|
||||
free_label(y);
|
||||
}
|
||||
avgloss/=testsample.n;
|
||||
fclose(predfl);
|
||||
|
||||
if(struct_verbosity>=1) {
|
||||
printf("done\n");
|
||||
printf("Runtime (without IO) in cpu-seconds: %.2f\n",
|
||||
(float)(runtime/100.0));
|
||||
}
|
||||
if((!no_accuracy) && (struct_verbosity>=1)) {
|
||||
printf("Average loss on test set: %.4f\n",(float)avgloss);
|
||||
printf("Zero/one-error on test set: %.2f%% (%ld correct, %ld incorrect, %d total)\n",(float)100.0*incorrect/testsample.n,correct,incorrect,testsample.n);
|
||||
}
|
||||
print_struct_testing_stats(testsample,&model,&sparm,&teststats);
|
||||
free_struct_sample(testsample);
|
||||
free_struct_model(model);
|
||||
|
||||
svm_struct_classify_api_exit();
|
||||
|
||||
return(0);
|
||||
}
|
||||
|
||||
void read_input_parameters(int argc,char *argv[],char *testfile,
|
||||
char *modelfile,char *predictionsfile,
|
||||
STRUCT_LEARN_PARM *struct_parm,
|
||||
long *verbosity,long *struct_verbosity)
|
||||
{
|
||||
long i;
|
||||
|
||||
/* set default */
|
||||
strcpy (modelfile, "svm_model");
|
||||
strcpy (predictionsfile, "svm_predictions");
|
||||
(*verbosity)=0;/*verbosity for svm_light*/
|
||||
(*struct_verbosity)=1; /*verbosity for struct learning portion*/
|
||||
struct_parm->custom_argc=0;
|
||||
|
||||
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||
switch ((argv[i])[1])
|
||||
{
|
||||
case 'h': print_help(); exit(0);
|
||||
case '?': print_help(); exit(0);
|
||||
case '-': strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);i++; strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);break;
|
||||
case 'v': i++; (*struct_verbosity)=atol(argv[i]); break;
|
||||
case 'y': i++; (*verbosity)=atol(argv[i]); break;
|
||||
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
if((i+1)>=argc) {
|
||||
printf("\nNot enough input parameters!\n\n");
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
strcpy (testfile, argv[i]);
|
||||
strcpy (modelfile, argv[i+1]);
|
||||
if((i+2)<argc) {
|
||||
strcpy (predictionsfile, argv[i+2]);
|
||||
}
|
||||
|
||||
parse_struct_parameters_classify(struct_parm);
|
||||
}
|
||||
|
||||
void print_help(void)
|
||||
{
|
||||
printf("\nSVM-struct classification module: %s, %s, %s\n",INST_NAME,INST_VERSION,INST_VERSION_DATE);
|
||||
printf(" includes SVM-struct %s for learning complex outputs, %s\n",STRUCT_VERSION,STRUCT_VERSION_DATE);
|
||||
printf(" includes SVM-light %s quadratic optimizer, %s\n",VERSION,VERSION_DATE);
|
||||
copyright_notice();
|
||||
printf(" usage: svm_struct_classify [options] example_file model_file output_file\n\n");
|
||||
printf("options: -h -> this help\n");
|
||||
printf(" -v [0..3] -> verbosity level (default 2)\n\n");
|
||||
|
||||
print_struct_help_classify();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
66
src/classifier/svm/svm_struct/svm_struct_common.c
Normal file
66
src/classifier/svm/svm_struct/svm_struct_common.c
Normal file
@@ -0,0 +1,66 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_common.h */
|
||||
/* */
|
||||
/* Functions and types used by multiple components of SVM-struct. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 03.07.04 */
|
||||
/* */
|
||||
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "svm_struct_common.h"
|
||||
|
||||
long struct_verbosity; /* verbosity level (0-4) */
|
||||
|
||||
void printIntArray(int* x, int n)
|
||||
{
|
||||
int i;
|
||||
for(i=0;i<n;i++)
|
||||
printf("%i:",x[i]);
|
||||
}
|
||||
|
||||
void printDoubleArray(double* x, int n)
|
||||
{
|
||||
int i;
|
||||
for(i=0;i<n;i++)
|
||||
printf("%f:",x[i]);
|
||||
}
|
||||
|
||||
void printWordArray(WORD* x)
|
||||
{
|
||||
int i=0;
|
||||
for(;x[i].wnum!=0;i++)
|
||||
if(x[i].weight != 0)
|
||||
printf(" %i:%.2f ",(int)x[i].wnum,x[i].weight);
|
||||
}
|
||||
|
||||
void printW(double *w, long sizePhi, long n,double C)
|
||||
{
|
||||
int i;
|
||||
printf("---- w ----\n");
|
||||
for(i=0;i<sizePhi;i++)
|
||||
{
|
||||
printf("%f ",w[i]);
|
||||
}
|
||||
printf("\n----- xi ----\n");
|
||||
for(;i<sizePhi+2*n;i++)
|
||||
{
|
||||
printf("%f ",1/sqrt(2*C)*w[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
||||
}
|
||||
/**** end print methods ****/
|
||||
|
61
src/classifier/svm/svm_struct/svm_struct_common.h
Executable file
61
src/classifier/svm/svm_struct/svm_struct_common.h
Executable file
@@ -0,0 +1,61 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_common.h */
|
||||
/* */
|
||||
/* Functions and types used by multiple components of SVM-struct. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 31.10.05 */
|
||||
/* */
|
||||
/* Copyright (c) 2005 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#ifndef svm_struct_common
|
||||
#define svm_struct_common
|
||||
|
||||
# define STRUCT_VERSION "V3.10"
|
||||
# define STRUCT_VERSION_DATE "14.08.08"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#include "../svm_light/svm_common.h"
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#include "../svm_struct_api_types.h"
|
||||
|
||||
typedef struct example { /* an example is a pair of pattern and label */
|
||||
PATTERN x;
|
||||
LABEL y;
|
||||
} EXAMPLE;
|
||||
|
||||
typedef struct sample { /* a sample is a set of examples */
|
||||
int n; /* n is the total number of examples */
|
||||
EXAMPLE *examples;
|
||||
} SAMPLE;
|
||||
|
||||
typedef struct constset { /* a set of linear inequality constrains of
|
||||
for lhs[i]*w >= rhs[i] */
|
||||
int m; /* m is the total number of constrains */
|
||||
DOC **lhs;
|
||||
double *rhs;
|
||||
} CONSTSET;
|
||||
|
||||
|
||||
/**** print methods ****/
|
||||
void printIntArray(int*,int);
|
||||
void printDoubleArray(double*,int);
|
||||
void printWordArray(WORD*);
|
||||
void printModel(MODEL *);
|
||||
void printW(double *, long, long, double);
|
||||
|
||||
extern long struct_verbosity; /* verbosity level (0-4) */
|
||||
|
||||
#endif
|
1289
src/classifier/svm/svm_struct/svm_struct_learn.c
Executable file
1289
src/classifier/svm/svm_struct/svm_struct_learn.c
Executable file
File diff suppressed because it is too large
Load Diff
101
src/classifier/svm/svm_struct/svm_struct_learn.h
Executable file
101
src/classifier/svm/svm_struct/svm_struct_learn.h
Executable file
@@ -0,0 +1,101 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_learn.h */
|
||||
/* */
|
||||
/* Basic algorithm for learning structured outputs (e.g. parses, */
|
||||
/* sequences, multi-label classification) with a Support Vector */
|
||||
/* Machine. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 03.07.04 */
|
||||
/* */
|
||||
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#ifndef SVM_STRUCT_LEARN
|
||||
#define SVM_STRUCT_LEARN
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#include "../svm_light/svm_common.h"
|
||||
#include "../svm_light/svm_learn.h"
|
||||
#include "../svm_struct_api_types.h"
|
||||
#include "svm_struct_common.h"
|
||||
|
||||
#define SLACK_RESCALING 1
|
||||
#define MARGIN_RESCALING 2
|
||||
|
||||
#define NSLACK_ALG 0
|
||||
#define NSLACK_SHRINK_ALG 1
|
||||
#define ONESLACK_PRIMAL_ALG 2
|
||||
#define ONESLACK_DUAL_ALG 3
|
||||
#define ONESLACK_DUAL_CACHE_ALG 4
|
||||
|
||||
typedef struct ccacheelem {
|
||||
SVECTOR *fydelta; /* left hand side of constraint */
|
||||
double rhs; /* right hand side of constraint */
|
||||
double viol; /* violation score under current model */
|
||||
struct ccacheelem *next; /* next in linked list */
|
||||
} CCACHEELEM;
|
||||
|
||||
typedef struct ccache {
|
||||
int n; /* number of examples */
|
||||
CCACHEELEM **constlist; /* array of pointers to constraint lists
|
||||
- one list per example. The first
|
||||
element of the list always points to
|
||||
the most violated constraint under the
|
||||
current model for each example. */
|
||||
STRUCTMODEL *sm; /* pointer to model */
|
||||
double *avg_viol_gain; /* array of average values by which
|
||||
violation of globally most violated
|
||||
constraint exceeds that of most violated
|
||||
constraint in cache */
|
||||
int *changed; /* array of boolean indicating whether the
|
||||
most violated ybar change compared to
|
||||
last iter? */
|
||||
} CCACHE;
|
||||
|
||||
void find_most_violated_constraint(SVECTOR **fydelta, double *lossval,
|
||||
EXAMPLE *ex, SVECTOR *fycached, long n,
|
||||
STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
|
||||
double *rt_viol, double *rt_psi,
|
||||
long *argmax_count);
|
||||
CCACHE *create_constraint_cache(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||
STRUCTMODEL *sm);
|
||||
void free_constraint_cache(CCACHE *ccache);
|
||||
double add_constraint_to_constraint_cache(CCACHE *ccache, MODEL *svmModel,
|
||||
int exnum, SVECTOR *fydelta,
|
||||
double rhs, double gainthresh,
|
||||
int maxconst, double *rt_cachesum);
|
||||
void update_constraint_cache_for_model(CCACHE *ccache, MODEL *svmModel);
|
||||
double compute_violation_of_constraint_in_cache(CCACHE *ccache, double thresh);
|
||||
double find_most_violated_joint_constraint_in_cache(CCACHE *ccache,
|
||||
double thresh,
|
||||
double *lhs_n,
|
||||
SVECTOR **lhs, double *rhs);
|
||||
void svm_learn_struct(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm, STRUCTMODEL *sm,
|
||||
int alg_type);
|
||||
void svm_learn_struct_joint(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||
STRUCTMODEL *sm, int alg_type);
|
||||
void svm_learn_struct_joint_custom(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||
STRUCTMODEL *sm);
|
||||
void remove_inactive_constraints(CONSTSET *cset, double *alpha, long i,
|
||||
long *alphahist, long mininactive);
|
||||
MATRIX *init_kernel_matrix(CONSTSET *cset, KERNEL_PARM *kparm);
|
||||
MATRIX *update_kernel_matrix(MATRIX *matrix, int newpos, CONSTSET *cset,
|
||||
KERNEL_PARM *kparm);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
417
src/classifier/svm/svm_struct/svm_struct_main.c
Executable file
417
src/classifier/svm/svm_struct/svm_struct_main.c
Executable file
@@ -0,0 +1,417 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_main.c */
|
||||
/* */
|
||||
/* Command line interface to the alignment learning module of the */
|
||||
/* Support Vector Machine. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 03.07.04 */
|
||||
/* */
|
||||
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
|
||||
/* the following enables you to use svm-learn out of C++ */
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
#include "../svm_light/svm_common.h"
|
||||
#include "../svm_light/svm_learn.h"
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
# include "svm_struct_learn.h"
|
||||
# include "svm_struct_common.h"
|
||||
# include "../svm_struct_api.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <assert.h>
|
||||
/* } */
|
||||
|
||||
char trainfile[200]; /* file with training examples */
|
||||
char modelfile[200]; /* file for resulting classifier */
|
||||
|
||||
void read_input_parameters(int, char **, char *, char *,long *, long *,
|
||||
STRUCT_LEARN_PARM *, LEARN_PARM *, KERNEL_PARM *,
|
||||
int *);
|
||||
void wait_any_key();
|
||||
void print_help();
|
||||
|
||||
|
||||
int main (int argc, char* argv[])
|
||||
{
|
||||
SAMPLE sample; /* training sample */
|
||||
LEARN_PARM learn_parm;
|
||||
KERNEL_PARM kernel_parm;
|
||||
STRUCT_LEARN_PARM struct_parm;
|
||||
STRUCTMODEL structmodel;
|
||||
int alg_type;
|
||||
|
||||
svm_struct_learn_api_init(argc,argv);
|
||||
|
||||
read_input_parameters(argc,argv,trainfile,modelfile,&verbosity,
|
||||
&struct_verbosity,&struct_parm,&learn_parm,
|
||||
&kernel_parm,&alg_type);
|
||||
|
||||
if(struct_verbosity>=1) {
|
||||
printf("Reading training examples..."); fflush(stdout);
|
||||
}
|
||||
/* read the training examples */
|
||||
sample=read_struct_examples(trainfile,&struct_parm);
|
||||
if(struct_verbosity>=1) {
|
||||
printf("done\n"); fflush(stdout);
|
||||
}
|
||||
|
||||
/* Do the learning and return structmodel. */
|
||||
if(alg_type == 0)
|
||||
svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_ALG);
|
||||
else if(alg_type == 1)
|
||||
svm_learn_struct(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,NSLACK_SHRINK_ALG);
|
||||
else if(alg_type == 2)
|
||||
svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_PRIMAL_ALG);
|
||||
else if(alg_type == 3)
|
||||
svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_ALG);
|
||||
else if(alg_type == 4)
|
||||
svm_learn_struct_joint(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel,ONESLACK_DUAL_CACHE_ALG);
|
||||
else if(alg_type == 9)
|
||||
svm_learn_struct_joint_custom(sample,&struct_parm,&learn_parm,&kernel_parm,&structmodel);
|
||||
else
|
||||
exit(1);
|
||||
|
||||
/* Warning: The model contains references to the original data 'docs'.
|
||||
If you want to free the original data, and only keep the model, you
|
||||
have to make a deep copy of 'model'. */
|
||||
if(struct_verbosity>=1) {
|
||||
printf("Writing learned model...");fflush(stdout);
|
||||
}
|
||||
write_struct_model(modelfile,&structmodel,&struct_parm);
|
||||
if(struct_verbosity>=1) {
|
||||
printf("done\n");fflush(stdout);
|
||||
}
|
||||
|
||||
free_struct_sample(sample);
|
||||
free_struct_model(structmodel);
|
||||
|
||||
svm_struct_learn_api_exit();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*---------------------------------------------------------------------------*/
|
||||
|
||||
void read_input_parameters(int argc,char *argv[],char *trainfile,
|
||||
char *modelfile,
|
||||
long *verbosity,long *struct_verbosity,
|
||||
STRUCT_LEARN_PARM *struct_parm,
|
||||
LEARN_PARM *learn_parm, KERNEL_PARM *kernel_parm,
|
||||
int *alg_type)
|
||||
{
|
||||
long i;
|
||||
char type[100];
|
||||
|
||||
/* set default */
|
||||
(*alg_type)=DEFAULT_ALG_TYPE;
|
||||
struct_parm->C=-0.01;
|
||||
struct_parm->slack_norm=1;
|
||||
struct_parm->epsilon=DEFAULT_EPS;
|
||||
struct_parm->custom_argc=0;
|
||||
struct_parm->loss_function=DEFAULT_LOSS_FCT;
|
||||
struct_parm->loss_type=DEFAULT_RESCALING;
|
||||
struct_parm->newconstretrain=100;
|
||||
struct_parm->ccache_size=5;
|
||||
struct_parm->batch_size=100;
|
||||
|
||||
strcpy (modelfile, "svm_struct_model");
|
||||
strcpy (learn_parm->predfile, "trans_predictions");
|
||||
strcpy (learn_parm->alphafile, "");
|
||||
(*verbosity)=0;/*verbosity for svm_light*/
|
||||
(*struct_verbosity)=1; /*verbosity for struct learning portion*/
|
||||
learn_parm->biased_hyperplane=1;
|
||||
learn_parm->remove_inconsistent=0;
|
||||
learn_parm->skip_final_opt_check=0;
|
||||
learn_parm->svm_maxqpsize=10;
|
||||
learn_parm->svm_newvarsinqp=0;
|
||||
learn_parm->svm_iter_to_shrink=-9999;
|
||||
learn_parm->maxiter=100000;
|
||||
learn_parm->kernel_cache_size=40;
|
||||
learn_parm->svm_c=99999999; /* overridden by struct_parm->C */
|
||||
learn_parm->eps=0.001; /* overridden by struct_parm->epsilon */
|
||||
learn_parm->transduction_posratio=-1.0;
|
||||
learn_parm->svm_costratio=1.0;
|
||||
learn_parm->svm_costratio_unlab=1.0;
|
||||
learn_parm->svm_unlabbound=1E-5;
|
||||
learn_parm->epsilon_crit=0.001;
|
||||
learn_parm->epsilon_a=1E-10; /* changed from 1e-15 */
|
||||
learn_parm->compute_loo=0;
|
||||
learn_parm->rho=1.0;
|
||||
learn_parm->xa_depth=0;
|
||||
kernel_parm->kernel_type=0;
|
||||
kernel_parm->poly_degree=3;
|
||||
kernel_parm->rbf_gamma=1.0;
|
||||
kernel_parm->coef_lin=1;
|
||||
kernel_parm->coef_const=1;
|
||||
strcpy(kernel_parm->custom,"empty");
|
||||
strcpy(type,"c");
|
||||
|
||||
for(i=1;(i<argc) && ((argv[i])[0] == '-');i++) {
|
||||
switch ((argv[i])[1])
|
||||
{
|
||||
case '?': print_help(); exit(0);
|
||||
case 'a': i++; strcpy(learn_parm->alphafile,argv[i]); break;
|
||||
case 'c': i++; struct_parm->C=atof(argv[i]); break;
|
||||
case 'p': i++; struct_parm->slack_norm=atol(argv[i]); break;
|
||||
case 'e': i++; struct_parm->epsilon=atof(argv[i]); break;
|
||||
case 'k': i++; struct_parm->newconstretrain=atol(argv[i]); break;
|
||||
case 'h': i++; learn_parm->svm_iter_to_shrink=atol(argv[i]); break;
|
||||
case '#': i++; learn_parm->maxiter=atol(argv[i]); break;
|
||||
case 'm': i++; learn_parm->kernel_cache_size=atol(argv[i]); break;
|
||||
case 'w': i++; (*alg_type)=atol(argv[i]); break;
|
||||
case 'o': i++; struct_parm->loss_type=atol(argv[i]); break;
|
||||
case 'n': i++; learn_parm->svm_newvarsinqp=atol(argv[i]); break;
|
||||
case 'q': i++; learn_parm->svm_maxqpsize=atol(argv[i]); break;
|
||||
case 'l': i++; struct_parm->loss_function=atol(argv[i]); break;
|
||||
case 'f': i++; struct_parm->ccache_size=atol(argv[i]); break;
|
||||
case 'b': i++; struct_parm->batch_size=atof(argv[i]); break;
|
||||
case 't': i++; kernel_parm->kernel_type=atol(argv[i]); break;
|
||||
case 'd': i++; kernel_parm->poly_degree=atol(argv[i]); break;
|
||||
case 'g': i++; kernel_parm->rbf_gamma=atof(argv[i]); break;
|
||||
case 's': i++; kernel_parm->coef_lin=atof(argv[i]); break;
|
||||
case 'r': i++; kernel_parm->coef_const=atof(argv[i]); break;
|
||||
case 'u': i++; strcpy(kernel_parm->custom,argv[i]); break;
|
||||
case '-': strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);i++; strcpy(struct_parm->custom_argv[struct_parm->custom_argc++],argv[i]);break;
|
||||
case 'v': i++; (*struct_verbosity)=atol(argv[i]); break;
|
||||
case 'y': i++; (*verbosity)=atol(argv[i]); break;
|
||||
default: printf("\nUnrecognized option %s!\n\n",argv[i]);
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
if(i>=argc) {
|
||||
printf("\nNot enough input parameters!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
strcpy (trainfile, argv[i]);
|
||||
if((i+1)<argc) {
|
||||
strcpy (modelfile, argv[i+1]);
|
||||
}
|
||||
if(learn_parm->svm_iter_to_shrink == -9999) {
|
||||
learn_parm->svm_iter_to_shrink=100;
|
||||
}
|
||||
|
||||
if((learn_parm->skip_final_opt_check)
|
||||
&& (kernel_parm->kernel_type == LINEAR)) {
|
||||
printf("\nIt does not make sense to skip the final optimality check for linear kernels.\n\n");
|
||||
learn_parm->skip_final_opt_check=0;
|
||||
}
|
||||
if((learn_parm->skip_final_opt_check)
|
||||
&& (learn_parm->remove_inconsistent)) {
|
||||
printf("\nIt is necessary to do the final optimality check when removing inconsistent \nexamples.\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if((learn_parm->svm_maxqpsize<2)) {
|
||||
printf("\nMaximum size of QP-subproblems not in valid range: %ld [2..]\n",learn_parm->svm_maxqpsize);
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if((learn_parm->svm_maxqpsize<learn_parm->svm_newvarsinqp)) {
|
||||
printf("\nMaximum size of QP-subproblems [%ld] must be larger than the number of\n",learn_parm->svm_maxqpsize);
|
||||
printf("new variables [%ld] entering the working set in each iteration.\n",learn_parm->svm_newvarsinqp);
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(learn_parm->svm_iter_to_shrink<1) {
|
||||
printf("\nMaximum number of iterations for shrinking not in valid range: %ld [1,..]\n",learn_parm->svm_iter_to_shrink);
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(struct_parm->C<0) {
|
||||
printf("\nYou have to specify a value for the parameter '-c' (C>0)!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(((*alg_type) < 0) || (((*alg_type) > 5) && ((*alg_type) != 9))) {
|
||||
printf("\nAlgorithm type must be either '0', '1', '2', '3', '4', or '9'!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(learn_parm->transduction_posratio>1) {
|
||||
printf("\nThe fraction of unlabeled examples to classify as positives must\n");
|
||||
printf("be less than 1.0 !!!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(learn_parm->svm_costratio<=0) {
|
||||
printf("\nThe COSTRATIO parameter must be greater than zero!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(struct_parm->epsilon<=0) {
|
||||
printf("\nThe epsilon parameter must be greater than zero!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if((struct_parm->ccache_size<=0) && ((*alg_type) == 4)) {
|
||||
printf("\nThe cache size must be at least 1!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(((struct_parm->batch_size<=0) || (struct_parm->batch_size>100))
|
||||
&& ((*alg_type) == 4)) {
|
||||
printf("\nThe batch size must be in the interval ]0,100]!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if((struct_parm->slack_norm<1) || (struct_parm->slack_norm>2)) {
|
||||
printf("\nThe norm of the slacks must be either 1 (L1-norm) or 2 (L2-norm)!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if((struct_parm->loss_type != SLACK_RESCALING)
|
||||
&& (struct_parm->loss_type != MARGIN_RESCALING)) {
|
||||
printf("\nThe loss type must be either 1 (slack rescaling) or 2 (margin rescaling)!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if(learn_parm->rho<0) {
|
||||
printf("\nThe parameter rho for xi/alpha-estimates and leave-one-out pruning must\n");
|
||||
printf("be greater than zero (typically 1.0 or 2.0, see T. Joachims, Estimating the\n");
|
||||
printf("Generalization Performance of an SVM Efficiently, ICML, 2000.)!\n\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
if((learn_parm->xa_depth<0) || (learn_parm->xa_depth>100)) {
|
||||
printf("\nThe parameter depth for ext. xi/alpha-estimates must be in [0..100] (zero\n");
|
||||
printf("for switching to the conventional xa/estimates described in T. Joachims,\n");
|
||||
printf("Estimating the Generalization Performance of an SVM Efficiently, ICML, 2000.)\n");
|
||||
wait_any_key();
|
||||
print_help();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
parse_struct_parameters(struct_parm);
|
||||
}
|
||||
|
||||
void wait_any_key()
|
||||
{
|
||||
printf("\n(more)\n");
|
||||
(void)getc(stdin);
|
||||
}
|
||||
|
||||
void print_help()
|
||||
{
|
||||
printf("\nSVM-struct learning module: %s, %s, %s\n",INST_NAME,INST_VERSION,INST_VERSION_DATE);
|
||||
printf(" includes SVM-struct %s for learning complex outputs, %s\n",STRUCT_VERSION,STRUCT_VERSION_DATE);
|
||||
printf(" includes SVM-light %s quadratic optimizer, %s\n",VERSION,VERSION_DATE);
|
||||
copyright_notice();
|
||||
printf(" usage: svm_struct_learn [options] example_file model_file\n\n");
|
||||
printf("Arguments:\n");
|
||||
printf(" example_file-> file with training data\n");
|
||||
printf(" model_file -> file to store learned decision rule in\n");
|
||||
|
||||
printf("General Options:\n");
|
||||
printf(" -? -> this help\n");
|
||||
printf(" -v [0..3] -> verbosity level (default 1)\n");
|
||||
printf(" -y [0..3] -> verbosity level for svm_light (default 0)\n");
|
||||
printf("Learning Options:\n");
|
||||
printf(" -c float -> C: trade-off between training error\n");
|
||||
printf(" and margin (default 0.01)\n");
|
||||
printf(" -p [1,2] -> L-norm to use for slack variables. Use 1 for L1-norm,\n");
|
||||
printf(" use 2 for squared slacks. (default 1)\n");
|
||||
printf(" -o [1,2] -> Rescaling method to use for loss.\n");
|
||||
printf(" 1: slack rescaling\n");
|
||||
printf(" 2: margin rescaling\n");
|
||||
printf(" (default %d)\n",DEFAULT_RESCALING);
|
||||
printf(" -l [0..] -> Loss function to use.\n");
|
||||
printf(" 0: zero/one loss\n");
|
||||
printf(" ?: see below in application specific options\n");
|
||||
printf(" (default %d)\n",DEFAULT_LOSS_FCT);
|
||||
printf("Optimization Options (see [2][5]):\n");
|
||||
printf(" -w [0,..,9] -> choice of structural learning algorithm (default %d):\n",(int)DEFAULT_ALG_TYPE);
|
||||
printf(" 0: n-slack algorithm described in [2]\n");
|
||||
printf(" 1: n-slack algorithm with shrinking heuristic\n");
|
||||
printf(" 2: 1-slack algorithm (primal) described in [5]\n");
|
||||
printf(" 3: 1-slack algorithm (dual) described in [5]\n");
|
||||
printf(" 4: 1-slack algorithm (dual) with constraint cache [5]\n");
|
||||
printf(" 9: custom algorithm in svm_struct_learn_custom.c\n");
|
||||
printf(" -e float -> epsilon: allow that tolerance for termination\n");
|
||||
printf(" criterion (default %f)\n",DEFAULT_EPS);
|
||||
printf(" -k [1..] -> number of new constraints to accumulate before\n");
|
||||
printf(" recomputing the QP solution (default 100) (-w 0 and 1 only)\n");
|
||||
printf(" -f [5..] -> number of constraints to cache for each example\n");
|
||||
printf(" (default 5) (used with -w 4)\n");
|
||||
printf(" -b [1..100] -> percentage of training set for which to refresh cache\n");
|
||||
printf(" when no epsilon violated constraint can be constructed\n");
|
||||
printf(" from current cache (default 100%%) (used with -w 4)\n");
|
||||
printf("SVM-light Options for Solving QP Subproblems (see [3]):\n");
|
||||
printf(" -n [2..q] -> number of new variables entering the working set\n");
|
||||
printf(" in each svm-light iteration (default n = q). \n");
|
||||
printf(" Set n < q to prevent zig-zagging.\n");
|
||||
printf(" -m [5..] -> size of svm-light cache for kernel evaluations in MB\n");
|
||||
printf(" (default 40) (used only for -w 1 with kernels)\n");
|
||||
printf(" -h [5..] -> number of svm-light iterations a variable needs to be\n");
|
||||
printf(" optimal before considered for shrinking (default 100)\n");
|
||||
printf(" -# int -> terminate svm-light QP subproblem optimization, if no\n");
|
||||
printf(" progress after this number of iterations.\n");
|
||||
printf(" (default 100000)\n");
|
||||
printf("Kernel Options:\n");
|
||||
printf(" -t int -> type of kernel function:\n");
|
||||
printf(" 0: linear (default)\n");
|
||||
printf(" 1: polynomial (s a*b+c)^d\n");
|
||||
printf(" 2: radial basis function exp(-gamma ||a-b||^2)\n");
|
||||
printf(" 3: sigmoid tanh(s a*b + c)\n");
|
||||
printf(" 4: user defined kernel from kernel.h\n");
|
||||
printf(" -d int -> parameter d in polynomial kernel\n");
|
||||
printf(" -g float -> parameter gamma in rbf kernel\n");
|
||||
printf(" -s float -> parameter s in sigmoid/poly kernel\n");
|
||||
printf(" -r float -> parameter c in sigmoid/poly kernel\n");
|
||||
printf(" -u string -> parameter of user defined kernel\n");
|
||||
printf("Output Options:\n");
|
||||
printf(" -a string -> write all alphas to this file after learning\n");
|
||||
printf(" (in the same order as in the training set)\n");
|
||||
printf("Application-Specific Options:\n");
|
||||
print_struct_help();
|
||||
wait_any_key();
|
||||
|
||||
printf("\nMore details in:\n");
|
||||
printf("[1] T. Joachims, Learning to Align Sequences: A Maximum Margin Aproach.\n");
|
||||
printf(" Technical Report, September, 2003.\n");
|
||||
printf("[2] I. Tsochantaridis, T. Joachims, T. Hofmann, and Y. Altun, Large Margin\n");
|
||||
printf(" Methods for Structured and Interdependent Output Variables, Journal\n");
|
||||
printf(" of Machine Learning Research (JMLR), Vol. 6(Sep):1453-1484, 2005.\n");
|
||||
printf("[3] T. Joachims, Making Large-Scale SVM Learning Practical. Advances in\n");
|
||||
printf(" Kernel Methods - Support Vector Learning, B. Sch<63>lkopf and C. Burges and\n");
|
||||
printf(" A. Smola (ed.), MIT Press, 1999.\n");
|
||||
printf("[4] T. Joachims, Learning to Classify Text Using Support Vector\n");
|
||||
printf(" Machines: Methods, Theory, and Algorithms. Dissertation, Kluwer,\n");
|
||||
printf(" 2002.\n");
|
||||
printf("[5] T. Joachims, T. Finley, Chun-Nam Yu, Cutting-Plane Training of Structural\n");
|
||||
printf(" SVMs, Machine Learning Journal, to appear.\n");
|
||||
}
|
||||
|
||||
|
||||
|
615
src/classifier/svm/svm_struct_api.c
Executable file
615
src/classifier/svm/svm_struct_api.c
Executable file
@@ -0,0 +1,615 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_api.c */
|
||||
/* */
|
||||
/* Definition of API for attaching implementing SVM learning of */
|
||||
/* structures (e.g. parsing, multi-label classification, HMM) */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 03.07.04 */
|
||||
/* */
|
||||
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#include "svm_struct_api.h"
|
||||
#include "svm_struct/svm_struct_common.h"
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
void svm_struct_learn_api_init(int argc, char *argv[]) {
|
||||
/* Called in learning part before anything else is done to allow
|
||||
any initializations that might be necessary. */
|
||||
}
|
||||
|
||||
void svm_struct_learn_api_exit() {
|
||||
/* Called in learning part at the very end to allow any clean-up
|
||||
that might be necessary. */
|
||||
}
|
||||
|
||||
void svm_struct_classify_api_init(int argc, char *argv[]) {
|
||||
/* Called in prediction part before anything else is done to allow
|
||||
any initializations that might be necessary. */
|
||||
}
|
||||
|
||||
void svm_struct_classify_api_exit() {
|
||||
/* Called in prediction part at the very end to allow any clean-up
|
||||
that might be necessary. */
|
||||
}
|
||||
|
||||
SAMPLE read_struct_examples(char *file, STRUCT_LEARN_PARM *sparm) {
|
||||
/* Reads training examples and returns them in sample. The number of
|
||||
examples must be written into sample.n */
|
||||
SAMPLE sample; /* sample */
|
||||
EXAMPLE *examples;
|
||||
long n; /* number of examples */
|
||||
DOC **docs; /* examples in original SVM-light format */
|
||||
double *target;
|
||||
long totwords, i, num_classes = 0;
|
||||
|
||||
/* Using the read_documents function from SVM-light */
|
||||
read_documents(file, &docs, &target, &totwords, &n);
|
||||
examples = (EXAMPLE *)my_malloc(sizeof(EXAMPLE) * n);
|
||||
for (i = 0; i < n; i++) /* find highest class label */
|
||||
if (num_classes < (target[i] + 0.1))
|
||||
num_classes = target[i] + 0.1;
|
||||
for (i = 0; i < n; i++) /* make sure all class labels are positive */
|
||||
if (target[i] < 1) {
|
||||
printf("\nERROR: The class label '%lf' of example number %ld is not "
|
||||
"greater than '1'!\n",
|
||||
target[i], i + 1);
|
||||
exit(1);
|
||||
}
|
||||
for (i = 0; i < n; i++) { /* copy docs over into new datastructure */
|
||||
examples[i].x.doc = docs[i];
|
||||
examples[i].y.class_ = target[i] + 0.1;
|
||||
examples[i].y.scores = NULL;
|
||||
examples[i].y.num_classes_ = num_classes;
|
||||
}
|
||||
free(target);
|
||||
free(docs);
|
||||
sample.n = n;
|
||||
sample.examples = examples;
|
||||
|
||||
if (struct_verbosity >= 0)
|
||||
printf(" (%d examples) ", sample.n);
|
||||
return (sample);
|
||||
}
|
||||
|
||||
void init_struct_model(SAMPLE sample, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm) {
|
||||
/* Initialize structmodel sm. The weight vector w does not need to be
|
||||
initialized, but you need to provide the maximum size of the
|
||||
feature space in sizePsi. This is the maximum number of different
|
||||
weights that can be learned. Later, the weight vector w will
|
||||
contain the learned weights for the model. */
|
||||
long i, totwords = 0;
|
||||
WORD *w;
|
||||
|
||||
sparm->num_classes_ = 1;
|
||||
for (i = 0; i < sample.n; i++) /* find highest class label */
|
||||
if (sparm->num_classes_ < (sample.examples[i].y.class_ + 0.1))
|
||||
sparm->num_classes_ = sample.examples[i].y.class_ + 0.1;
|
||||
for (i = 0; i < sample.n; i++) /* find highest feature number */
|
||||
for (w = sample.examples[i].x.doc->fvec->words; w->wnum; w++)
|
||||
if (totwords < w->wnum)
|
||||
totwords = w->wnum;
|
||||
sparm->num_features = totwords;
|
||||
if (struct_verbosity >= 0)
|
||||
printf("Training set properties: %d features, %d classes\n",
|
||||
sparm->num_features, sparm->num_classes_);
|
||||
sm->sizePsi = sparm->num_features * sparm->num_classes_;
|
||||
if (struct_verbosity >= 2)
|
||||
printf("Size of Phi: %ld\n", sm->sizePsi);
|
||||
}
|
||||
|
||||
CONSTSET init_struct_constraints(SAMPLE sample, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm) {
|
||||
/* Initializes the optimization problem. Typically, you do not need
|
||||
to change this function, since you want to start with an empty
|
||||
set of constraints. However, if for example you have constraints
|
||||
that certain weights need to be positive, you might put that in
|
||||
here. The constraints are represented as lhs[i]*w >= rhs[i]. lhs
|
||||
is an array of feature vectors, rhs is an array of doubles. m is
|
||||
the number of constraints. The function returns the initial
|
||||
set of constraints. */
|
||||
CONSTSET c;
|
||||
long sizePsi = sm->sizePsi;
|
||||
long i;
|
||||
WORD words[2];
|
||||
|
||||
if (1) { /* normal case: start with empty set of constraints */
|
||||
c.lhs = NULL;
|
||||
c.rhs = NULL;
|
||||
c.m = 0;
|
||||
} else { /* add constraints so that all learned weights are
|
||||
positive. WARNING: Currently, they are positive only up to
|
||||
precision epsilon set by -e. */
|
||||
c.lhs = my_malloc(sizeof(DOC *) * sizePsi);
|
||||
c.rhs = my_malloc(sizeof(double) * sizePsi);
|
||||
for (i = 0; i < sizePsi; i++) {
|
||||
words[0].wnum = i + 1;
|
||||
words[0].weight = 1.0;
|
||||
words[1].wnum = 0;
|
||||
/* the following slackid is a hack. we will run into problems,
|
||||
if we have move than 1000000 slack sets (ie examples) */
|
||||
c.lhs[i] = create_example(i, 0, 1000000 + i, 1,
|
||||
create_svector(words, NULL, 1.0));
|
||||
c.rhs[i] = 0.0;
|
||||
}
|
||||
}
|
||||
return (c);
|
||||
}
|
||||
|
||||
LABEL classify_struct_example(PATTERN x, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm) {
|
||||
/* Finds the label yhat for pattern x that scores the highest
|
||||
according to the linear evaluation function in sm, especially the
|
||||
weights sm.w. The returned label is taken as the prediction of sm
|
||||
for the pattern x. The weights correspond to the features defined
|
||||
by psi() and range from index 1 to index sm->sizePsi. If the
|
||||
function cannot find a label, it shall return an empty label as
|
||||
recognized by the function empty_label(y). */
|
||||
LABEL y;
|
||||
DOC doc;
|
||||
long class_, bestclass = -1, first = 1, j;
|
||||
double score, bestscore = -1;
|
||||
WORD *words;
|
||||
|
||||
doc = *(x.doc);
|
||||
y.scores = (double *)my_malloc(sizeof(double) * (sparm->num_classes_ + 1));
|
||||
y.num_classes_ = sparm->num_classes_;
|
||||
words = doc.fvec->words;
|
||||
for (j = 0; (words[j]).wnum != 0; j++) { /* Check if feature numbers */
|
||||
if ((words[j]).wnum > sparm->num_features) /* are not larger than in */
|
||||
(words[j]).wnum = 0; /* model. Remove feature if */
|
||||
} /* necessary. */
|
||||
for (class_ = 1; class_ <= sparm->num_classes_; class_++) {
|
||||
y.class_ = class_;
|
||||
doc.fvec = psi(x, y, sm, sparm);
|
||||
score = classify_example(sm->svm_model, &doc);
|
||||
free_svector(doc.fvec);
|
||||
y.scores[class_] = score;
|
||||
if ((bestscore < score) || (first)) {
|
||||
bestscore = score;
|
||||
bestclass = class_;
|
||||
first = 0;
|
||||
}
|
||||
}
|
||||
y.class_ = bestclass;
|
||||
return (y);
|
||||
}
|
||||
|
||||
LABEL find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y,
|
||||
STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm) {
|
||||
/* Finds the label ybar for pattern x that that is responsible for
|
||||
the most violated constraint for the slack rescaling
|
||||
formulation. It has to take into account the scoring function in
|
||||
sm, especially the weights sm.w, as well as the loss
|
||||
function. The weights in sm.w correspond to the features defined
|
||||
by psi() and range from index 1 to index sm->sizePsi. Most simple
|
||||
is the case of the zero/one loss function. For the zero/one loss,
|
||||
this function should return the highest scoring label ybar, if
|
||||
ybar is unequal y; if it is equal to the correct label y, then
|
||||
the function shall return the second highest scoring label. If
|
||||
the function cannot find a label, it shall return an empty label
|
||||
as recognized by the function empty_label(y). */
|
||||
LABEL ybar;
|
||||
DOC doc;
|
||||
long class_, bestclass = -1, first = 1;
|
||||
double score, score_y, score_ybar, bestscore = -1;
|
||||
|
||||
/* NOTE: This function could be made much more efficient by not
|
||||
always computing a new PSI vector. */
|
||||
doc = *(x.doc);
|
||||
doc.fvec = psi(x, y, sm, sparm);
|
||||
score_y = classify_example(sm->svm_model, &doc);
|
||||
free_svector(doc.fvec);
|
||||
|
||||
ybar.scores = NULL;
|
||||
ybar.num_classes_ = sparm->num_classes_;
|
||||
for (class_ = 1; class_ <= sparm->num_classes_; class_++) {
|
||||
ybar.class_ = class_;
|
||||
doc.fvec = psi(x, ybar, sm, sparm);
|
||||
score_ybar = classify_example(sm->svm_model, &doc);
|
||||
free_svector(doc.fvec);
|
||||
score = loss(y, ybar, sparm) * (1.0 - score_y + score_ybar);
|
||||
if ((bestscore < score) || (first)) {
|
||||
bestscore = score;
|
||||
bestclass = class_;
|
||||
first = 0;
|
||||
}
|
||||
}
|
||||
if (bestclass == -1)
|
||||
printf("ERROR: Only one class\n");
|
||||
ybar.class_ = bestclass;
|
||||
if (struct_verbosity >= 3)
|
||||
printf("[%ld:%.2f] ", bestclass, bestscore);
|
||||
return (ybar);
|
||||
}
|
||||
|
||||
LABEL find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y,
|
||||
STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm) {
|
||||
/* Finds the label ybar for pattern x that that is responsible for
|
||||
the most violated constraint for the margin rescaling
|
||||
formulation. It has to take into account the scoring function in
|
||||
sm, especially the weights sm.w, as well as the loss
|
||||
function. The weights in sm.w correspond to the features defined
|
||||
by psi() and range from index 1 to index sm->sizePsi. Most simple
|
||||
is the case of the zero/one loss function. For the zero/one loss,
|
||||
this function should return the highest scoring label ybar, if
|
||||
ybar is unequal y; if it is equal to the correct label y, then
|
||||
the function shall return the second highest scoring label. If
|
||||
the function cannot find a label, it shall return an empty label
|
||||
as recognized by the function empty_label(y). */
|
||||
LABEL ybar;
|
||||
DOC doc;
|
||||
long class_, bestclass = -1, first = 1;
|
||||
double score, bestscore = -1;
|
||||
|
||||
/* NOTE: This function could be made much more efficient by not
|
||||
always computing a new PSI vector. */
|
||||
doc = *(x.doc);
|
||||
ybar.scores = NULL;
|
||||
ybar.num_classes_ = sparm->num_classes_;
|
||||
for (class_ = 1; class_ <= sparm->num_classes_; class_++) {
|
||||
ybar.class_ = class_;
|
||||
doc.fvec = psi(x, ybar, sm, sparm);
|
||||
score = classify_example(sm->svm_model, &doc);
|
||||
free_svector(doc.fvec);
|
||||
score += loss(y, ybar, sparm);
|
||||
if ((bestscore < score) || (first)) {
|
||||
bestscore = score;
|
||||
bestclass = class_;
|
||||
first = 0;
|
||||
}
|
||||
}
|
||||
if (bestclass == -1)
|
||||
printf("ERROR: Only one class\n");
|
||||
ybar.class_ = bestclass;
|
||||
if (struct_verbosity >= 3)
|
||||
printf("[%ld:%.2f] ", bestclass, bestscore);
|
||||
return (ybar);
|
||||
}
|
||||
|
||||
int empty_label(LABEL y) {
|
||||
/* Returns true, if y is an empty label. An empty label might be
|
||||
returned by find_most_violated_constraint_???(x, y, sm) if there
|
||||
is no incorrect label that can be found for x, or if it is unable
|
||||
to label x at all */
|
||||
return (y.class_ < 0.9);
|
||||
}
|
||||
|
||||
SVECTOR *psi(PATTERN x, LABEL y, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm) {
|
||||
/* Returns a feature vector describing the match between pattern x and
|
||||
label y. The feature vector is returned as an SVECTOR
|
||||
(i.e. pairs <featurenumber:featurevalue>), where the last pair has
|
||||
featurenumber 0 as a terminator. Featurenumbers start with 1 and end with
|
||||
sizePsi. This feature vector determines the linear evaluation
|
||||
function that is used to score labels. There will be one weight in
|
||||
sm.w for each feature. Note that psi has to match
|
||||
find_most_violated_constraint_???(x, y, sm) and vice versa. In
|
||||
particular, find_most_violated_constraint_???(x, y, sm) finds that
|
||||
ybar!=y that maximizes psi(x,ybar,sm)*sm.w (where * is the inner
|
||||
vector product) and the appropriate function of the loss. */
|
||||
SVECTOR *fvec;
|
||||
|
||||
/* shift the feature numbers to the position of weight vector of class y */
|
||||
fvec = shift_s(x.doc->fvec, (y.class_ - 1) * sparm->num_features);
|
||||
|
||||
/* The following makes sure that the weight vectors for each class
|
||||
are treated separately when kernels are used . */
|
||||
fvec->kernel_id = y.class_;
|
||||
|
||||
return (fvec);
|
||||
}
|
||||
|
||||
double loss(LABEL y, LABEL ybar, STRUCT_LEARN_PARM *sparm) {
|
||||
/* loss for correct label y and predicted label ybar. The loss for
|
||||
y==ybar has to be zero. sparm->loss_function is set with the -l option. */
|
||||
if (sparm->loss_function == 0) { /* type 0 loss: 0/1 loss */
|
||||
if (y.class_ == ybar.class_) /* return 0, if y==ybar. return 100 else */
|
||||
return (0);
|
||||
else
|
||||
return (100);
|
||||
}
|
||||
if (sparm->loss_function == 1) { /* type 1 loss: squared difference */
|
||||
return ((y.class_ - ybar.class_) * (y.class_ - ybar.class_));
|
||||
} else {
|
||||
/* Put your code for different loss functions here. But then
|
||||
find_most_violated_constraint_???(x, y, sm) has to return the
|
||||
highest scoring label with the largest loss. */
|
||||
printf("Unkown loss function\n");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
int finalize_iteration(double ceps, int cached_constraint, SAMPLE sample,
|
||||
STRUCTMODEL *sm, CONSTSET cset, double *alpha,
|
||||
STRUCT_LEARN_PARM *sparm) {
|
||||
/* This function is called just before the end of each cutting plane
|
||||
* iteration. ceps is the amount by which the most violated constraint found
|
||||
* in the current iteration was violated. cached_constraint is true if the
|
||||
* added constraint was constructed from the cache. If the return value is
|
||||
* FALSE, then the algorithm is allowed to terminate. If it is TRUE, the
|
||||
* algorithm will keep iterating even if the desired precision sparm->epsilon
|
||||
* is already reached. */
|
||||
return (0);
|
||||
}
|
||||
|
||||
void print_struct_learning_stats(SAMPLE sample, STRUCTMODEL *sm, CONSTSET cset,
|
||||
double *alpha, STRUCT_LEARN_PARM *sparm) {
|
||||
/* This function is called after training and allows final touches to
|
||||
the model sm. But primarly it allows computing and printing any
|
||||
kind of statistic (e.g. training error) you might want. */
|
||||
|
||||
/* Replace SV with single weight vector */
|
||||
MODEL *model = sm->svm_model;
|
||||
if (model->kernel_parm.kernel_type == LINEAR) {
|
||||
if (struct_verbosity >= 1) {
|
||||
printf("Compacting linear model...");
|
||||
fflush(stdout);
|
||||
}
|
||||
sm->svm_model = compact_linear_model(model);
|
||||
sm->w = sm->svm_model->lin_weights; /* short cut to weight vector */
|
||||
free_model(model, 1);
|
||||
if (struct_verbosity >= 1) {
|
||||
printf("done\n");
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void write_struct_model(char *file, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm) {
|
||||
/* Writes structural model sm to file file. */
|
||||
FILE *modelfl;
|
||||
long j, i, sv_num;
|
||||
MODEL *model = sm->svm_model;
|
||||
SVECTOR *v;
|
||||
|
||||
if ((modelfl = fopen(file, "w")) == NULL) {
|
||||
perror(file);
|
||||
exit(1);
|
||||
}
|
||||
fprintf(modelfl, "SVM-multiclass Version %s\n", INST_VERSION);
|
||||
fprintf(modelfl, "%d # number of classes\n", sparm->num_classes_);
|
||||
fprintf(modelfl, "%d # number of base features\n", sparm->num_features);
|
||||
fprintf(modelfl, "%d # loss function\n", sparm->loss_function);
|
||||
fprintf(modelfl, "%ld # kernel type\n", model->kernel_parm.kernel_type);
|
||||
fprintf(modelfl, "%ld # kernel parameter -d \n",
|
||||
model->kernel_parm.poly_degree);
|
||||
fprintf(modelfl, "%.8g # kernel parameter -g \n",
|
||||
model->kernel_parm.rbf_gamma);
|
||||
fprintf(modelfl, "%.8g # kernel parameter -s \n",
|
||||
model->kernel_parm.coef_lin);
|
||||
fprintf(modelfl, "%.8g # kernel parameter -r \n",
|
||||
model->kernel_parm.coef_const);
|
||||
fprintf(modelfl, "%s# kernel parameter -u \n", model->kernel_parm.custom);
|
||||
fprintf(modelfl, "%ld # highest feature index \n", model->totwords);
|
||||
fprintf(modelfl, "%ld # number of training documents \n", model->totdoc);
|
||||
|
||||
sv_num = 1;
|
||||
for (i = 1; i < model->sv_num; i++) {
|
||||
for (v = model->supvec[i]->fvec; v; v = v->next)
|
||||
sv_num++;
|
||||
}
|
||||
fprintf(modelfl, "%ld # number of support vectors plus 1 \n", sv_num);
|
||||
fprintf(modelfl,
|
||||
"%.8g # threshold b, each following line is a SV (starting with "
|
||||
"alpha*y)\n",
|
||||
model->b);
|
||||
|
||||
for (i = 1; i < model->sv_num; i++) {
|
||||
for (v = model->supvec[i]->fvec; v; v = v->next) {
|
||||
fprintf(modelfl, "%.32g ", model->alpha[i] * v->factor);
|
||||
fprintf(modelfl, "qid:%ld ", v->kernel_id);
|
||||
for (j = 0; (v->words[j]).wnum; j++) {
|
||||
fprintf(modelfl, "%ld:%.8g ", (long)(v->words[j]).wnum,
|
||||
(double)(v->words[j]).weight);
|
||||
}
|
||||
if (v->userdefined)
|
||||
fprintf(modelfl, "#%s\n", v->userdefined);
|
||||
else
|
||||
fprintf(modelfl, "#\n");
|
||||
/* NOTE: this could be made more efficient by summing the
|
||||
alpha's of identical vectors before writing them to the
|
||||
file. */
|
||||
}
|
||||
}
|
||||
fclose(modelfl);
|
||||
}
|
||||
|
||||
void print_struct_testing_stats(SAMPLE sample, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm,
|
||||
STRUCT_TEST_STATS *teststats) {
|
||||
/* This function is called after making all test predictions in
|
||||
svm_struct_classify and allows computing and printing any kind of
|
||||
evaluation (e.g. precision/recall) you might want. You can use
|
||||
the function eval_prediction to accumulate the necessary
|
||||
statistics for each prediction. */
|
||||
}
|
||||
|
||||
void eval_prediction(long exnum, EXAMPLE ex, LABEL ypred, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm, STRUCT_TEST_STATS *teststats) {
|
||||
/* This function allows you to accumlate statistic for how well the
|
||||
predicition matches the labeled example. It is called from
|
||||
svm_struct_classify. See also the function
|
||||
print_struct_testing_stats. */
|
||||
if (exnum == 0) { /* this is the first time the function is
|
||||
called. So initialize the teststats */
|
||||
}
|
||||
}
|
||||
|
||||
STRUCTMODEL read_struct_model(char *file, STRUCT_LEARN_PARM *sparm) {
|
||||
/* Reads structural model sm from file file. This function is used
|
||||
only in the prediction module, not in the learning module. */
|
||||
FILE *modelfl;
|
||||
STRUCTMODEL sm;
|
||||
long i, queryid, slackid;
|
||||
double costfactor;
|
||||
long max_sv, max_words, ll, wpos;
|
||||
char *line, *comment;
|
||||
WORD *words;
|
||||
char version_buffer[100];
|
||||
MODEL *model;
|
||||
|
||||
nol_ll(file, &max_sv, &max_words, &ll); /* scan size of model file */
|
||||
max_words += 2;
|
||||
ll += 2;
|
||||
|
||||
words = (WORD *)my_malloc(sizeof(WORD) * (max_words + 10));
|
||||
line = (char *)my_malloc(sizeof(char) * ll);
|
||||
model = (MODEL *)my_malloc(sizeof(MODEL));
|
||||
|
||||
if ((modelfl = fopen(file, "r")) == NULL) {
|
||||
perror(file);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
fscanf(modelfl, "SVM-multiclass Version %s\n", version_buffer);
|
||||
if (strcmp(version_buffer, INST_VERSION)) {
|
||||
perror(
|
||||
"Version of model-file does not match version of svm_struct_classify!");
|
||||
exit(1);
|
||||
}
|
||||
fscanf(modelfl, "%d%*[^\n]\n", &sparm->num_classes_);
|
||||
fscanf(modelfl, "%d%*[^\n]\n", &sparm->num_features);
|
||||
fscanf(modelfl, "%d%*[^\n]\n", &sparm->loss_function);
|
||||
fscanf(modelfl, "%ld%*[^\n]\n", &model->kernel_parm.kernel_type);
|
||||
fscanf(modelfl, "%ld%*[^\n]\n", &model->kernel_parm.poly_degree);
|
||||
fscanf(modelfl, "%lf%*[^\n]\n", &model->kernel_parm.rbf_gamma);
|
||||
fscanf(modelfl, "%lf%*[^\n]\n", &model->kernel_parm.coef_lin);
|
||||
fscanf(modelfl, "%lf%*[^\n]\n", &model->kernel_parm.coef_const);
|
||||
fscanf(modelfl, "%[^#]%*[^\n]\n", model->kernel_parm.custom);
|
||||
|
||||
fscanf(modelfl, "%ld%*[^\n]\n", &model->totwords);
|
||||
fscanf(modelfl, "%ld%*[^\n]\n", &model->totdoc);
|
||||
fscanf(modelfl, "%ld%*[^\n]\n", &model->sv_num);
|
||||
fscanf(modelfl, "%lf%*[^\n]\n", &model->b);
|
||||
|
||||
model->supvec = (DOC **)my_malloc(sizeof(DOC *) * model->sv_num);
|
||||
model->alpha = (double *)my_malloc(sizeof(double) * model->sv_num);
|
||||
model->index = NULL;
|
||||
model->lin_weights = NULL;
|
||||
|
||||
for (i = 1; i < model->sv_num; i++) {
|
||||
fgets(line, (int)ll, modelfl);
|
||||
if (!parse_document(line, words, &(model->alpha[i]), &queryid, &slackid,
|
||||
&costfactor, &wpos, max_words, &comment)) {
|
||||
printf("\nParsing error while reading model file in SV %ld!\n%s", i,
|
||||
line);
|
||||
exit(1);
|
||||
}
|
||||
model->supvec[i] =
|
||||
create_example(-1, 0, 0, 0.0, create_svector(words, comment, 1.0));
|
||||
model->supvec[i]->fvec->kernel_id = queryid;
|
||||
}
|
||||
fclose(modelfl);
|
||||
free(line);
|
||||
free(words);
|
||||
if (verbosity >= 1) {
|
||||
fprintf(stdout, " (%d support vectors read) ", (int)(model->sv_num - 1));
|
||||
}
|
||||
sm.svm_model = model;
|
||||
sm.sizePsi = model->totwords;
|
||||
sm.w = NULL;
|
||||
return (sm);
|
||||
}
|
||||
|
||||
void write_label(FILE *fp, LABEL y) {
|
||||
/* Writes label y to file handle fp. */
|
||||
int i;
|
||||
fprintf(fp, "%d", y.class_);
|
||||
if (y.scores)
|
||||
for (i = 1; i <= y.num_classes_; i++)
|
||||
fprintf(fp, " %f", y.scores[i]);
|
||||
fprintf(fp, "\n");
|
||||
}
|
||||
|
||||
void free_pattern(PATTERN x) {
|
||||
/* Frees the memory of x. */
|
||||
free_example(x.doc, 1);
|
||||
}
|
||||
|
||||
void free_label(LABEL y) {
|
||||
/* Frees the memory of y. */
|
||||
if (y.scores)
|
||||
free(y.scores);
|
||||
}
|
||||
|
||||
void free_struct_model(STRUCTMODEL sm) {
|
||||
/* Frees the memory of model. */
|
||||
/* if(sm.w) free(sm.w); */ /* this is free'd in free_model */
|
||||
if (sm.svm_model)
|
||||
free_model(sm.svm_model, 1);
|
||||
/* add free calls for user defined data here */
|
||||
}
|
||||
|
||||
void free_struct_sample(SAMPLE s) {
|
||||
/* Frees the memory of sample s. */
|
||||
int i;
|
||||
for (i = 0; i < s.n; i++) {
|
||||
free_pattern(s.examples[i].x);
|
||||
free_label(s.examples[i].y);
|
||||
}
|
||||
free(s.examples);
|
||||
}
|
||||
|
||||
void print_struct_help() {
|
||||
/* Prints a help text that is appended to the common help text of
|
||||
svm_struct_learn. */
|
||||
|
||||
printf(" none\n\n");
|
||||
printf("Based on multi-class SVM formulation described in:\n");
|
||||
printf(" K. Crammer and Y. Singer. On the Algorithmic "
|
||||
"Implementation of\n");
|
||||
printf(" Multi-class SVMs, JMLR, 2001.\n");
|
||||
}
|
||||
|
||||
void parse_struct_parameters(STRUCT_LEARN_PARM *sparm) {
|
||||
/* Parses the command line parameters that start with -- */
|
||||
int i;
|
||||
|
||||
for (i = 0; (i < sparm->custom_argc) && ((sparm->custom_argv[i])[0] == '-');
|
||||
i++) {
|
||||
switch ((sparm->custom_argv[i])[2]) {
|
||||
case 'a':
|
||||
i++; /* strcpy(learn_parm->alphafile,argv[i]); */
|
||||
break;
|
||||
case 'e':
|
||||
i++; /* sparm->epsilon=atof(sparm->custom_argv[i]); */
|
||||
break;
|
||||
case 'k':
|
||||
i++; /* sparm->newconstretrain=atol(sparm->custom_argv[i]); */
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void print_struct_help_classify() {
|
||||
/* Prints a help text that is appended to the common help text of
|
||||
svm_struct_classify. */
|
||||
}
|
||||
|
||||
void parse_struct_parameters_classify(STRUCT_LEARN_PARM *sparm) {
|
||||
/* Parses the command line parameters that start with -- for the
|
||||
classification module */
|
||||
int i;
|
||||
|
||||
for (i = 0; (i < sparm->custom_argc) && ((sparm->custom_argv[i])[0] == '-');
|
||||
i++) {
|
||||
switch ((sparm->custom_argv[i])[2]) {
|
||||
/* case 'x': i++; strcpy(xvalue,sparm->custom_argv[i]); break; */
|
||||
default:
|
||||
printf("\nUnrecognized option %s!\n\n", sparm->custom_argv[i]);
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
}
|
76
src/classifier/svm/svm_struct_api.h
Executable file
76
src/classifier/svm/svm_struct_api.h
Executable file
@@ -0,0 +1,76 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_api.h */
|
||||
/* */
|
||||
/* Definition of API for attaching implementing SVM learning of */
|
||||
/* structures (e.g. parsing, multi-label classification, HMM) */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 03.07.04 */
|
||||
/* */
|
||||
/* Copyright (c) 2004 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#include "svm_struct/svm_struct_common.h"
|
||||
#include "svm_struct_api_types.h"
|
||||
|
||||
#ifndef svm_struct_api
|
||||
#define svm_struct_api
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void svm_struct_learn_api_init(int argc, char *argv[]);
|
||||
void svm_struct_learn_api_exit();
|
||||
void svm_struct_classify_api_init(int argc, char *argv[]);
|
||||
void svm_struct_classify_api_exit();
|
||||
SAMPLE read_struct_examples(char *file, STRUCT_LEARN_PARM *sparm);
|
||||
void init_struct_model(SAMPLE sample, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm);
|
||||
CONSTSET init_struct_constraints(SAMPLE sample, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm);
|
||||
LABEL find_most_violated_constraint_slackrescaling(PATTERN x, LABEL y,
|
||||
STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm);
|
||||
LABEL find_most_violated_constraint_marginrescaling(PATTERN x, LABEL y,
|
||||
STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm);
|
||||
LABEL classify_struct_example(PATTERN x, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm);
|
||||
int empty_label(LABEL y);
|
||||
SVECTOR *psi(PATTERN x, LABEL y, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm);
|
||||
double loss(LABEL y, LABEL ybar, STRUCT_LEARN_PARM *sparm);
|
||||
int finalize_iteration(double ceps, int cached_constraint, SAMPLE sample,
|
||||
STRUCTMODEL *sm, CONSTSET cset, double *alpha,
|
||||
STRUCT_LEARN_PARM *sparm);
|
||||
void print_struct_learning_stats(SAMPLE sample, STRUCTMODEL *sm, CONSTSET cset,
|
||||
double *alpha, STRUCT_LEARN_PARM *sparm);
|
||||
void print_struct_testing_stats(SAMPLE sample, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm,
|
||||
STRUCT_TEST_STATS *teststats);
|
||||
void eval_prediction(long exnum, EXAMPLE ex, LABEL prediction, STRUCTMODEL *sm,
|
||||
STRUCT_LEARN_PARM *sparm, STRUCT_TEST_STATS *teststats);
|
||||
void write_struct_model(char *file, STRUCTMODEL *sm, STRUCT_LEARN_PARM *sparm);
|
||||
STRUCTMODEL read_struct_model(char *file, STRUCT_LEARN_PARM *sparm);
|
||||
void write_label(FILE *fp, LABEL y);
|
||||
void free_pattern(PATTERN x);
|
||||
void free_label(LABEL y);
|
||||
void free_struct_model(STRUCTMODEL sm);
|
||||
void free_struct_sample(SAMPLE s);
|
||||
void print_struct_help();
|
||||
void parse_struct_parameters(STRUCT_LEARN_PARM *sparm);
|
||||
void print_struct_help_classify();
|
||||
void parse_struct_parameters_classify(STRUCT_LEARN_PARM *sparm);
|
||||
void svm_learn_struct_joint_custom(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||
STRUCTMODEL *sm);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif
|
114
src/classifier/svm/svm_struct_api_types.h
Executable file
114
src/classifier/svm/svm_struct_api_types.h
Executable file
@@ -0,0 +1,114 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_api.h */
|
||||
/* */
|
||||
/* Definition of API for attaching implementing SVM learning of */
|
||||
/* structures (e.g. parsing, multi-label classification, HMM) */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 13.10.03 */
|
||||
/* */
|
||||
/* Copyright (c) 2003 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#ifndef svm_struct_api_types
|
||||
#define svm_struct_api_types
|
||||
|
||||
#include "svm_light/svm_common.h"
|
||||
#include "svm_light/svm_learn.h"
|
||||
|
||||
#define INST_NAME "Multi-Class SVM"
|
||||
#define INST_VERSION "V2.20"
|
||||
#define INST_VERSION_DATE "14.08.08"
|
||||
|
||||
/* default precision for solving the optimization problem */
|
||||
#define DEFAULT_EPS 0.1
|
||||
/* default loss rescaling method: 1=slack_rescaling, 2=margin_rescaling */
|
||||
#define DEFAULT_RESCALING 2
|
||||
/* default loss function: */
|
||||
#define DEFAULT_LOSS_FCT 0
|
||||
/* default optimization algorithm to use: */
|
||||
#define DEFAULT_ALG_TYPE 4
|
||||
/* store Psi(x,y) once instead of recomputing it every time: */
|
||||
#define USE_FYCACHE 0
|
||||
/* decide whether to evaluate sum before storing vectors in constraint
|
||||
cache:
|
||||
0 = NO,
|
||||
1 = YES (best, if sparse vectors and long vector lists),
|
||||
2 = YES (best, if short vector lists),
|
||||
3 = YES (best, if dense vectors and long vector lists) */
|
||||
#define COMPACT_CACHED_VECTORS 2
|
||||
/* minimum absolute value below which values in sparse vectors are
|
||||
rounded to zero. Values are stored in the FVAL type defined in svm_common.h
|
||||
RECOMMENDATION: assuming you use FVAL=float, use
|
||||
10E-15 if COMPACT_CACHED_VECTORS is 1
|
||||
10E-10 if COMPACT_CACHED_VECTORS is 2 or 3
|
||||
*/
|
||||
#define COMPACT_ROUNDING_THRESH 10E-15
|
||||
|
||||
typedef struct pattern {
|
||||
/* this defines the x-part of a training example, e.g. the structure
|
||||
for storing a natural language sentence in NLP parsing */
|
||||
DOC *doc;
|
||||
} PATTERN;
|
||||
|
||||
typedef struct label {
|
||||
/* this defines the y-part (the label) of a training example,
|
||||
e.g. the parse tree of the corresponding sentence. */
|
||||
int class_; /* class label */
|
||||
int num_classes_; /* total number of classes */
|
||||
double *scores; /* value of linear function of each class */
|
||||
} LABEL;
|
||||
|
||||
typedef struct structmodel {
|
||||
double *w; /* pointer to the learned weights */
|
||||
MODEL *svm_model; /* the learned SVM model */
|
||||
long sizePsi; /* maximum number of weights in w */
|
||||
double walpha;
|
||||
/* other information that is needed for the stuctural model can be
|
||||
added here, e.g. the grammar rules for NLP parsing */
|
||||
} STRUCTMODEL;
|
||||
|
||||
typedef struct struct_learn_parm {
|
||||
double epsilon; /* precision for which to solve
|
||||
quadratic program */
|
||||
double newconstretrain; /* number of new constraints to
|
||||
accumulate before recomputing the QP
|
||||
solution */
|
||||
int ccache_size; /* maximum number of constraints to
|
||||
cache for each example (used in w=4
|
||||
algorithm) */
|
||||
double batch_size; /* size of the mini batches in percent
|
||||
of training set size (used in w=4
|
||||
algorithm) */
|
||||
double C; /* trade-off between margin and loss */
|
||||
char custom_argv[20][300]; /* string set with the -u command line option */
|
||||
int custom_argc; /* number of -u command line options */
|
||||
int slack_norm; /* norm to use in objective function
|
||||
for slack variables; 1 -> L1-norm,
|
||||
2 -> L2-norm */
|
||||
int loss_type; /* selected loss function from -r
|
||||
command line option. Select between
|
||||
slack rescaling (1) and margin
|
||||
rescaling (2) */
|
||||
int loss_function; /* select between different loss
|
||||
functions via -l command line
|
||||
option */
|
||||
/* further parameters that are passed to init_struct_model() */
|
||||
int num_classes_;
|
||||
int num_features;
|
||||
} STRUCT_LEARN_PARM;
|
||||
|
||||
typedef struct struct_test_stats {
|
||||
/* you can add variables for keeping statistics when evaluating the
|
||||
test predictions in svm_struct_classify. This can be used in the
|
||||
function eval_prediction and print_struct_testing_stats. */
|
||||
} STRUCT_TEST_STATS;
|
||||
|
||||
#endif
|
42
src/classifier/svm/svm_struct_learn_custom.c
Executable file
42
src/classifier/svm/svm_struct_learn_custom.c
Executable file
@@ -0,0 +1,42 @@
|
||||
/***********************************************************************/
|
||||
/* */
|
||||
/* svm_struct_learn_custom.c (instantiated for SVM-perform) */
|
||||
/* */
|
||||
/* Allows implementing a custom/alternate algorithm for solving */
|
||||
/* the structual SVM optimization problem. The algorithm can use */
|
||||
/* full access to the SVM-struct API and to SVM-light. */
|
||||
/* */
|
||||
/* Author: Thorsten Joachims */
|
||||
/* Date: 09.01.08 */
|
||||
/* */
|
||||
/* Copyright (c) 2008 Thorsten Joachims - All rights reserved */
|
||||
/* */
|
||||
/* This software is available for non-commercial use only. It must */
|
||||
/* not be modified and distributed without prior permission of the */
|
||||
/* author. The author is not responsible for implications from the */
|
||||
/* use of this software. */
|
||||
/* */
|
||||
/***********************************************************************/
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include "svm_struct_api.h"
|
||||
#include "svm_light/svm_common.h"
|
||||
#include "svm_struct/svm_struct_common.h"
|
||||
#include "svm_struct/svm_struct_learn.h"
|
||||
|
||||
|
||||
void svm_learn_struct_joint_custom(SAMPLE sample, STRUCT_LEARN_PARM *sparm,
|
||||
LEARN_PARM *lparm, KERNEL_PARM *kparm,
|
||||
STRUCTMODEL *sm)
|
||||
/* Input: sample (training examples)
|
||||
sparm (structural learning parameters)
|
||||
lparm (svm learning parameters)
|
||||
kparm (kernel parameters)
|
||||
Output: sm (learned model) */
|
||||
{
|
||||
/* Put your algorithm here. See svm_struct_learn.c for an example of
|
||||
how to access this API. */
|
||||
}
|
||||
|
34
src/classifier/svm/svm_trainer.cpp
Normal file
34
src/classifier/svm/svm_trainer.cpp
Normal file
@@ -0,0 +1,34 @@
|
||||
#include "../svm_trainer.h"
|
||||
#include "svm_binary_trainer.hpp"
|
||||
#include "svm_multiclass_trainer.hpp"
|
||||
|
||||
ISVMTrainer new_svm_binary_trainer() {
|
||||
return new ovclassifier::SVMBinaryTrainer();
|
||||
}
|
||||
ISVMTrainer new_svm_multiclass_trainer() {
|
||||
return new ovclassifier::SVMMultiClassTrainer();
|
||||
}
|
||||
|
||||
void destroy_svm_trainer(ISVMTrainer trainer) {
|
||||
delete static_cast<ovclassifier::SVMTrainer *>(trainer);
|
||||
}
|
||||
|
||||
void svm_trainer_reset(ISVMTrainer trainer) {
|
||||
static_cast<ovclassifier::SVMTrainer *>(trainer)->Reset();
|
||||
}
|
||||
|
||||
void svm_trainer_set_labels(ISVMTrainer trainer, int labels) {
|
||||
static_cast<ovclassifier::SVMTrainer *>(trainer)->SetLabels(labels);
|
||||
}
|
||||
|
||||
void svm_trainer_set_features(ISVMTrainer trainer, int feats) {
|
||||
static_cast<ovclassifier::SVMTrainer *>(trainer)->SetFeatures(feats);
|
||||
}
|
||||
|
||||
void svm_trainer_add_data(ISVMTrainer trainer, int label, const float *vec) {
|
||||
static_cast<ovclassifier::SVMTrainer *>(trainer)->AddData(label, vec);
|
||||
}
|
||||
|
||||
int svm_train(ISVMTrainer trainer, const char *modelfile) {
|
||||
return static_cast<ovclassifier::SVMTrainer *>(trainer)->Train(modelfile);
|
||||
}
|
16
src/classifier/svm/svm_trainer.hpp
Normal file
16
src/classifier/svm/svm_trainer.hpp
Normal file
@@ -0,0 +1,16 @@
|
||||
#ifndef _SVM_TRAINER_H_
|
||||
#define _SVM_TRAINER_H_
|
||||
|
||||
namespace ovclassifier {
|
||||
class SVMTrainer {
|
||||
public:
|
||||
virtual ~SVMTrainer(){};
|
||||
virtual void Reset() = 0;
|
||||
virtual void SetLabels(int labels) = 0;
|
||||
virtual void SetFeatures(int feats) = 0;
|
||||
virtual void AddData(int label, const float *vec) = 0;
|
||||
virtual int Train(const char *modelfile) = 0;
|
||||
};
|
||||
|
||||
} // namespace ovclassifier
|
||||
#endif // _SVM_TRAINER_H_
|
19
src/classifier/svm_classifier.h
Normal file
19
src/classifier/svm_classifier.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef _CLASSIFIER_SVM_CLASSIFIER_C_H_
|
||||
#define _CLASSIFIER_SVM_CLASSIFIER_C_H_
|
||||
|
||||
#include "../common/common.h"
|
||||
#ifdef __cplusplus
|
||||
#include "svm/svm_classifier.hpp"
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef void *ISVMClassifier;
|
||||
ISVMClassifier new_svm_binary_classifier();
|
||||
ISVMClassifier new_svm_multiclass_classifier();
|
||||
void destroy_svm_classifier(ISVMClassifier e);
|
||||
int svm_classifier_load_model(ISVMClassifier e, const char *modelfile);
|
||||
double svm_predict(ISVMClassifier e, const float *vec);
|
||||
int svm_classify(ISVMClassifier e, const float *vec, FloatVector *scores);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // !_CLASSIFER_SVM_CLASSIFIER_C_H_
|
20
src/classifier/svm_trainer.h
Normal file
20
src/classifier/svm_trainer.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef _CLASSIFIER_SVM_TRAINER_C_H_
|
||||
#define _CLASSIFIER_SVM_TRAINER_C_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
#include "svm/svm_trainer.hpp"
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef void *ISVMTrainer;
|
||||
ISVMTrainer new_svm_binary_trainer();
|
||||
ISVMTrainer new_svm_multiclass_trainer();
|
||||
void destroy_svm_trainer(ISVMTrainer trainer);
|
||||
void svm_trainer_reset(ISVMTrainer trainer);
|
||||
void svm_trainer_set_labels(ISVMTrainer trainer, int labels);
|
||||
void svm_trainer_set_features(ISVMTrainer trainer, int feats);
|
||||
void svm_trainer_add_data(ISVMTrainer trainer, int label, const float *vec);
|
||||
int svm_train(ISVMTrainer trainer, const char *modelfile);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
#endif // !_CLASSIFER_SVM_TRAINER_C_H_
|
@@ -1,9 +1,9 @@
|
||||
#include "common.h"
|
||||
#include "cpu.h"
|
||||
#include <algorithm>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
#include <iostream>
|
||||
#include "cpu.h"
|
||||
#include <math.h>
|
||||
|
||||
#ifdef OV_VULKAN
|
||||
#include "gpu.h"
|
||||
@@ -29,9 +29,7 @@ void destroy_gpu_instance() {
|
||||
#endif // OV_VULKAN
|
||||
}
|
||||
|
||||
int get_big_cpu_count() {
|
||||
return ncnn::get_big_cpu_count();
|
||||
}
|
||||
int get_big_cpu_count() { return ncnn::get_big_cpu_count(); }
|
||||
|
||||
void set_omp_num_threads(int n) {
|
||||
#ifdef OV_OPENMP
|
||||
@@ -43,9 +41,7 @@ int load_model(IEstimator d, const char *root_path) {
|
||||
return static_cast<ov::Estimator *>(d)->LoadModel(root_path);
|
||||
}
|
||||
|
||||
void destroy_estimator(IEstimator d) {
|
||||
delete static_cast<ov::Estimator*>(d);
|
||||
}
|
||||
void destroy_estimator(IEstimator d) { delete static_cast<ov::Estimator *>(d); }
|
||||
|
||||
void set_num_threads(IEstimator d, int n) {
|
||||
static_cast<ov::Estimator *>(d)->set_num_threads(n);
|
||||
@@ -62,6 +58,13 @@ void FreePoint2fVector(Point2fVector* p) {
|
||||
}
|
||||
}
|
||||
|
||||
void FreePoint3dVector(Point3dVector *p) {
|
||||
if (p->points != NULL) {
|
||||
free(p->points);
|
||||
p->points = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void Point2fVectorSetValue(Point2fVector *p, int i, const Point2f *val) {
|
||||
if (p->points == NULL || i >= p->length) {
|
||||
return;
|
||||
@@ -134,6 +137,7 @@ Estimator::Estimator() : EstimatorBase() {
|
||||
}
|
||||
net_->opt.blob_allocator = &blob_allocator_;
|
||||
net_->opt.workspace_allocator = &workspace_allocator_;
|
||||
net_->opt.lightmode = light_mode_;
|
||||
#ifdef OV_VULKAN
|
||||
net_->opt.use_vulkan_compute = true;
|
||||
#endif // OV_VULKAN
|
||||
@@ -160,15 +164,11 @@ int Estimator::LoadModel(const char * root_path) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
EstimatorBase::EstimatorBase() {
|
||||
num_threads = ncnn::get_big_cpu_count();
|
||||
}
|
||||
EstimatorBase::EstimatorBase() { num_threads = ncnn::get_big_cpu_count(); }
|
||||
|
||||
EstimatorBase::~EstimatorBase() {}
|
||||
|
||||
void EstimatorBase::set_num_threads(int n) {
|
||||
num_threads = n;
|
||||
}
|
||||
void EstimatorBase::set_num_threads(int n) { num_threads = n; }
|
||||
|
||||
void Estimator::set_num_threads(int n) {
|
||||
EstimatorBase::set_num_threads(n);
|
||||
@@ -184,8 +184,7 @@ void Estimator::set_light_mode(bool mode) {
|
||||
}
|
||||
}
|
||||
|
||||
int RatioAnchors(const Rect & anchor,
|
||||
const std::vector<float>& ratios,
|
||||
int RatioAnchors(const Rect &anchor, const std::vector<float> &ratios,
|
||||
std::vector<Rect> *anchors, int threads_num) {
|
||||
anchors->clear();
|
||||
Point center = Point(anchor.x + (anchor.width - 1) * 0.5f,
|
||||
@@ -202,15 +201,16 @@ int RatioAnchors(const Rect & anchor,
|
||||
float curr_x = center.x - (curr_anchor_width - 1) * 0.5f;
|
||||
float curr_y = center.y - (curr_anchor_height - 1) * 0.5f;
|
||||
|
||||
Rect curr_anchor = Rect(curr_x, curr_y,
|
||||
curr_anchor_width - 1, curr_anchor_height - 1);
|
||||
Rect curr_anchor =
|
||||
Rect(curr_x, curr_y, curr_anchor_width - 1, curr_anchor_height - 1);
|
||||
anchors->push_back(curr_anchor);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ScaleAnchors(const std::vector<Rect> &ratio_anchors,
|
||||
const std::vector<float>& scales, std::vector<Rect>* anchors, int threads_num) {
|
||||
const std::vector<float> &scales, std::vector<Rect> *anchors,
|
||||
int threads_num) {
|
||||
anchors->clear();
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel for num_threads(threads_num)
|
||||
@@ -225,8 +225,7 @@ int ScaleAnchors(const std::vector<Rect>& ratio_anchors,
|
||||
float curr_height = scale * (anchor.height + 1);
|
||||
float curr_x = center.x - curr_width * 0.5f;
|
||||
float curr_y = center.y - curr_height * 0.5f;
|
||||
Rect curr_anchor = Rect(curr_x, curr_y,
|
||||
curr_width, curr_height);
|
||||
Rect curr_anchor = Rect(curr_x, curr_y, curr_width, curr_height);
|
||||
anchors->push_back(curr_anchor);
|
||||
}
|
||||
}
|
||||
@@ -234,10 +233,8 @@ int ScaleAnchors(const std::vector<Rect>& ratio_anchors,
|
||||
return 0;
|
||||
}
|
||||
|
||||
int GenerateAnchors(const int & base_size,
|
||||
const std::vector<float>& ratios,
|
||||
const std::vector<float> scales,
|
||||
std::vector<Rect>* anchors,
|
||||
int GenerateAnchors(const int &base_size, const std::vector<float> &ratios,
|
||||
const std::vector<float> scales, std::vector<Rect> *anchors,
|
||||
int threads_num) {
|
||||
anchors->clear();
|
||||
Rect anchor = Rect(0, 0, base_size, base_size);
|
||||
@@ -250,27 +247,25 @@ int GenerateAnchors(const int & base_size,
|
||||
|
||||
float InterRectArea(const Rect &a, const Rect &b) {
|
||||
Point left_top = Point(std::max(a.x, b.x), std::max(a.y, b.y));
|
||||
Point right_bottom = Point(std::min(a.br().x, b.br().x), std::min(a.br().y, b.br().y));
|
||||
Point right_bottom =
|
||||
Point(std::min(a.br().x, b.br().x), std::min(a.br().y, b.br().y));
|
||||
Point diff = right_bottom - left_top;
|
||||
return (std::max(diff.x + 1, 0) * std::max(diff.y + 1, 0));
|
||||
}
|
||||
|
||||
int ComputeIOU(const Rect & rect1,
|
||||
const Rect & rect2, float * iou,
|
||||
int ComputeIOU(const Rect &rect1, const Rect &rect2, float *iou,
|
||||
const std::string &type) {
|
||||
|
||||
float inter_area = InterRectArea(rect1, rect2);
|
||||
if (type == "UNION") {
|
||||
*iou = inter_area / (rect1.area() + rect2.area() - inter_area);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
*iou = inter_area / std::min(rect1.area(), rect2.area());
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
void EnlargeRect(const float &scale, Rect *rect) {
|
||||
float offset_x = (scale - 1.f) / 2.f * rect->width;
|
||||
float offset_y = (scale - 1.f) / 2.f * rect->height;
|
||||
@@ -291,22 +286,20 @@ void RectifyRect(Rect* rect) {
|
||||
rect->height = max_side;
|
||||
}
|
||||
|
||||
void qsort_descent_inplace(std::vector<ObjectInfo>& objects, int left, int right)
|
||||
{
|
||||
void qsort_descent_inplace(std::vector<ObjectInfo> &objects, int left,
|
||||
int right) {
|
||||
int i = left;
|
||||
int j = right;
|
||||
float p = objects[(left + right) / 2].score;
|
||||
|
||||
while (i <= j)
|
||||
{
|
||||
while (i <= j) {
|
||||
while (objects[i].score > p)
|
||||
i++;
|
||||
|
||||
while (objects[j].score < p)
|
||||
j--;
|
||||
|
||||
if (i <= j)
|
||||
{
|
||||
if (i <= j) {
|
||||
// swap
|
||||
std::swap(objects[i], objects[j]);
|
||||
|
||||
@@ -319,42 +312,40 @@ void qsort_descent_inplace(std::vector<ObjectInfo>& objects, int left, int right
|
||||
{
|
||||
#pragma omp section
|
||||
{
|
||||
if (left < j) qsort_descent_inplace(objects, left, j);
|
||||
if (left < j)
|
||||
qsort_descent_inplace(objects, left, j);
|
||||
}
|
||||
#pragma omp section
|
||||
{
|
||||
if (i < right) qsort_descent_inplace(objects, i, right);
|
||||
if (i < right)
|
||||
qsort_descent_inplace(objects, i, right);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void qsort_descent_inplace(std::vector<ObjectInfo>& objects)
|
||||
{
|
||||
void qsort_descent_inplace(std::vector<ObjectInfo> &objects) {
|
||||
if (objects.empty())
|
||||
return;
|
||||
|
||||
qsort_descent_inplace(objects, 0, objects.size() - 1);
|
||||
}
|
||||
|
||||
void nms_sorted_bboxes(const std::vector<ObjectInfo>& objects, std::vector<int>& picked, float nms_threshold)
|
||||
{
|
||||
void nms_sorted_bboxes(const std::vector<ObjectInfo> &objects,
|
||||
std::vector<int> &picked, float nms_threshold) {
|
||||
picked.clear();
|
||||
|
||||
const int n = objects.size();
|
||||
|
||||
std::vector<float> areas(n);
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
for (int i = 0; i < n; i++) {
|
||||
areas[i] = objects[i].rect.area();
|
||||
}
|
||||
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
for (int i = 0; i < n; i++) {
|
||||
const ObjectInfo &a = objects[i];
|
||||
|
||||
int keep = 1;
|
||||
for (int j = 0; j < (int)picked.size(); j++)
|
||||
{
|
||||
for (int j = 0; j < (int)picked.size(); j++) {
|
||||
const ObjectInfo &b = objects[picked[j]];
|
||||
|
||||
// intersection over union
|
||||
@@ -370,9 +361,10 @@ void nms_sorted_bboxes(const std::vector<ObjectInfo>& objects, std::vector<int>&
|
||||
}
|
||||
}
|
||||
//
|
||||
// insightface/detection/scrfd/mmdet/core/anchor/anchor_generator.py gen_single_level_base_anchors()
|
||||
ncnn::Mat generate_anchors(int base_size, const ncnn::Mat& ratios, const ncnn::Mat& scales)
|
||||
{
|
||||
// insightface/detection/scrfd/mmdet/core/anchor/anchor_generator.py
|
||||
// gen_single_level_base_anchors()
|
||||
ncnn::Mat generate_anchors(int base_size, const ncnn::Mat &ratios,
|
||||
const ncnn::Mat &scales) {
|
||||
int num_ratio = ratios.w;
|
||||
int num_scale = scales.w;
|
||||
|
||||
@@ -382,15 +374,13 @@ ncnn::Mat generate_anchors(int base_size, const ncnn::Mat& ratios, const ncnn::M
|
||||
const float cx = 0;
|
||||
const float cy = 0;
|
||||
|
||||
for (int i = 0; i < num_ratio; i++)
|
||||
{
|
||||
for (int i = 0; i < num_ratio; i++) {
|
||||
float ar = ratios[i];
|
||||
|
||||
int r_w = round(base_size / sqrt(ar));
|
||||
int r_h = round(r_w * ar); // round(base_size * sqrt(ar));
|
||||
|
||||
for (int j = 0; j < num_scale; j++)
|
||||
{
|
||||
for (int j = 0; j < num_scale; j++) {
|
||||
float scale = scales[j];
|
||||
|
||||
float rs_w = r_w * scale;
|
||||
@@ -408,15 +398,12 @@ ncnn::Mat generate_anchors(int base_size, const ncnn::Mat& ratios, const ncnn::M
|
||||
return anchors;
|
||||
}
|
||||
|
||||
int generate_grids_and_stride(const int target_size, std::vector<int>& strides, std::vector<GridAndStride>& grid_strides)
|
||||
{
|
||||
for (auto stride : strides)
|
||||
{
|
||||
int generate_grids_and_stride(const int target_size, std::vector<int> &strides,
|
||||
std::vector<GridAndStride> &grid_strides) {
|
||||
for (auto stride : strides) {
|
||||
int num_grid = target_size / stride;
|
||||
for (int g1 = 0; g1 < num_grid; g1++)
|
||||
{
|
||||
for (int g0 = 0; g0 < num_grid; g0++)
|
||||
{
|
||||
for (int g1 = 0; g1 < num_grid; g1++) {
|
||||
for (int g0 = 0; g0 < num_grid; g0++) {
|
||||
grid_strides.push_back((GridAndStride){g0, g1, stride});
|
||||
}
|
||||
}
|
||||
@@ -425,9 +412,6 @@ int generate_grids_and_stride(const int target_size, std::vector<int>& strides,
|
||||
return 0;
|
||||
}
|
||||
|
||||
float sigmoid(float x)
|
||||
{
|
||||
return static_cast<float>(1.f / (1.f + exp(-x)));
|
||||
}
|
||||
float sigmoid(float x) { return static_cast<float>(1.f / (1.f + exp(-x))); }
|
||||
|
||||
}
|
||||
} // namespace ov
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user