mirror of
https://github.com/xxjwxc/public.git
synced 2025-09-26 11:51:14 +08:00
1
This commit is contained in:
@@ -41,6 +41,10 @@ func (t *TritonInfo) Close() {
|
||||
t.conn.Close()
|
||||
}
|
||||
|
||||
func (t *TritonInfo) GetClient() triton.GRPCInferenceServiceClient {
|
||||
return t.client
|
||||
}
|
||||
|
||||
// ServerLive 心跳检测
|
||||
func (t *TritonInfo) ServerLive(ctx context.Context) (bool, error) {
|
||||
// Create context for our request with 10 second timeout
|
||||
@@ -204,3 +208,17 @@ func (t *TritonInfo) RequestFromTexts(ctx context.Context, texts []string, outTe
|
||||
}
|
||||
return outputData0, nil
|
||||
}
|
||||
|
||||
// 封装字符串为 Triton 的 BYTES 格式
|
||||
func (t *TritonInfo) EncodeStringsToBytes(inputData []string) []byte {
|
||||
var buffer bytes.Buffer
|
||||
for _, str := range inputData {
|
||||
// 写入字符串长度(4 字节,小端序)
|
||||
length := uint32(len(str))
|
||||
binary.Write(&buffer, binary.LittleEndian, length)
|
||||
|
||||
// 写入字符串内容
|
||||
buffer.Write([]byte(str))
|
||||
}
|
||||
return buffer.Bytes()
|
||||
}
|
||||
|
74
mywebsocket/broadcast.go
Normal file
74
mywebsocket/broadcast.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package mywebsocket
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/xxjwxc/public/message"
|
||||
"github.com/xxjwxc/public/mylog"
|
||||
)
|
||||
|
||||
var (
|
||||
clients sync.Map
|
||||
|
||||
joinChannel = make(chan *Socket)
|
||||
quitChannel = make(chan *Socket)
|
||||
)
|
||||
|
||||
func init() {
|
||||
go connect()
|
||||
}
|
||||
func connect() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
mylog.Errorf("recover: %v", err)
|
||||
}
|
||||
mylog.Errorf("websocket connect goroutine exit!!!")
|
||||
}()
|
||||
|
||||
mylog.Infof("connect goroutine started ...")
|
||||
for {
|
||||
select {
|
||||
case cli := <-joinChannel:
|
||||
mylog.Infof("socket join: %v", cli.ID())
|
||||
clients.Store(cli.ID(), cli)
|
||||
case cli := <-quitChannel:
|
||||
mylog.Infof("socket quit: %v", cli.ID())
|
||||
clients.Delete(cli.ID())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加socket客户端
|
||||
func AddSocketClient(cli *Socket) {
|
||||
joinChannel <- cli
|
||||
}
|
||||
|
||||
// 删除socket客户端
|
||||
func DelSocketClient(cli *Socket) {
|
||||
quitChannel <- cli
|
||||
}
|
||||
|
||||
// SendOneMessageFromId 发送消息给所有客户端
|
||||
func SendOneMessageFromId(clientId string, byteMessage []byte) error {
|
||||
cli, ok := clients.Load(clientId)
|
||||
if !ok {
|
||||
return message.GetError(message.NotFindError)
|
||||
}
|
||||
return cli.(*Socket).WriteMessage(byteMessage)
|
||||
}
|
||||
|
||||
func SendAllMessage(byteMessage []byte) {
|
||||
clients.Range(func(key, value interface{}) bool {
|
||||
value.(*Socket).WriteMessage(byteMessage)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func Length() int {
|
||||
count := 0
|
||||
clients.Range(func(key, value interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
69
mywebsocket/socket.go
Normal file
69
mywebsocket/socket.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package mywebsocket
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gws "github.com/gorilla/websocket"
|
||||
"github.com/xxjwxc/public/myglobal"
|
||||
)
|
||||
|
||||
var (
|
||||
upgrader = gws.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
HandshakeTimeout: time.Second * 10,
|
||||
}
|
||||
)
|
||||
var node *myglobal.NodeInfo
|
||||
|
||||
func init() {
|
||||
node = myglobal.GetNode()
|
||||
}
|
||||
|
||||
func NewSocketUpgrader(sessionId string, w http.ResponseWriter, r *http.Request) (*Socket, error) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sessionId == "" {
|
||||
sessionId = node.GetIDStr()
|
||||
}
|
||||
|
||||
return &Socket{
|
||||
SessionId: sessionId,
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Socket struct {
|
||||
SessionId string
|
||||
conn *gws.Conn
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
func (s *Socket) ID() string {
|
||||
return s.SessionId
|
||||
}
|
||||
|
||||
func (s *Socket) WriteMessage(byteMessage []byte) error {
|
||||
s.Mutex.Lock()
|
||||
defer s.Mutex.Unlock()
|
||||
return s.conn.WriteMessage(gws.BinaryMessage, byteMessage)
|
||||
}
|
||||
|
||||
func (s *Socket) ReadMessage() (messageType int, p []byte, err error) {
|
||||
return s.conn.ReadMessage()
|
||||
}
|
||||
|
||||
func (s *Socket) Close() {
|
||||
s.Mutex.Lock()
|
||||
defer s.Mutex.Unlock()
|
||||
s.conn.Close()
|
||||
}
|
||||
|
||||
// func (s *Socket) Upgrade(w http.ResponseWriter, r *http.Request) (*gws.Conn, error) {
|
||||
// return upgrader.Upgrade(w, r, nil)
|
||||
// }
|
1
mywebsocket/socket_test.go
Normal file
1
mywebsocket/socket_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package mywebsocket
|
Reference in New Issue
Block a user