mirror of
https://github.com/xxjwxc/public.git
synced 2025-09-26 20:01:19 +08:00
1
This commit is contained in:
@@ -41,6 +41,10 @@ func (t *TritonInfo) Close() {
|
|||||||
t.conn.Close()
|
t.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TritonInfo) GetClient() triton.GRPCInferenceServiceClient {
|
||||||
|
return t.client
|
||||||
|
}
|
||||||
|
|
||||||
// ServerLive 心跳检测
|
// ServerLive 心跳检测
|
||||||
func (t *TritonInfo) ServerLive(ctx context.Context) (bool, error) {
|
func (t *TritonInfo) ServerLive(ctx context.Context) (bool, error) {
|
||||||
// Create context for our request with 10 second timeout
|
// Create context for our request with 10 second timeout
|
||||||
@@ -204,3 +208,17 @@ func (t *TritonInfo) RequestFromTexts(ctx context.Context, texts []string, outTe
|
|||||||
}
|
}
|
||||||
return outputData0, nil
|
return outputData0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 封装字符串为 Triton 的 BYTES 格式
|
||||||
|
func (t *TritonInfo) EncodeStringsToBytes(inputData []string) []byte {
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
for _, str := range inputData {
|
||||||
|
// 写入字符串长度(4 字节,小端序)
|
||||||
|
length := uint32(len(str))
|
||||||
|
binary.Write(&buffer, binary.LittleEndian, length)
|
||||||
|
|
||||||
|
// 写入字符串内容
|
||||||
|
buffer.Write([]byte(str))
|
||||||
|
}
|
||||||
|
return buffer.Bytes()
|
||||||
|
}
|
||||||
|
74
mywebsocket/broadcast.go
Normal file
74
mywebsocket/broadcast.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package mywebsocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/xxjwxc/public/message"
|
||||||
|
"github.com/xxjwxc/public/mylog"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
clients sync.Map
|
||||||
|
|
||||||
|
joinChannel = make(chan *Socket)
|
||||||
|
quitChannel = make(chan *Socket)
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go connect()
|
||||||
|
}
|
||||||
|
func connect() {
|
||||||
|
defer func() {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
mylog.Errorf("recover: %v", err)
|
||||||
|
}
|
||||||
|
mylog.Errorf("websocket connect goroutine exit!!!")
|
||||||
|
}()
|
||||||
|
|
||||||
|
mylog.Infof("connect goroutine started ...")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case cli := <-joinChannel:
|
||||||
|
mylog.Infof("socket join: %v", cli.ID())
|
||||||
|
clients.Store(cli.ID(), cli)
|
||||||
|
case cli := <-quitChannel:
|
||||||
|
mylog.Infof("socket quit: %v", cli.ID())
|
||||||
|
clients.Delete(cli.ID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加socket客户端
|
||||||
|
func AddSocketClient(cli *Socket) {
|
||||||
|
joinChannel <- cli
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除socket客户端
|
||||||
|
func DelSocketClient(cli *Socket) {
|
||||||
|
quitChannel <- cli
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendOneMessageFromId 发送消息给所有客户端
|
||||||
|
func SendOneMessageFromId(clientId string, byteMessage []byte) error {
|
||||||
|
cli, ok := clients.Load(clientId)
|
||||||
|
if !ok {
|
||||||
|
return message.GetError(message.NotFindError)
|
||||||
|
}
|
||||||
|
return cli.(*Socket).WriteMessage(byteMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendAllMessage(byteMessage []byte) {
|
||||||
|
clients.Range(func(key, value interface{}) bool {
|
||||||
|
value.(*Socket).WriteMessage(byteMessage)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Length() int {
|
||||||
|
count := 0
|
||||||
|
clients.Range(func(key, value interface{}) bool {
|
||||||
|
count++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return count
|
||||||
|
}
|
69
mywebsocket/socket.go
Normal file
69
mywebsocket/socket.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package mywebsocket
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
gws "github.com/gorilla/websocket"
|
||||||
|
"github.com/xxjwxc/public/myglobal"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
upgrader = gws.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
HandshakeTimeout: time.Second * 10,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
var node *myglobal.NodeInfo
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
node = myglobal.GetNode()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSocketUpgrader(sessionId string, w http.ResponseWriter, r *http.Request) (*Socket, error) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if sessionId == "" {
|
||||||
|
sessionId = node.GetIDStr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Socket{
|
||||||
|
SessionId: sessionId,
|
||||||
|
conn: conn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Socket struct {
|
||||||
|
SessionId string
|
||||||
|
conn *gws.Conn
|
||||||
|
sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) ID() string {
|
||||||
|
return s.SessionId
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) WriteMessage(byteMessage []byte) error {
|
||||||
|
s.Mutex.Lock()
|
||||||
|
defer s.Mutex.Unlock()
|
||||||
|
return s.conn.WriteMessage(gws.BinaryMessage, byteMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) ReadMessage() (messageType int, p []byte, err error) {
|
||||||
|
return s.conn.ReadMessage()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Socket) Close() {
|
||||||
|
s.Mutex.Lock()
|
||||||
|
defer s.Mutex.Unlock()
|
||||||
|
s.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// func (s *Socket) Upgrade(w http.ResponseWriter, r *http.Request) (*gws.Conn, error) {
|
||||||
|
// return upgrader.Upgrade(w, r, nil)
|
||||||
|
// }
|
1
mywebsocket/socket_test.go
Normal file
1
mywebsocket/socket_test.go
Normal file
@@ -0,0 +1 @@
|
|||||||
|
package mywebsocket
|
Reference in New Issue
Block a user