feat: added grpc client, model interface

This commit is contained in:
kweijack
2024-07-03 11:58:10 +00:00
parent 36ac248d2e
commit 5d6050f78e
3 changed files with 188 additions and 0 deletions

111
client/client.go Normal file
View File

@@ -0,0 +1,111 @@
package client
import (
"github.com/dev6699/face/model"
"github.com/dev6699/face/protobuf"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
var (
conn *grpc.ClientConn
client protobuf.GRPCInferenceServiceClient
modelsMetadata = make(map[string]*modelMetadata)
)
// Init initializes grpc connection and fetch all models metadata from grpc server.
func Init(url string, models []model.ModelMeta) error {
var err error
conn, err = grpc.NewClient(url, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return err
}
client = protobuf.NewGRPCInferenceServiceClient(conn)
for _, m := range models {
meta, err := newModelMetadata(client, m.ModelName(), m.ModelVersion())
if err != nil {
return err
}
modelsMetadata[m.ModelName()] = meta
}
return nil
}
// Close tears down underlying grpc connection.
func Close() error {
return conn.Close()
}
// Infer is a generic function takes in modelFactory to create model.Model and input for model.PreProcess(),
// and performs infer request based on model metadata automatically.
func Infer[I, O any](modelFactory func() model.Model[I, O], input I) (O, error) {
var zeroOutput O
model := modelFactory()
contents, err := model.PreProcess(input)
if err != nil {
return zeroOutput, err
}
modelInferRequest := modelsMetadata[model.ModelName()].formInferRequest(contents)
inferResponse, err := ModelInferRequest(client, modelInferRequest)
if err != nil {
return zeroOutput, err
}
return model.PostProcess(inferResponse.RawOutputContents)
}
type modelMetadata struct {
modelName string
modelVersion string
*protobuf.ModelMetadataResponse
}
func newModelMetadata(client protobuf.GRPCInferenceServiceClient, modelName string, modelVersion string) (*modelMetadata, error) {
metaResponse, err := ModelMetadataRequest(client, modelName, modelVersion)
if err != nil {
return nil, err
}
return &modelMetadata{
modelName: modelName,
modelVersion: modelVersion,
ModelMetadataResponse: metaResponse,
}, nil
}
func (m *modelMetadata) formInferRequest(contents []*protobuf.InferTensorContents) *protobuf.ModelInferRequest {
inputs := []*protobuf.ModelInferRequest_InferInputTensor{}
for i, c := range contents {
input := m.Inputs[i]
shape := input.Shape
if shape[0] == -1 {
shape[0] = 1
}
inputs = append(inputs, &protobuf.ModelInferRequest_InferInputTensor{
Name: input.Name,
Datatype: input.Datatype,
Shape: shape,
Contents: c,
})
}
outputs := make([]*protobuf.ModelInferRequest_InferRequestedOutputTensor, len(m.Outputs))
for i, o := range m.Outputs {
outputs[i] = &protobuf.ModelInferRequest_InferRequestedOutputTensor{
Name: o.Name,
}
}
return &protobuf.ModelInferRequest{
ModelName: m.modelName,
ModelVersion: m.modelVersion,
Inputs: inputs,
Outputs: outputs,
}
}

61
client/conn.go Normal file
View File

@@ -0,0 +1,61 @@
package client
import (
"context"
"time"
"github.com/dev6699/face/protobuf"
)
var requestTimeout = 10 * time.Second
func ServerLiveRequest(client protobuf.GRPCInferenceServiceClient) (*protobuf.ServerLiveResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
serverLiveRequest := protobuf.ServerLiveRequest{}
serverLiveResponse, err := client.ServerLive(ctx, &serverLiveRequest)
if err != nil {
return nil, err
}
return serverLiveResponse, nil
}
func ServerReadyRequest(client protobuf.GRPCInferenceServiceClient) (*protobuf.ServerReadyResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
serverReadyRequest := protobuf.ServerReadyRequest{}
serverReadyResponse, err := client.ServerReady(ctx, &serverReadyRequest)
if err != nil {
return nil, err
}
return serverReadyResponse, nil
}
func ModelMetadataRequest(client protobuf.GRPCInferenceServiceClient, modelName string, modelVersion string) (*protobuf.ModelMetadataResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
modelMetadataRequest := protobuf.ModelMetadataRequest{
Name: modelName,
Version: modelVersion,
}
modelMetadataResponse, err := client.ModelMetadata(ctx, &modelMetadataRequest)
if err != nil {
return nil, err
}
return modelMetadataResponse, nil
}
func ModelInferRequest(client protobuf.GRPCInferenceServiceClient, modelInferRequest *protobuf.ModelInferRequest) (*protobuf.ModelInferResponse, error) {
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
defer cancel()
modelInferResponse, err := client.ModelInfer(ctx, modelInferRequest)
if err != nil {
return nil, err
}
return modelInferResponse, nil
}

16
model/model.go Normal file
View File

@@ -0,0 +1,16 @@
package model
import (
"github.com/dev6699/face/protobuf"
)
type Model[I any, O any] interface {
ModelMeta
PreProcess(input I) ([]*protobuf.InferTensorContents, error)
PostProcess(rawOutputContents [][]byte) (O, error)
}
type ModelMeta interface {
ModelName() string
ModelVersion() string
}