mirror of
https://github.com/luscis/openlan.git
synced 2025-10-06 17:17:00 +08:00
fea: socket: support to negotiate key
This commit is contained in:
@@ -1,19 +1,22 @@
|
||||
package libol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ClInit = 0x00
|
||||
ClConnected = 0x01
|
||||
ClUnAuth = 0x02
|
||||
ClAuth = 0x03
|
||||
ClConnecting = 0x04
|
||||
ClTerminal = 0x05
|
||||
ClClosed = 0x06
|
||||
ClInit = 0x00
|
||||
ClConnected = 0x01
|
||||
ClUnAuth = 0x02
|
||||
ClAuth = 0x03
|
||||
ClConnecting = 0x04
|
||||
ClTerminal = 0x05
|
||||
ClClosed = 0x06
|
||||
ClNegotiating = 0x07
|
||||
ClNegotiated = 0x08
|
||||
)
|
||||
|
||||
type SocketStatus uint8
|
||||
@@ -34,6 +37,10 @@ func (s SocketStatus) String() string {
|
||||
return "connecting"
|
||||
case ClTerminal:
|
||||
return "terminal"
|
||||
case ClNegotiating:
|
||||
return "negotiating"
|
||||
case ClNegotiated:
|
||||
return "negotiated"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -77,6 +84,8 @@ type SocketClient interface {
|
||||
SetListener(listener ClientListener)
|
||||
SetTimeout(v int64)
|
||||
Out() *SubLogger
|
||||
SetKey(key string)
|
||||
Key() string
|
||||
}
|
||||
|
||||
type StreamSocket struct {
|
||||
@@ -89,6 +98,7 @@ type StreamSocket struct {
|
||||
remoteAddr string
|
||||
localAddr string
|
||||
address string
|
||||
Block *BlockCrypt
|
||||
}
|
||||
|
||||
func (t *StreamSocket) LocalAddr() string {
|
||||
@@ -147,6 +157,25 @@ func (t *StreamSocket) ReadMsg() (*FrameMessage, error) {
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func (t *StreamSocket) SetKey(key string) {
|
||||
if block := t.message.Crypt(); block != nil {
|
||||
block.Update(key)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *StreamSocket) Key() string {
|
||||
key := ""
|
||||
if block := t.message.Crypt(); block != nil {
|
||||
key = block.key
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
type SocketConfig struct {
|
||||
Address string
|
||||
Block *BlockCrypt
|
||||
}
|
||||
|
||||
type SocketClientImpl struct {
|
||||
*StreamSocket
|
||||
lock sync.RWMutex
|
||||
@@ -158,22 +187,56 @@ type SocketClientImpl struct {
|
||||
timeout int64 // sec for read and write timeout
|
||||
}
|
||||
|
||||
func NewSocketClient(address string, message Messager) *SocketClientImpl {
|
||||
func NewSocketClient(cfg SocketConfig, message Messager) *SocketClientImpl {
|
||||
return &SocketClientImpl{
|
||||
StreamSocket: &StreamSocket{
|
||||
maxSize: 1514,
|
||||
minSize: 15,
|
||||
message: message,
|
||||
statistics: NewSafeStrInt64(),
|
||||
out: NewSubLogger(address),
|
||||
remoteAddr: address,
|
||||
address: address,
|
||||
out: NewSubLogger(cfg.Address),
|
||||
remoteAddr: cfg.Address,
|
||||
address: cfg.Address,
|
||||
Block: cfg.Block,
|
||||
},
|
||||
newTime: time.Now().Unix(),
|
||||
status: ClInit,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SocketClientImpl) negotiate() error {
|
||||
if s.Key() == "" {
|
||||
return nil
|
||||
}
|
||||
key := GenLetters(64)
|
||||
request := NewControlFrame(NegoReq, key)
|
||||
if err := s.WriteMsg(request); err != nil {
|
||||
return err
|
||||
}
|
||||
s.status = ClNegotiating
|
||||
if reply, err := s.ReadMsg(); err == nil {
|
||||
if reply.IsControl() {
|
||||
action, params := reply.CmdAndParams()
|
||||
if action != NegoResp {
|
||||
return NewErr("wrong message type: %s", action)
|
||||
}
|
||||
if bytes.Compare(key, params) != 0 {
|
||||
return NewErr("negotiate key failed: %s != %s", key, params)
|
||||
}
|
||||
if block := s.message.Crypt(); block != nil {
|
||||
block.Update(string(key))
|
||||
}
|
||||
s.status = ClNegotiated
|
||||
return nil
|
||||
} else {
|
||||
Info("SocketClientImpl.negotiate %s", reply.String())
|
||||
}
|
||||
return NewErr("wrong message type")
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// MUST IMPLEMENT
|
||||
func (s *SocketClientImpl) Connect() error {
|
||||
return nil
|
||||
@@ -265,7 +328,7 @@ func (s *SocketClientImpl) SetTimeout(v int64) {
|
||||
s.timeout = v
|
||||
}
|
||||
|
||||
func (s *SocketClientImpl) updateConn(conn net.Conn) {
|
||||
func (s *SocketClientImpl) update(conn net.Conn) {
|
||||
if conn != nil {
|
||||
s.connection = conn
|
||||
s.connectedTime = time.Now().Unix()
|
||||
@@ -280,14 +343,21 @@ func (s *SocketClientImpl) updateConn(conn net.Conn) {
|
||||
s.remoteAddr = ""
|
||||
s.message.Flush()
|
||||
}
|
||||
s.out.Event("SocketClientImpl.updateConn: %s %s", s.localAddr, s.remoteAddr)
|
||||
if s.Block != nil {
|
||||
s.message.SetCrypt(s.Block)
|
||||
}
|
||||
s.out.Event("SocketClientImpl.update: %s %s", s.localAddr, s.remoteAddr)
|
||||
}
|
||||
|
||||
func (s *SocketClientImpl) SetConnection(conn net.Conn) {
|
||||
func (s *SocketClientImpl) Reset(conn net.Conn) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
s.updateConn(conn)
|
||||
s.update(conn)
|
||||
s.status = ClConnected
|
||||
if err := s.negotiate(); err != nil {
|
||||
s.out.Error("SocketClientImpl.Reset %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// MUST IMPLEMENT
|
||||
@@ -380,14 +450,48 @@ func (t *SocketServerImpl) OffClient(client SocketClient) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SocketServerImpl) negotiate(client SocketClient) error {
|
||||
if client.Key() == "" {
|
||||
return nil
|
||||
}
|
||||
if request, err := client.ReadMsg(); err == nil {
|
||||
if request.IsControl() {
|
||||
client.SetStatus(ClNegotiated)
|
||||
action, params := request.CmdAndParams()
|
||||
if action == NegoReq {
|
||||
Info("SocketServerImpl.negotiate %s", params)
|
||||
reply := NewControlFrame(NegoResp, params)
|
||||
if err := client.WriteMsg(reply); err != nil {
|
||||
return err
|
||||
}
|
||||
client.SetKey(string(params))
|
||||
return nil
|
||||
}
|
||||
return NewErr("wrong message type: %s", action)
|
||||
} else {
|
||||
Info("SocketServerImpl.negotiate %s", request.String())
|
||||
}
|
||||
return NewErr("wrong message type")
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SocketServerImpl) doOnClient(call ServerListener, client SocketClient) {
|
||||
Info("SocketServerImpl.doOnClient: +%s", client)
|
||||
_ = t.clients.Set(client.RemoteAddr(), client)
|
||||
if call.OnClient != nil {
|
||||
_ = call.OnClient(client)
|
||||
if call.ReadAt != nil {
|
||||
Go(func() { t.Read(client, call.ReadAt) })
|
||||
}
|
||||
Go(func() {
|
||||
if err := t.negotiate(client); err != nil {
|
||||
t.OffClient(client)
|
||||
Warn("SocketServerImpl.doOnClient %s %s", client, err)
|
||||
return
|
||||
}
|
||||
_ = call.OnClient(client)
|
||||
if call.ReadAt != nil {
|
||||
t.Read(client, call.ReadAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user