From 64d2f7032638e46f53f1c4384866265480103fc3 Mon Sep 17 00:00:00 2001 From: xxj <346944475@qq.com> Date: Fri, 28 Mar 2025 18:49:19 +0800 Subject: [PATCH] 1 --- mytriton/triton.go | 18 ++++++++++ mywebsocket/broadcast.go | 74 ++++++++++++++++++++++++++++++++++++++ mywebsocket/socket.go | 69 +++++++++++++++++++++++++++++++++++ mywebsocket/socket_test.go | 1 + 4 files changed, 162 insertions(+) create mode 100644 mywebsocket/broadcast.go create mode 100644 mywebsocket/socket.go create mode 100644 mywebsocket/socket_test.go diff --git a/mytriton/triton.go b/mytriton/triton.go index cd16ad9..89a0708 100644 --- a/mytriton/triton.go +++ b/mytriton/triton.go @@ -41,6 +41,10 @@ func (t *TritonInfo) Close() { t.conn.Close() } +func (t *TritonInfo) GetClient() triton.GRPCInferenceServiceClient { + return t.client +} + // ServerLive 心跳检测 func (t *TritonInfo) ServerLive(ctx context.Context) (bool, error) { // Create context for our request with 10 second timeout @@ -204,3 +208,17 @@ func (t *TritonInfo) RequestFromTexts(ctx context.Context, texts []string, outTe } return outputData0, nil } + +// 封装字符串为 Triton 的 BYTES 格式 +func (t *TritonInfo) EncodeStringsToBytes(inputData []string) []byte { + var buffer bytes.Buffer + for _, str := range inputData { + // 写入字符串长度(4 字节,小端序) + length := uint32(len(str)) + binary.Write(&buffer, binary.LittleEndian, length) + + // 写入字符串内容 + buffer.Write([]byte(str)) + } + return buffer.Bytes() +} diff --git a/mywebsocket/broadcast.go b/mywebsocket/broadcast.go new file mode 100644 index 0000000..513c6c4 --- /dev/null +++ b/mywebsocket/broadcast.go @@ -0,0 +1,74 @@ +package mywebsocket + +import ( + "sync" + + "github.com/xxjwxc/public/message" + "github.com/xxjwxc/public/mylog" +) + +var ( + clients sync.Map + + joinChannel = make(chan *Socket) + quitChannel = make(chan *Socket) +) + +func init() { + go connect() +} +func connect() { + defer func() { + if err := recover(); err != nil { + mylog.Errorf("recover: %v", err) + } + mylog.Errorf("websocket connect goroutine exit!!!") + }() + + mylog.Infof("connect goroutine started ...") + for { + select { + case cli := <-joinChannel: + mylog.Infof("socket join: %v", cli.ID()) + clients.Store(cli.ID(), cli) + case cli := <-quitChannel: + mylog.Infof("socket quit: %v", cli.ID()) + clients.Delete(cli.ID()) + } + } +} + +// 添加socket客户端 +func AddSocketClient(cli *Socket) { + joinChannel <- cli +} + +// 删除socket客户端 +func DelSocketClient(cli *Socket) { + quitChannel <- cli +} + +// SendOneMessageFromId 发送消息给所有客户端 +func SendOneMessageFromId(clientId string, byteMessage []byte) error { + cli, ok := clients.Load(clientId) + if !ok { + return message.GetError(message.NotFindError) + } + return cli.(*Socket).WriteMessage(byteMessage) +} + +func SendAllMessage(byteMessage []byte) { + clients.Range(func(key, value interface{}) bool { + value.(*Socket).WriteMessage(byteMessage) + return true + }) +} + +func Length() int { + count := 0 + clients.Range(func(key, value interface{}) bool { + count++ + return true + }) + return count +} diff --git a/mywebsocket/socket.go b/mywebsocket/socket.go new file mode 100644 index 0000000..d7bad42 --- /dev/null +++ b/mywebsocket/socket.go @@ -0,0 +1,69 @@ +package mywebsocket + +import ( + "net/http" + "sync" + "time" + + gws "github.com/gorilla/websocket" + "github.com/xxjwxc/public/myglobal" +) + +var ( + upgrader = gws.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true + }, + HandshakeTimeout: time.Second * 10, + } +) +var node *myglobal.NodeInfo + +func init() { + node = myglobal.GetNode() +} + +func NewSocketUpgrader(sessionId string, w http.ResponseWriter, r *http.Request) (*Socket, error) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, err + } + if sessionId == "" { + sessionId = node.GetIDStr() + } + + return &Socket{ + SessionId: sessionId, + conn: conn, + }, nil +} + +type Socket struct { + SessionId string + conn *gws.Conn + sync.Mutex +} + +func (s *Socket) ID() string { + return s.SessionId +} + +func (s *Socket) WriteMessage(byteMessage []byte) error { + s.Mutex.Lock() + defer s.Mutex.Unlock() + return s.conn.WriteMessage(gws.BinaryMessage, byteMessage) +} + +func (s *Socket) ReadMessage() (messageType int, p []byte, err error) { + return s.conn.ReadMessage() +} + +func (s *Socket) Close() { + s.Mutex.Lock() + defer s.Mutex.Unlock() + s.conn.Close() +} + +// func (s *Socket) Upgrade(w http.ResponseWriter, r *http.Request) (*gws.Conn, error) { +// return upgrader.Upgrade(w, r, nil) +// } diff --git a/mywebsocket/socket_test.go b/mywebsocket/socket_test.go new file mode 100644 index 0000000..173ff83 --- /dev/null +++ b/mywebsocket/socket_test.go @@ -0,0 +1 @@ +package mywebsocket