From 5d6050f78e7b81193041ba566c01fbdb1d82faef Mon Sep 17 00:00:00 2001 From: kweijack Date: Wed, 3 Jul 2024 11:58:10 +0000 Subject: [PATCH] feat: added grpc client, model interface --- client/client.go | 111 +++++++++++++++++++++++++++++++++++++++++++++++ client/conn.go | 61 ++++++++++++++++++++++++++ model/model.go | 16 +++++++ 3 files changed, 188 insertions(+) create mode 100644 client/client.go create mode 100644 client/conn.go create mode 100644 model/model.go diff --git a/client/client.go b/client/client.go new file mode 100644 index 0000000..7d73405 --- /dev/null +++ b/client/client.go @@ -0,0 +1,111 @@ +package client + +import ( + "github.com/dev6699/face/model" + "github.com/dev6699/face/protobuf" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +var ( + conn *grpc.ClientConn + client protobuf.GRPCInferenceServiceClient + modelsMetadata = make(map[string]*modelMetadata) +) + +// Init initializes grpc connection and fetch all models metadata from grpc server. +func Init(url string, models []model.ModelMeta) error { + var err error + conn, err = grpc.NewClient(url, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return err + } + + client = protobuf.NewGRPCInferenceServiceClient(conn) + for _, m := range models { + + meta, err := newModelMetadata(client, m.ModelName(), m.ModelVersion()) + if err != nil { + return err + } + + modelsMetadata[m.ModelName()] = meta + } + return nil +} + +// Close tears down underlying grpc connection. +func Close() error { + return conn.Close() +} + +// Infer is a generic function takes in modelFactory to create model.Model and input for model.PreProcess(), +// and performs infer request based on model metadata automatically. +func Infer[I, O any](modelFactory func() model.Model[I, O], input I) (O, error) { + var zeroOutput O + model := modelFactory() + contents, err := model.PreProcess(input) + if err != nil { + return zeroOutput, err + } + + modelInferRequest := modelsMetadata[model.ModelName()].formInferRequest(contents) + + inferResponse, err := ModelInferRequest(client, modelInferRequest) + if err != nil { + return zeroOutput, err + } + + return model.PostProcess(inferResponse.RawOutputContents) +} + +type modelMetadata struct { + modelName string + modelVersion string + *protobuf.ModelMetadataResponse +} + +func newModelMetadata(client protobuf.GRPCInferenceServiceClient, modelName string, modelVersion string) (*modelMetadata, error) { + metaResponse, err := ModelMetadataRequest(client, modelName, modelVersion) + if err != nil { + return nil, err + } + + return &modelMetadata{ + modelName: modelName, + modelVersion: modelVersion, + ModelMetadataResponse: metaResponse, + }, nil +} + +func (m *modelMetadata) formInferRequest(contents []*protobuf.InferTensorContents) *protobuf.ModelInferRequest { + + inputs := []*protobuf.ModelInferRequest_InferInputTensor{} + for i, c := range contents { + input := m.Inputs[i] + shape := input.Shape + if shape[0] == -1 { + shape[0] = 1 + } + inputs = append(inputs, &protobuf.ModelInferRequest_InferInputTensor{ + Name: input.Name, + Datatype: input.Datatype, + Shape: shape, + Contents: c, + }) + } + + outputs := make([]*protobuf.ModelInferRequest_InferRequestedOutputTensor, len(m.Outputs)) + for i, o := range m.Outputs { + outputs[i] = &protobuf.ModelInferRequest_InferRequestedOutputTensor{ + Name: o.Name, + } + } + + return &protobuf.ModelInferRequest{ + ModelName: m.modelName, + ModelVersion: m.modelVersion, + Inputs: inputs, + Outputs: outputs, + } +} diff --git a/client/conn.go b/client/conn.go new file mode 100644 index 0000000..8b3679d --- /dev/null +++ b/client/conn.go @@ -0,0 +1,61 @@ +package client + +import ( + "context" + "time" + + "github.com/dev6699/face/protobuf" +) + +var requestTimeout = 10 * time.Second + +func ServerLiveRequest(client protobuf.GRPCInferenceServiceClient) (*protobuf.ServerLiveResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + serverLiveRequest := protobuf.ServerLiveRequest{} + serverLiveResponse, err := client.ServerLive(ctx, &serverLiveRequest) + if err != nil { + return nil, err + } + return serverLiveResponse, nil +} + +func ServerReadyRequest(client protobuf.GRPCInferenceServiceClient) (*protobuf.ServerReadyResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + serverReadyRequest := protobuf.ServerReadyRequest{} + serverReadyResponse, err := client.ServerReady(ctx, &serverReadyRequest) + if err != nil { + return nil, err + } + return serverReadyResponse, nil +} + +func ModelMetadataRequest(client protobuf.GRPCInferenceServiceClient, modelName string, modelVersion string) (*protobuf.ModelMetadataResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + modelMetadataRequest := protobuf.ModelMetadataRequest{ + Name: modelName, + Version: modelVersion, + } + modelMetadataResponse, err := client.ModelMetadata(ctx, &modelMetadataRequest) + if err != nil { + return nil, err + } + return modelMetadataResponse, nil +} + +func ModelInferRequest(client protobuf.GRPCInferenceServiceClient, modelInferRequest *protobuf.ModelInferRequest) (*protobuf.ModelInferResponse, error) { + ctx, cancel := context.WithTimeout(context.Background(), requestTimeout) + defer cancel() + + modelInferResponse, err := client.ModelInfer(ctx, modelInferRequest) + if err != nil { + return nil, err + } + + return modelInferResponse, nil +} diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..fbc3a24 --- /dev/null +++ b/model/model.go @@ -0,0 +1,16 @@ +package model + +import ( + "github.com/dev6699/face/protobuf" +) + +type Model[I any, O any] interface { + ModelMeta + PreProcess(input I) ([]*protobuf.InferTensorContents, error) + PostProcess(rawOutputContents [][]byte) (O, error) +} + +type ModelMeta interface { + ModelName() string + ModelVersion() string +}