This commit is contained in:
xxj
2025-03-28 18:49:19 +08:00
parent e1b3f17fd4
commit 64d2f70326
4 changed files with 162 additions and 0 deletions

View File

@@ -41,6 +41,10 @@ func (t *TritonInfo) Close() {
t.conn.Close()
}
func (t *TritonInfo) GetClient() triton.GRPCInferenceServiceClient {
return t.client
}
// ServerLive 心跳检测
func (t *TritonInfo) ServerLive(ctx context.Context) (bool, error) {
// Create context for our request with 10 second timeout
@@ -204,3 +208,17 @@ func (t *TritonInfo) RequestFromTexts(ctx context.Context, texts []string, outTe
}
return outputData0, nil
}
// 封装字符串为 Triton 的 BYTES 格式
func (t *TritonInfo) EncodeStringsToBytes(inputData []string) []byte {
var buffer bytes.Buffer
for _, str := range inputData {
// 写入字符串长度4 字节,小端序)
length := uint32(len(str))
binary.Write(&buffer, binary.LittleEndian, length)
// 写入字符串内容
buffer.Write([]byte(str))
}
return buffer.Bytes()
}

74
mywebsocket/broadcast.go Normal file
View File

@@ -0,0 +1,74 @@
package mywebsocket
import (
"sync"
"github.com/xxjwxc/public/message"
"github.com/xxjwxc/public/mylog"
)
var (
clients sync.Map
joinChannel = make(chan *Socket)
quitChannel = make(chan *Socket)
)
func init() {
go connect()
}
func connect() {
defer func() {
if err := recover(); err != nil {
mylog.Errorf("recover: %v", err)
}
mylog.Errorf("websocket connect goroutine exit!!!")
}()
mylog.Infof("connect goroutine started ...")
for {
select {
case cli := <-joinChannel:
mylog.Infof("socket join: %v", cli.ID())
clients.Store(cli.ID(), cli)
case cli := <-quitChannel:
mylog.Infof("socket quit: %v", cli.ID())
clients.Delete(cli.ID())
}
}
}
// 添加socket客户端
func AddSocketClient(cli *Socket) {
joinChannel <- cli
}
// 删除socket客户端
func DelSocketClient(cli *Socket) {
quitChannel <- cli
}
// SendOneMessageFromId 发送消息给所有客户端
func SendOneMessageFromId(clientId string, byteMessage []byte) error {
cli, ok := clients.Load(clientId)
if !ok {
return message.GetError(message.NotFindError)
}
return cli.(*Socket).WriteMessage(byteMessage)
}
func SendAllMessage(byteMessage []byte) {
clients.Range(func(key, value interface{}) bool {
value.(*Socket).WriteMessage(byteMessage)
return true
})
}
func Length() int {
count := 0
clients.Range(func(key, value interface{}) bool {
count++
return true
})
return count
}

69
mywebsocket/socket.go Normal file
View File

@@ -0,0 +1,69 @@
package mywebsocket
import (
"net/http"
"sync"
"time"
gws "github.com/gorilla/websocket"
"github.com/xxjwxc/public/myglobal"
)
var (
upgrader = gws.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
HandshakeTimeout: time.Second * 10,
}
)
var node *myglobal.NodeInfo
func init() {
node = myglobal.GetNode()
}
func NewSocketUpgrader(sessionId string, w http.ResponseWriter, r *http.Request) (*Socket, error) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return nil, err
}
if sessionId == "" {
sessionId = node.GetIDStr()
}
return &Socket{
SessionId: sessionId,
conn: conn,
}, nil
}
type Socket struct {
SessionId string
conn *gws.Conn
sync.Mutex
}
func (s *Socket) ID() string {
return s.SessionId
}
func (s *Socket) WriteMessage(byteMessage []byte) error {
s.Mutex.Lock()
defer s.Mutex.Unlock()
return s.conn.WriteMessage(gws.BinaryMessage, byteMessage)
}
func (s *Socket) ReadMessage() (messageType int, p []byte, err error) {
return s.conn.ReadMessage()
}
func (s *Socket) Close() {
s.Mutex.Lock()
defer s.Mutex.Unlock()
s.conn.Close()
}
// func (s *Socket) Upgrade(w http.ResponseWriter, r *http.Request) (*gws.Conn, error) {
// return upgrader.Upgrade(w, r, nil)
// }

View File

@@ -0,0 +1 @@
package mywebsocket