Files
openlan/pkg/libol/socket.go
2022-10-28 20:52:02 +08:00

627 lines
13 KiB
Go
Executable File

package libol
import (
"bytes"
"crypto/md5"
"net"
"sync"
"time"
)
const (
ClInit = 0x00
ClConnected = 0x01
ClUnAuth = 0x02
ClAuth = 0x03
ClConnecting = 0x04
ClTerminal = 0x05
ClClosed = 0x06
ClNegotiating = 0x07
ClNegotiated = 0x08
)
type SocketStatus uint8
func (s SocketStatus) String() string {
switch s {
case ClInit:
return "initialized"
case ClConnected:
return "connected"
case ClUnAuth:
return "unauthenticated"
case ClAuth:
return "authenticated"
case ClClosed:
return "closed"
case ClConnecting:
return "connecting"
case ClTerminal:
return "terminal"
case ClNegotiating:
return "negotiating"
case ClNegotiated:
return "negotiated"
}
return ""
}
// Socket Client Interface and Implement
const (
CsSendOkay = "send"
CsRecvOkay = "recv"
CsSendError = "error"
CsDropped = "dropped"
)
type ClientListener struct {
OnClose func(client SocketClient) error
OnConnected func(client SocketClient) error
OnStatus func(client SocketClient, old, new SocketStatus)
}
type SocketClient interface {
LocalAddr() string
RemoteAddr() string
Connect() error
Close()
WriteMsg(frame *FrameMessage) error
ReadMsg() (*FrameMessage, error)
UpTime() int64
AliveTime() int64
String() string
Terminal()
Private() interface{}
SetPrivate(v interface{})
Status() SocketStatus
SetStatus(v SocketStatus)
MaxSize() int
SetMaxSize(value int)
MinSize() int
IsOk() bool
Have(status SocketStatus) bool
Statistics() map[string]int64
SetListener(listener ClientListener)
SetTimeout(v int64)
Out() *SubLogger
SetKey(key string)
Key() string
}
type StreamSocket struct {
message Messager
connection net.Conn
statistics *SafeStrInt64
maxSize int
minSize int
out *SubLogger
remoteAddr string
localAddr string
address string
Block *BlockCrypt
}
func (t *StreamSocket) LocalAddr() string {
return t.localAddr
}
func (t *StreamSocket) RemoteAddr() string {
return t.remoteAddr
}
func (t *StreamSocket) String() string {
return t.address
}
func (t *StreamSocket) IsOk() bool {
return t.connection != nil
}
func (t *StreamSocket) WriteMsg(frame *FrameMessage) error {
if !t.IsOk() {
t.statistics.Add(CsDropped, 1)
return NewErr("%s not okay", t)
}
if frame.IsControl() {
action, params := frame.CmdAndParams()
Cmd("StreamSocket.WriteMsg: %s%s", action, params)
}
if t.message == nil { // default is stream message
t.message = &StreamMessagerImpl{}
}
size, err := t.message.Send(t.connection, frame)
if err != nil {
t.statistics.Add(CsSendError, 1)
return err
}
t.statistics.Add(CsSendOkay, int64(size))
return nil
}
func (t *StreamSocket) ReadMsg() (*FrameMessage, error) {
if HasLog(LOG) {
Log("StreamSocket.ReadMsg: %s", t)
}
if !t.IsOk() {
return nil, NewErr("%s not okay", t)
}
if t.message == nil { // default is stream message
t.message = &StreamMessagerImpl{}
}
frame, err := t.message.Receive(t.connection, t.maxSize, t.minSize)
if err != nil {
return nil, err
}
size := len(frame.frame)
t.statistics.Add(CsRecvOkay, int64(size))
return frame, nil
}
func (t *StreamSocket) SetKey(key string) {
if block := t.message.Crypt(); block != nil {
block.Update(key)
}
}
func (t *StreamSocket) Key() string {
key := ""
if block := t.message.Crypt(); block != nil {
key = block.key
}
return key
}
type SocketConfig struct {
Address string
Block *BlockCrypt
}
type SocketClientImpl struct {
*StreamSocket
lock sync.RWMutex
listener ClientListener
newTime int64
connectedTime int64
private interface{}
status SocketStatus
timeout int64 // sec for read and write timeout
}
func NewSocketClient(cfg SocketConfig, message Messager) *SocketClientImpl {
return &SocketClientImpl{
StreamSocket: &StreamSocket{
maxSize: 1514,
minSize: 15,
message: message,
statistics: NewSafeStrInt64(),
out: NewSubLogger(cfg.Address),
remoteAddr: cfg.Address,
address: cfg.Address,
Block: cfg.Block,
},
newTime: time.Now().Unix(),
status: ClInit,
}
}
func (s *SocketClientImpl) negotiate() error {
if s.Key() == "" {
return nil
}
key := GenLetters(64)
request := NewControlFrame(NegoReq, key)
if err := s.WriteMsg(request); err != nil {
return err
}
s.status = ClNegotiating
reply, err := s.ReadMsg()
if err != nil {
return err
}
if !reply.IsControl() {
Info("SocketClientImpl.negotiate %s", reply.String())
return NewErr("wrong message type")
}
action, params := reply.CmdAndParams()
if action != NegoResp {
return NewErr("wrong message type: %s", action)
}
Cmd("SocketClientImpl.negotiate %s %x", action, params)
sum := md5.Sum(key)
if bytes.Compare(sum[:md5.Size], params) != 0 {
return NewErr("negotiate key failed: %x != %x", key, params)
}
if block := s.message.Crypt(); block != nil {
block.Update(string(key))
}
s.status = ClNegotiated
return nil
}
// MUST IMPLEMENT
func (s *SocketClientImpl) Connect() error {
return nil
}
// MUST IMPLEMENT
func (s *SocketClientImpl) Close() {
}
// MUST IMPLEMENT
func (s *SocketClientImpl) Terminal() {
}
func (s *SocketClientImpl) Out() *SubLogger {
if s.out == nil {
s.out = NewSubLogger(s.address)
}
return s.out
}
func (s *SocketClientImpl) Retry() bool {
s.lock.Lock()
defer s.lock.Unlock()
if s.connection != nil ||
s.status == ClTerminal ||
s.status == ClUnAuth {
return false
}
s.status = ClConnecting
return true
}
func (s *SocketClientImpl) Status() SocketStatus {
s.lock.RLock()
defer s.lock.RUnlock()
return s.status
}
func (s *SocketClientImpl) UpTime() int64 {
return time.Now().Unix() - s.newTime
}
func (s *SocketClientImpl) AliveTime() int64 {
if s.connectedTime == 0 {
return 0
}
return time.Now().Unix() - s.connectedTime
}
func (s *SocketClientImpl) Private() interface{} {
s.lock.RLock()
defer s.lock.RUnlock()
return s.private
}
func (s *SocketClientImpl) SetPrivate(v interface{}) {
s.lock.Lock()
defer s.lock.Unlock()
s.private = v
}
func (s *SocketClientImpl) MaxSize() int {
return s.maxSize
}
func (s *SocketClientImpl) SetMaxSize(value int) {
s.maxSize = value
}
func (s *SocketClientImpl) MinSize() int {
return s.minSize
}
func (s *SocketClientImpl) Have(state SocketStatus) bool {
return s.Status() == state
}
func (s *SocketClientImpl) Statistics() map[string]int64 {
sts := make(map[string]int64)
s.statistics.Copy(sts)
return sts
}
func (s *SocketClientImpl) SetListener(listener ClientListener) {
s.listener = listener
}
func (s *SocketClientImpl) SetTimeout(v int64) {
s.timeout = v
}
func (s *SocketClientImpl) update(conn net.Conn) {
if conn != nil {
s.connection = conn
s.connectedTime = time.Now().Unix()
s.localAddr = conn.LocalAddr().String()
s.remoteAddr = conn.RemoteAddr().String()
} else {
if s.connection != nil {
_ = s.connection.Close()
}
s.connection = nil
s.localAddr = ""
s.remoteAddr = ""
s.message.Flush()
}
if s.Block != nil {
s.message.SetCrypt(s.Block)
}
s.out.Event("SocketClientImpl.update: %s %s", s.localAddr, s.remoteAddr)
}
func (s *SocketClientImpl) Reset(conn net.Conn) {
s.lock.Lock()
defer s.lock.Unlock()
s.update(conn)
s.status = ClConnected
if err := s.negotiate(); err != nil {
s.out.Error("SocketClientImpl.Reset %s", err)
return
}
}
// MUST IMPLEMENT
func (s *SocketClientImpl) SetStatus(v SocketStatus) {
}
// Socket Server Interface and Implement
const (
SsRecv = "recv"
SsDeny = "deny"
SsAlive = "alive"
SsSend = "send"
SsDrop = "dropped"
SsAccept = "accept"
SsClose = "closed"
)
type ServerListener struct {
OnClient func(client SocketClient) error
OnClose func(client SocketClient) error
ReadAt func(client SocketClient, f *FrameMessage) error
}
type ReadClient func(client SocketClient, f *FrameMessage) error
type SocketServer interface {
Listen() (err error)
Close()
Accept()
ListClient() <-chan SocketClient
OffClient(client SocketClient)
TotalClient() int
Loop(call ServerListener)
Read(client SocketClient, ReadAt ReadClient)
String() string
Address() string
Statistics() map[string]int64
SetTimeout(v int64)
}
// TODO keepalive to release zombie connections.
type SocketServerImpl struct {
lock sync.RWMutex
statistics *SafeStrInt64
address string
maxClient int
clients *SafeStrMap
onClients chan SocketClient
offClients chan SocketClient
close func()
timeout int64 // sec for read and write timeout
WrQus int // per frames.
error error
}
func NewSocketServer(listen string) *SocketServerImpl {
return &SocketServerImpl{
address: listen,
statistics: NewSafeStrInt64(),
maxClient: 128,
clients: NewSafeStrMap(1024),
onClients: make(chan SocketClient, 1024),
offClients: make(chan SocketClient, 1024),
WrQus: 1024,
}
}
func (t *SocketServerImpl) ListClient() <-chan SocketClient {
list := make(chan SocketClient, 32)
Go(func() {
t.clients.Iter(func(k string, v interface{}) {
if client, ok := v.(SocketClient); ok {
list <- client
}
})
list <- nil
})
return list
}
func (t *SocketServerImpl) TotalClient() int {
return t.clients.Len()
}
func (t *SocketServerImpl) OffClient(client SocketClient) {
Warn("SocketServerImpl.OffClient %s", client)
if client != nil {
t.offClients <- client
}
}
func (t *SocketServerImpl) negotiate(client SocketClient) error {
if client.Key() == "" {
return nil
}
request, err := client.ReadMsg()
if err != nil {
return err
}
if !request.IsControl() {
Info("SocketServerImpl.negotiate %s", request.String())
return NewErr("wrong message type")
}
client.SetStatus(ClNegotiated)
action, params := request.CmdAndParams()
if action == NegoReq {
Cmd("SocketServerImpl.negotiate %s", params)
sum := md5.Sum(params)
reply := NewControlFrame(NegoResp, sum[:md5.Size])
if err := client.WriteMsg(reply); err != nil {
return err
}
client.SetKey(string(params))
return nil
}
return NewErr("wrong message type: %s", action)
}
func (t *SocketServerImpl) doOnClient(call ServerListener, client SocketClient) {
Info("SocketServerImpl.doOnClient: +%s", client)
_ = t.clients.Set(client.RemoteAddr(), client)
if call.OnClient != nil {
Go(func() {
if err := t.negotiate(client); err != nil {
t.OffClient(client)
Warn("SocketServerImpl.doOnClient %s %s", client, err)
return
}
_ = call.OnClient(client)
if call.ReadAt != nil {
t.Read(client, call.ReadAt)
}
})
}
}
func (t *SocketServerImpl) doOffClient(call ServerListener, client SocketClient) {
Info("SocketServerImpl.doOffClient: -%s", client)
addr := client.RemoteAddr()
if _, ok := t.clients.GetEx(addr); ok {
Info("SocketServerImpl.doOffClient: close %s", addr)
t.statistics.Add(SsClose, 1)
if call.OnClose != nil {
_ = call.OnClose(client)
}
client.Close()
t.clients.Del(addr)
t.statistics.Add(SsAlive, -1)
}
}
func (t *SocketServerImpl) Loop(call ServerListener) {
Debug("SocketServerImpl.Loop")
defer t.close()
for {
select {
case client := <-t.onClients:
t.doOnClient(call, client)
case client := <-t.offClients:
t.doOffClient(call, client)
}
}
}
func (t *SocketServerImpl) Read(client SocketClient, ReadAt ReadClient) {
Log("SocketServerImpl.Read: %s", client)
done := make(chan bool, 2)
queue := make(chan *FrameMessage, t.WrQus)
Go(func() {
for {
select {
case frame := <-queue:
if err := ReadAt(client, frame); err != nil {
Error("SocketServerImpl.Read: readAt %s", err)
return
}
case <-done:
return
}
}
})
for {
frame, err := client.ReadMsg()
if err != nil || frame.size <= 0 {
if frame != nil {
Error("SocketServerImpl.Read: %s %d", client, frame.size)
} else {
Error("SocketServerImpl.Read: %s %s", client, err)
}
done <- true
t.OffClient(client)
break
}
t.statistics.Add(SsRecv, 1)
if HasLog(LOG) {
Log("SocketServerImpl.Read: length: %d ", frame.size)
Log("SocketServerImpl.Read: frame : %x", frame)
}
queue <- frame
}
}
// MUST IMPLEMENT
func (t *SocketServerImpl) Listen() error {
return nil
}
// MUST IMPLEMENT
func (t *SocketServerImpl) Accept() {
}
// MUST IMPLEMENT
func (t *SocketServerImpl) Close() {
if t.close != nil {
t.close()
}
}
func (t *SocketServerImpl) Address() string {
return t.address
}
func (t *SocketServerImpl) String() string {
return t.Address()
}
func (t *SocketServerImpl) Statistics() map[string]int64 {
sts := make(map[string]int64, 32)
t.statistics.Copy(sts)
return sts
}
func (t *SocketServerImpl) SetTimeout(v int64) {
t.timeout = v
}
// pre-process when accept connection,
// and allowed accept new connection, will return nil.
func (t *SocketServerImpl) preAccept(conn net.Conn, err error) error {
if err != nil {
if t.error == nil || t.error.Error() != err.Error() {
Warn("SocketServerImpl.preAccept: %s", err)
}
t.error = err
return err
}
t.error = nil
addr := conn.RemoteAddr()
Debug("SocketServerImpl.preAccept: %s", addr)
t.statistics.Add(SsAccept, 1)
alive := t.statistics.Get(SsAlive)
if alive >= int64(t.maxClient) {
Debug("SocketServerImpl.preAccept: close %s", addr)
t.statistics.Add(SsDeny, 1)
t.statistics.Add(SsClose, 1)
_ = conn.Close()
return NewErr("too many open clients")
}
Debug("SocketServerImpl.preAccept: allow %s", addr)
t.statistics.Add(SsAlive, 1)
return nil
}