Files
go-darknet/network.go
2023-01-05 17:16:47 +03:00

163 lines
4.8 KiB
Go

package darknet
// #include <stdlib.h>
//
// #include <darknet.h>
//
// #include "network.h"
import "C"
import (
"io/ioutil"
"os"
"time"
"unsafe"
"github.com/edsrzf/mmap-go"
"github.com/pkg/errors"
)
// YOLONetwork represents a neural network using YOLO.
type YOLONetwork struct {
GPUDeviceIndex int
NetworkConfigurationFile string
WeightsFile string
Threshold float32
ClassNames []string
Classes int
cNet *C.network
hierarchalThreshold float32
nms float32
}
var (
errNetworkNotInit = errors.New("Network not initialised")
errUnableToInitNetwork = errors.New("Unable to initialise")
)
// Init the network.
func (n *YOLONetwork) Init() error {
nCfg := C.CString(n.NetworkConfigurationFile)
defer C.free(unsafe.Pointer(nCfg))
wFile := C.CString(n.WeightsFile)
defer C.free(unsafe.Pointer(wFile))
// GPU device ID must be set before `load_network()` is invoked.
C.cuda_set_device(C.int(n.GPUDeviceIndex))
n.cNet = C.load_network(nCfg, wFile, 0)
if n.cNet == nil {
return errUnableToInitNetwork
}
C.srand(2222222)
n.hierarchalThreshold = 0.5
n.nms = 0.45
metadata := C.get_metadata(nCfg)
n.Classes = int(metadata.classes)
n.ClassNames = makeClassNames(metadata.names, n.Classes)
return nil
}
/* EXPERIMENTAL */
/*
By default AlexeyAB's Darknet doesn't export any functions in darknet.h to give ability to create network from scratch via code.
So I can't modify `parse_network_cfg_custom` to load `list *sections = read_cfg(filename);` from memory.
So, the point of this method is to be able create network configuration via Golang and then pass it to `C.load_network`
This code is portable for Windows/Linux/MacOS. See the ref.: https://github.com/edsrzf/mmap-go#mmap-go
*/
func (n *YOLONetwork) InitFromDefinedCfg() error {
wFile := C.CString(n.WeightsFile)
defer C.free(unsafe.Pointer(wFile))
/* Prepare network sections via Go */
/*
instead of using:
wFile := C.CString(n.WeightsFile)
defer C.free(unsafe.Pointer(wFile))
We call `load_network` that takes the first char* parameter (representing a file path to network configuration) with a Go function that takes an in-memory file
*/
cfgBytes, err := os.ReadFile(n.NetworkConfigurationFile)
if err != nil {
return errors.Wrap(err, "Can't read file bytes")
}
// Create a temporary file.
tmpFile, err := ioutil.TempFile("", "")
if err != nil {
return errors.Wrap(err, "Can't create temporary file")
}
defer os.Remove(tmpFile.Name())
// Write the file content to the temporary file.
if _, err := tmpFile.Write(cfgBytes); err != nil {
return errors.Wrap(err, "Can't write network's configuration into temporary file")
}
defer tmpFile.Close()
// Open the temporary file.
// fd, err := syscall.Open(tmpFile.Name(), syscall.O_RDWR, 0)
// if err != nil {
// return errors.Wrap(err, "Can't re-open temporary file")
// }
// defer syscall.Close(fd)
// Create a memory mapping of the file.
// addr, err := syscall.Mmap(fd, 0, len(cfgBytes), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED)
// if err != nil {
// return errors.Wrap(err, "Can't mmap on temporary file")
// }
// Unmap the memory-mapped file.
// defer syscall.Munmap(addr)
// Map the temporary file to memory.
mapping, err := mmap.Map(tmpFile, mmap.RDWR, 0)
if err != nil {
return errors.Wrap(err, "Can't mmap on temporary file")
}
defer mapping.Unmap() // Unmap the memory-mapped file.
// GPU device ID must be set before `load_network()` is invoked.
C.cuda_set_device(C.int(n.GPUDeviceIndex))
nCfg := C.CString(tmpFile.Name())
defer C.free(unsafe.Pointer(nCfg))
n.cNet = C.load_network(nCfg, wFile, 0)
if n.cNet == nil {
return errUnableToInitNetwork
}
C.srand(2222222)
n.hierarchalThreshold = 0.5
n.nms = 0.45
metadata := C.get_metadata(nCfg)
n.Classes = int(metadata.classes)
n.ClassNames = makeClassNames(metadata.names, n.Classes)
return nil
}
// Close and release resources.
func (n *YOLONetwork) Close() error {
if n.cNet == nil {
return errNetworkNotInit
}
C.free_network_ptr(n.cNet)
n.cNet = nil
return nil
}
// Detect specified image
func (n *YOLONetwork) Detect(img *DarknetImage) (*DetectionResult, error) {
if n.cNet == nil {
return nil, errNetworkNotInit
}
startTime := time.Now()
result := C.perform_network_detect(n.cNet, &img.image, C.int(n.Classes), C.float(n.Threshold), C.float(n.hierarchalThreshold), C.float(n.nms), C.int(0))
endTime := time.Now()
ds := makeDetections(img, result.detections, int(result.detections_len), n.Threshold, n.Classes, n.ClassNames)
C.free_detections(result.detections, result.detections_len)
endTimeOverall := time.Now()
out := DetectionResult{
Detections: ds,
NetworkOnlyTimeTaken: endTime.Sub(startTime),
OverallTimeTaken: endTimeOverall.Sub(startTime),
}
return &out, nil
}