mirror of
https://github.com/luscis/openlan.git
synced 2025-10-05 16:47:11 +08:00
627 lines
13 KiB
Go
Executable File
627 lines
13 KiB
Go
Executable File
package libol
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/md5"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
ClInit = 0x00
|
|
ClConnected = 0x01
|
|
ClUnAuth = 0x02
|
|
ClAuth = 0x03
|
|
ClConnecting = 0x04
|
|
ClTerminal = 0x05
|
|
ClClosed = 0x06
|
|
ClNegotiating = 0x07
|
|
ClNegotiated = 0x08
|
|
)
|
|
|
|
type SocketStatus uint8
|
|
|
|
func (s SocketStatus) String() string {
|
|
switch s {
|
|
case ClInit:
|
|
return "initialized"
|
|
case ClConnected:
|
|
return "connected"
|
|
case ClUnAuth:
|
|
return "unauthenticated"
|
|
case ClAuth:
|
|
return "authenticated"
|
|
case ClClosed:
|
|
return "closed"
|
|
case ClConnecting:
|
|
return "connecting"
|
|
case ClTerminal:
|
|
return "terminal"
|
|
case ClNegotiating:
|
|
return "negotiating"
|
|
case ClNegotiated:
|
|
return "negotiated"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// Socket Client Interface and Implement
|
|
|
|
const (
|
|
CsSendOkay = "send"
|
|
CsRecvOkay = "recv"
|
|
CsSendError = "error"
|
|
CsDropped = "dropped"
|
|
)
|
|
|
|
type ClientListener struct {
|
|
OnClose func(client SocketClient) error
|
|
OnConnected func(client SocketClient) error
|
|
OnStatus func(client SocketClient, old, new SocketStatus)
|
|
}
|
|
|
|
type SocketClient interface {
|
|
LocalAddr() string
|
|
RemoteAddr() string
|
|
Connect() error
|
|
Close()
|
|
WriteMsg(frame *FrameMessage) error
|
|
ReadMsg() (*FrameMessage, error)
|
|
UpTime() int64
|
|
AliveTime() int64
|
|
String() string
|
|
Terminal()
|
|
Private() interface{}
|
|
SetPrivate(v interface{})
|
|
Status() SocketStatus
|
|
SetStatus(v SocketStatus)
|
|
MaxSize() int
|
|
SetMaxSize(value int)
|
|
MinSize() int
|
|
IsOk() bool
|
|
Have(status SocketStatus) bool
|
|
Statistics() map[string]int64
|
|
SetListener(listener ClientListener)
|
|
SetTimeout(v int64)
|
|
Out() *SubLogger
|
|
SetKey(key string)
|
|
Key() string
|
|
}
|
|
|
|
type StreamSocket struct {
|
|
message Messager
|
|
connection net.Conn
|
|
statistics *SafeStrInt64
|
|
maxSize int
|
|
minSize int
|
|
out *SubLogger
|
|
remoteAddr string
|
|
localAddr string
|
|
address string
|
|
Block *BlockCrypt
|
|
}
|
|
|
|
func (t *StreamSocket) LocalAddr() string {
|
|
return t.localAddr
|
|
}
|
|
|
|
func (t *StreamSocket) RemoteAddr() string {
|
|
return t.remoteAddr
|
|
}
|
|
|
|
func (t *StreamSocket) String() string {
|
|
return t.address
|
|
}
|
|
|
|
func (t *StreamSocket) IsOk() bool {
|
|
return t.connection != nil
|
|
}
|
|
|
|
func (t *StreamSocket) WriteMsg(frame *FrameMessage) error {
|
|
if !t.IsOk() {
|
|
t.statistics.Add(CsDropped, 1)
|
|
return NewErr("%s not okay", t)
|
|
}
|
|
if frame.IsControl() {
|
|
action, params := frame.CmdAndParams()
|
|
Cmd("StreamSocket.WriteMsg: %s%s", action, params)
|
|
}
|
|
if t.message == nil { // default is stream message
|
|
t.message = &StreamMessagerImpl{}
|
|
}
|
|
size, err := t.message.Send(t.connection, frame)
|
|
if err != nil {
|
|
t.statistics.Add(CsSendError, 1)
|
|
return err
|
|
}
|
|
t.statistics.Add(CsSendOkay, int64(size))
|
|
return nil
|
|
}
|
|
|
|
func (t *StreamSocket) ReadMsg() (*FrameMessage, error) {
|
|
if HasLog(LOG) {
|
|
Log("StreamSocket.ReadMsg: %s", t)
|
|
}
|
|
if !t.IsOk() {
|
|
return nil, NewErr("%s not okay", t)
|
|
}
|
|
if t.message == nil { // default is stream message
|
|
t.message = &StreamMessagerImpl{}
|
|
}
|
|
frame, err := t.message.Receive(t.connection, t.maxSize, t.minSize)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
size := len(frame.frame)
|
|
t.statistics.Add(CsRecvOkay, int64(size))
|
|
return frame, nil
|
|
}
|
|
|
|
func (t *StreamSocket) SetKey(key string) {
|
|
if block := t.message.Crypt(); block != nil {
|
|
block.Update(key)
|
|
}
|
|
}
|
|
|
|
func (t *StreamSocket) Key() string {
|
|
key := ""
|
|
if block := t.message.Crypt(); block != nil {
|
|
key = block.key
|
|
}
|
|
return key
|
|
}
|
|
|
|
type SocketConfig struct {
|
|
Address string
|
|
Block *BlockCrypt
|
|
}
|
|
|
|
type SocketClientImpl struct {
|
|
*StreamSocket
|
|
lock sync.RWMutex
|
|
listener ClientListener
|
|
newTime int64
|
|
connectedTime int64
|
|
private interface{}
|
|
status SocketStatus
|
|
timeout int64 // sec for read and write timeout
|
|
}
|
|
|
|
func NewSocketClient(cfg SocketConfig, message Messager) *SocketClientImpl {
|
|
return &SocketClientImpl{
|
|
StreamSocket: &StreamSocket{
|
|
maxSize: 1514,
|
|
minSize: 15,
|
|
message: message,
|
|
statistics: NewSafeStrInt64(),
|
|
out: NewSubLogger(cfg.Address),
|
|
remoteAddr: cfg.Address,
|
|
address: cfg.Address,
|
|
Block: cfg.Block,
|
|
},
|
|
newTime: time.Now().Unix(),
|
|
status: ClInit,
|
|
}
|
|
}
|
|
|
|
func (s *SocketClientImpl) negotiate() error {
|
|
if s.Key() == "" {
|
|
return nil
|
|
}
|
|
key := GenLetters(64)
|
|
request := NewControlFrame(NegoReq, key)
|
|
if err := s.WriteMsg(request); err != nil {
|
|
return err
|
|
}
|
|
s.status = ClNegotiating
|
|
reply, err := s.ReadMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !reply.IsControl() {
|
|
Info("SocketClientImpl.negotiate %s", reply.String())
|
|
return NewErr("wrong message type")
|
|
}
|
|
action, params := reply.CmdAndParams()
|
|
if action != NegoResp {
|
|
return NewErr("wrong message type: %s", action)
|
|
}
|
|
Cmd("SocketClientImpl.negotiate %s %x", action, params)
|
|
sum := md5.Sum(key)
|
|
if bytes.Compare(sum[:md5.Size], params) != 0 {
|
|
return NewErr("negotiate key failed: %x != %x", key, params)
|
|
}
|
|
if block := s.message.Crypt(); block != nil {
|
|
block.Update(string(key))
|
|
}
|
|
s.status = ClNegotiated
|
|
return nil
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (s *SocketClientImpl) Connect() error {
|
|
return nil
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (s *SocketClientImpl) Close() {
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (s *SocketClientImpl) Terminal() {
|
|
}
|
|
|
|
func (s *SocketClientImpl) Out() *SubLogger {
|
|
if s.out == nil {
|
|
s.out = NewSubLogger(s.address)
|
|
}
|
|
return s.out
|
|
}
|
|
|
|
func (s *SocketClientImpl) Retry() bool {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
if s.connection != nil ||
|
|
s.status == ClTerminal ||
|
|
s.status == ClUnAuth {
|
|
return false
|
|
}
|
|
s.status = ClConnecting
|
|
return true
|
|
}
|
|
|
|
func (s *SocketClientImpl) Status() SocketStatus {
|
|
s.lock.RLock()
|
|
defer s.lock.RUnlock()
|
|
return s.status
|
|
}
|
|
|
|
func (s *SocketClientImpl) UpTime() int64 {
|
|
return time.Now().Unix() - s.newTime
|
|
}
|
|
|
|
func (s *SocketClientImpl) AliveTime() int64 {
|
|
if s.connectedTime == 0 {
|
|
return 0
|
|
}
|
|
return time.Now().Unix() - s.connectedTime
|
|
}
|
|
|
|
func (s *SocketClientImpl) Private() interface{} {
|
|
s.lock.RLock()
|
|
defer s.lock.RUnlock()
|
|
return s.private
|
|
}
|
|
|
|
func (s *SocketClientImpl) SetPrivate(v interface{}) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.private = v
|
|
}
|
|
|
|
func (s *SocketClientImpl) MaxSize() int {
|
|
return s.maxSize
|
|
}
|
|
|
|
func (s *SocketClientImpl) SetMaxSize(value int) {
|
|
s.maxSize = value
|
|
}
|
|
|
|
func (s *SocketClientImpl) MinSize() int {
|
|
return s.minSize
|
|
}
|
|
|
|
func (s *SocketClientImpl) Have(state SocketStatus) bool {
|
|
return s.Status() == state
|
|
}
|
|
|
|
func (s *SocketClientImpl) Statistics() map[string]int64 {
|
|
sts := make(map[string]int64)
|
|
s.statistics.Copy(sts)
|
|
return sts
|
|
}
|
|
|
|
func (s *SocketClientImpl) SetListener(listener ClientListener) {
|
|
s.listener = listener
|
|
}
|
|
|
|
func (s *SocketClientImpl) SetTimeout(v int64) {
|
|
s.timeout = v
|
|
}
|
|
|
|
func (s *SocketClientImpl) update(conn net.Conn) {
|
|
if conn != nil {
|
|
s.connection = conn
|
|
s.connectedTime = time.Now().Unix()
|
|
s.localAddr = conn.LocalAddr().String()
|
|
s.remoteAddr = conn.RemoteAddr().String()
|
|
} else {
|
|
if s.connection != nil {
|
|
_ = s.connection.Close()
|
|
}
|
|
s.connection = nil
|
|
s.localAddr = ""
|
|
s.remoteAddr = ""
|
|
s.message.Flush()
|
|
}
|
|
if s.Block != nil {
|
|
s.message.SetCrypt(s.Block)
|
|
}
|
|
s.out.Event("SocketClientImpl.update: %s %s", s.localAddr, s.remoteAddr)
|
|
}
|
|
|
|
func (s *SocketClientImpl) Reset(conn net.Conn) {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
s.update(conn)
|
|
s.status = ClConnected
|
|
if err := s.negotiate(); err != nil {
|
|
s.out.Error("SocketClientImpl.Reset %s", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (s *SocketClientImpl) SetStatus(v SocketStatus) {
|
|
}
|
|
|
|
// Socket Server Interface and Implement
|
|
|
|
const (
|
|
SsRecv = "recv"
|
|
SsDeny = "deny"
|
|
SsAlive = "alive"
|
|
SsSend = "send"
|
|
SsDrop = "dropped"
|
|
SsAccept = "accept"
|
|
SsClose = "closed"
|
|
)
|
|
|
|
type ServerListener struct {
|
|
OnClient func(client SocketClient) error
|
|
OnClose func(client SocketClient) error
|
|
ReadAt func(client SocketClient, f *FrameMessage) error
|
|
}
|
|
|
|
type ReadClient func(client SocketClient, f *FrameMessage) error
|
|
|
|
type SocketServer interface {
|
|
Listen() (err error)
|
|
Close()
|
|
Accept()
|
|
ListClient() <-chan SocketClient
|
|
OffClient(client SocketClient)
|
|
TotalClient() int
|
|
Loop(call ServerListener)
|
|
Read(client SocketClient, ReadAt ReadClient)
|
|
String() string
|
|
Address() string
|
|
Statistics() map[string]int64
|
|
SetTimeout(v int64)
|
|
}
|
|
|
|
// TODO keepalive to release zombie connections.
|
|
type SocketServerImpl struct {
|
|
lock sync.RWMutex
|
|
statistics *SafeStrInt64
|
|
address string
|
|
maxClient int
|
|
clients *SafeStrMap
|
|
onClients chan SocketClient
|
|
offClients chan SocketClient
|
|
close func()
|
|
timeout int64 // sec for read and write timeout
|
|
WrQus int // per frames.
|
|
error error
|
|
}
|
|
|
|
func NewSocketServer(listen string) *SocketServerImpl {
|
|
return &SocketServerImpl{
|
|
address: listen,
|
|
statistics: NewSafeStrInt64(),
|
|
maxClient: 128,
|
|
clients: NewSafeStrMap(1024),
|
|
onClients: make(chan SocketClient, 1024),
|
|
offClients: make(chan SocketClient, 1024),
|
|
WrQus: 1024,
|
|
}
|
|
}
|
|
|
|
func (t *SocketServerImpl) ListClient() <-chan SocketClient {
|
|
list := make(chan SocketClient, 32)
|
|
Go(func() {
|
|
t.clients.Iter(func(k string, v interface{}) {
|
|
if client, ok := v.(SocketClient); ok {
|
|
list <- client
|
|
}
|
|
})
|
|
list <- nil
|
|
})
|
|
return list
|
|
}
|
|
|
|
func (t *SocketServerImpl) TotalClient() int {
|
|
return t.clients.Len()
|
|
}
|
|
|
|
func (t *SocketServerImpl) OffClient(client SocketClient) {
|
|
Warn("SocketServerImpl.OffClient %s", client)
|
|
if client != nil {
|
|
t.offClients <- client
|
|
}
|
|
}
|
|
|
|
func (t *SocketServerImpl) negotiate(client SocketClient) error {
|
|
if client.Key() == "" {
|
|
return nil
|
|
}
|
|
request, err := client.ReadMsg()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !request.IsControl() {
|
|
Info("SocketServerImpl.negotiate %s", request.String())
|
|
return NewErr("wrong message type")
|
|
}
|
|
client.SetStatus(ClNegotiated)
|
|
action, params := request.CmdAndParams()
|
|
if action == NegoReq {
|
|
Cmd("SocketServerImpl.negotiate %s", params)
|
|
sum := md5.Sum(params)
|
|
reply := NewControlFrame(NegoResp, sum[:md5.Size])
|
|
if err := client.WriteMsg(reply); err != nil {
|
|
return err
|
|
}
|
|
client.SetKey(string(params))
|
|
return nil
|
|
}
|
|
return NewErr("wrong message type: %s", action)
|
|
|
|
}
|
|
|
|
func (t *SocketServerImpl) doOnClient(call ServerListener, client SocketClient) {
|
|
Info("SocketServerImpl.doOnClient: +%s", client)
|
|
_ = t.clients.Set(client.RemoteAddr(), client)
|
|
if call.OnClient != nil {
|
|
Go(func() {
|
|
if err := t.negotiate(client); err != nil {
|
|
t.OffClient(client)
|
|
Warn("SocketServerImpl.doOnClient %s %s", client, err)
|
|
return
|
|
}
|
|
_ = call.OnClient(client)
|
|
if call.ReadAt != nil {
|
|
t.Read(client, call.ReadAt)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func (t *SocketServerImpl) doOffClient(call ServerListener, client SocketClient) {
|
|
Info("SocketServerImpl.doOffClient: -%s", client)
|
|
addr := client.RemoteAddr()
|
|
if _, ok := t.clients.GetEx(addr); ok {
|
|
Info("SocketServerImpl.doOffClient: close %s", addr)
|
|
t.statistics.Add(SsClose, 1)
|
|
if call.OnClose != nil {
|
|
_ = call.OnClose(client)
|
|
}
|
|
client.Close()
|
|
t.clients.Del(addr)
|
|
t.statistics.Add(SsAlive, -1)
|
|
}
|
|
}
|
|
|
|
func (t *SocketServerImpl) Loop(call ServerListener) {
|
|
Debug("SocketServerImpl.Loop")
|
|
defer t.close()
|
|
for {
|
|
select {
|
|
case client := <-t.onClients:
|
|
t.doOnClient(call, client)
|
|
case client := <-t.offClients:
|
|
t.doOffClient(call, client)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (t *SocketServerImpl) Read(client SocketClient, ReadAt ReadClient) {
|
|
Log("SocketServerImpl.Read: %s", client)
|
|
done := make(chan bool, 2)
|
|
queue := make(chan *FrameMessage, t.WrQus)
|
|
Go(func() {
|
|
for {
|
|
select {
|
|
case frame := <-queue:
|
|
if err := ReadAt(client, frame); err != nil {
|
|
Error("SocketServerImpl.Read: readAt %s", err)
|
|
return
|
|
}
|
|
case <-done:
|
|
return
|
|
}
|
|
}
|
|
})
|
|
for {
|
|
frame, err := client.ReadMsg()
|
|
if err != nil || frame.size <= 0 {
|
|
if frame != nil {
|
|
Error("SocketServerImpl.Read: %s %d", client, frame.size)
|
|
} else {
|
|
Error("SocketServerImpl.Read: %s %s", client, err)
|
|
}
|
|
done <- true
|
|
t.OffClient(client)
|
|
break
|
|
}
|
|
t.statistics.Add(SsRecv, 1)
|
|
if HasLog(LOG) {
|
|
Log("SocketServerImpl.Read: length: %d ", frame.size)
|
|
Log("SocketServerImpl.Read: frame : %x", frame)
|
|
}
|
|
queue <- frame
|
|
}
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (t *SocketServerImpl) Listen() error {
|
|
return nil
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (t *SocketServerImpl) Accept() {
|
|
}
|
|
|
|
// MUST IMPLEMENT
|
|
func (t *SocketServerImpl) Close() {
|
|
if t.close != nil {
|
|
t.close()
|
|
}
|
|
}
|
|
|
|
func (t *SocketServerImpl) Address() string {
|
|
return t.address
|
|
}
|
|
|
|
func (t *SocketServerImpl) String() string {
|
|
return t.Address()
|
|
}
|
|
|
|
func (t *SocketServerImpl) Statistics() map[string]int64 {
|
|
sts := make(map[string]int64, 32)
|
|
t.statistics.Copy(sts)
|
|
return sts
|
|
}
|
|
|
|
func (t *SocketServerImpl) SetTimeout(v int64) {
|
|
t.timeout = v
|
|
}
|
|
|
|
// pre-process when accept connection,
|
|
// and allowed accept new connection, will return nil.
|
|
func (t *SocketServerImpl) preAccept(conn net.Conn, err error) error {
|
|
if err != nil {
|
|
if t.error == nil || t.error.Error() != err.Error() {
|
|
Warn("SocketServerImpl.preAccept: %s", err)
|
|
}
|
|
t.error = err
|
|
return err
|
|
}
|
|
t.error = nil
|
|
addr := conn.RemoteAddr()
|
|
Debug("SocketServerImpl.preAccept: %s", addr)
|
|
t.statistics.Add(SsAccept, 1)
|
|
alive := t.statistics.Get(SsAlive)
|
|
if alive >= int64(t.maxClient) {
|
|
Debug("SocketServerImpl.preAccept: close %s", addr)
|
|
t.statistics.Add(SsDeny, 1)
|
|
t.statistics.Add(SsClose, 1)
|
|
_ = conn.Close()
|
|
return NewErr("too many open clients")
|
|
}
|
|
Debug("SocketServerImpl.preAccept: allow %s", addr)
|
|
t.statistics.Add(SsAlive, 1)
|
|
return nil
|
|
}
|