fea: socket: support to negotiate key

This commit is contained in:
Daniel Ding
2022-10-01 08:21:53 +08:00
parent 1f5bdbe45c
commit 4a6eaeb802
24 changed files with 315 additions and 128 deletions

View File

@@ -1,19 +1,22 @@
package libol
import (
"bytes"
"net"
"sync"
"time"
)
const (
ClInit = 0x00
ClConnected = 0x01
ClUnAuth = 0x02
ClAuth = 0x03
ClConnecting = 0x04
ClTerminal = 0x05
ClClosed = 0x06
ClInit = 0x00
ClConnected = 0x01
ClUnAuth = 0x02
ClAuth = 0x03
ClConnecting = 0x04
ClTerminal = 0x05
ClClosed = 0x06
ClNegotiating = 0x07
ClNegotiated = 0x08
)
type SocketStatus uint8
@@ -34,6 +37,10 @@ func (s SocketStatus) String() string {
return "connecting"
case ClTerminal:
return "terminal"
case ClNegotiating:
return "negotiating"
case ClNegotiated:
return "negotiated"
}
return ""
}
@@ -77,6 +84,8 @@ type SocketClient interface {
SetListener(listener ClientListener)
SetTimeout(v int64)
Out() *SubLogger
SetKey(key string)
Key() string
}
type StreamSocket struct {
@@ -89,6 +98,7 @@ type StreamSocket struct {
remoteAddr string
localAddr string
address string
Block *BlockCrypt
}
func (t *StreamSocket) LocalAddr() string {
@@ -147,6 +157,25 @@ func (t *StreamSocket) ReadMsg() (*FrameMessage, error) {
return frame, nil
}
func (t *StreamSocket) SetKey(key string) {
if block := t.message.Crypt(); block != nil {
block.Update(key)
}
}
func (t *StreamSocket) Key() string {
key := ""
if block := t.message.Crypt(); block != nil {
key = block.key
}
return key
}
type SocketConfig struct {
Address string
Block *BlockCrypt
}
type SocketClientImpl struct {
*StreamSocket
lock sync.RWMutex
@@ -158,22 +187,56 @@ type SocketClientImpl struct {
timeout int64 // sec for read and write timeout
}
func NewSocketClient(address string, message Messager) *SocketClientImpl {
func NewSocketClient(cfg SocketConfig, message Messager) *SocketClientImpl {
return &SocketClientImpl{
StreamSocket: &StreamSocket{
maxSize: 1514,
minSize: 15,
message: message,
statistics: NewSafeStrInt64(),
out: NewSubLogger(address),
remoteAddr: address,
address: address,
out: NewSubLogger(cfg.Address),
remoteAddr: cfg.Address,
address: cfg.Address,
Block: cfg.Block,
},
newTime: time.Now().Unix(),
status: ClInit,
}
}
func (s *SocketClientImpl) negotiate() error {
if s.Key() == "" {
return nil
}
key := GenLetters(64)
request := NewControlFrame(NegoReq, key)
if err := s.WriteMsg(request); err != nil {
return err
}
s.status = ClNegotiating
if reply, err := s.ReadMsg(); err == nil {
if reply.IsControl() {
action, params := reply.CmdAndParams()
if action != NegoResp {
return NewErr("wrong message type: %s", action)
}
if bytes.Compare(key, params) != 0 {
return NewErr("negotiate key failed: %s != %s", key, params)
}
if block := s.message.Crypt(); block != nil {
block.Update(string(key))
}
s.status = ClNegotiated
return nil
} else {
Info("SocketClientImpl.negotiate %s", reply.String())
}
return NewErr("wrong message type")
} else {
return err
}
}
// MUST IMPLEMENT
func (s *SocketClientImpl) Connect() error {
return nil
@@ -265,7 +328,7 @@ func (s *SocketClientImpl) SetTimeout(v int64) {
s.timeout = v
}
func (s *SocketClientImpl) updateConn(conn net.Conn) {
func (s *SocketClientImpl) update(conn net.Conn) {
if conn != nil {
s.connection = conn
s.connectedTime = time.Now().Unix()
@@ -280,14 +343,21 @@ func (s *SocketClientImpl) updateConn(conn net.Conn) {
s.remoteAddr = ""
s.message.Flush()
}
s.out.Event("SocketClientImpl.updateConn: %s %s", s.localAddr, s.remoteAddr)
if s.Block != nil {
s.message.SetCrypt(s.Block)
}
s.out.Event("SocketClientImpl.update: %s %s", s.localAddr, s.remoteAddr)
}
func (s *SocketClientImpl) SetConnection(conn net.Conn) {
func (s *SocketClientImpl) Reset(conn net.Conn) {
s.lock.Lock()
defer s.lock.Unlock()
s.updateConn(conn)
s.update(conn)
s.status = ClConnected
if err := s.negotiate(); err != nil {
s.out.Error("SocketClientImpl.Reset %s", err)
return
}
}
// MUST IMPLEMENT
@@ -380,14 +450,48 @@ func (t *SocketServerImpl) OffClient(client SocketClient) {
}
}
func (t *SocketServerImpl) negotiate(client SocketClient) error {
if client.Key() == "" {
return nil
}
if request, err := client.ReadMsg(); err == nil {
if request.IsControl() {
client.SetStatus(ClNegotiated)
action, params := request.CmdAndParams()
if action == NegoReq {
Info("SocketServerImpl.negotiate %s", params)
reply := NewControlFrame(NegoResp, params)
if err := client.WriteMsg(reply); err != nil {
return err
}
client.SetKey(string(params))
return nil
}
return NewErr("wrong message type: %s", action)
} else {
Info("SocketServerImpl.negotiate %s", request.String())
}
return NewErr("wrong message type")
} else {
return err
}
}
func (t *SocketServerImpl) doOnClient(call ServerListener, client SocketClient) {
Info("SocketServerImpl.doOnClient: +%s", client)
_ = t.clients.Set(client.RemoteAddr(), client)
if call.OnClient != nil {
_ = call.OnClient(client)
if call.ReadAt != nil {
Go(func() { t.Read(client, call.ReadAt) })
}
Go(func() {
if err := t.negotiate(client); err != nil {
t.OffClient(client)
Warn("SocketServerImpl.doOnClient %s %s", client, err)
return
}
_ = call.OnClient(client)
if call.ReadAt != nil {
t.Read(client, call.ReadAt)
}
})
}
}