Files
natpass/code/server/handler/handler.go
2021-09-27 17:02:56 +08:00

232 lines
5.3 KiB
Go

package handler
import (
"bytes"
"natpass/code/network"
"natpass/code/server/global"
"net"
"sync"
"time"
"github.com/lwch/logging"
)
// Handler handler
type Handler struct {
cfg *global.Configure
lockClients sync.RWMutex
clients map[string]*clients // client id => client
lockLinks sync.RWMutex
links map[string][2]*client // link id => endpoints
}
// New create handler
func New(cfg *global.Configure) *Handler {
return &Handler{
cfg: cfg,
clients: make(map[string]*clients),
links: make(map[string][2]*client),
}
}
// Handle main loop
func (h *Handler) Handle(conn net.Conn) {
c := network.NewConn(conn)
var id string
var idx uint32
defer func() {
if len(id) > 0 {
logging.Info("%s disconnected", id)
}
c.Close()
}()
var err error
for i := 0; i < 10; i++ {
id, idx, err = h.readHandshake(c)
if err != nil {
if err == errInvalidHandshake {
logging.Error("invalid handshake from %s", c.RemoteAddr().String())
return
}
logging.Error("read handshake from %s %d times, err=%v", c.RemoteAddr().String(), i+1, err)
continue
}
break
}
if err != nil {
return
}
logging.Info("%s-%d connected", id, idx)
clients := h.tryGetClients(id)
cli := clients.new(idx, c)
defer h.closeClient(cli)
go cli.keepalive()
cli.run()
}
func (h *Handler) tryGetClients(id string) *clients {
h.lockClients.Lock()
defer h.lockClients.Unlock()
clients := h.clients[id]
if clients != nil {
return clients
}
clients = newClients(h, id)
h.clients[id] = clients
return clients
}
// readHandshake read handshake message and compare secret encoded from md5
func (h *Handler) readHandshake(c *network.Conn) (string, uint32, error) {
msg, err := c.ReadMessage(5 * time.Second)
if err != nil {
return "", 0, err
}
if msg.GetXType() != network.Msg_handshake {
return "", 0, errNotHandshake
}
n := bytes.Compare(msg.GetHsp().GetEnc(), h.cfg.Enc[:])
if n != 0 {
return "", 0, errInvalidHandshake
}
return msg.GetFrom(), msg.GetFromIdx(), nil
}
func (h *Handler) getClient(linkID, to string, toIdx uint32) *client {
h.lockLinks.RLock()
pair := h.links[linkID]
h.lockLinks.RUnlock()
if pair[0] != nil && pair[0].is(to, toIdx) {
return pair[0]
}
if pair[1] != nil && pair[1].is(to, toIdx) {
return pair[1]
}
h.lockClients.RLock()
clients := h.clients[to]
h.lockClients.RUnlock()
if clients == nil {
return nil
}
return clients.next()
}
func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg) {
to := msg.GetTo()
toIdx := msg.GetToIdx()
switch msg.GetXType() {
case network.Msg_connect_req:
case network.Msg_connect_rep:
case network.Msg_disconnect:
case network.Msg_forward:
case network.Msg_shell_create:
case network.Msg_shell_resize:
case network.Msg_shell_data:
default:
return
}
cli := h.getClient(msg.GetLinkId(), to, toIdx)
if cli == nil {
logging.Error("client %s-%d not found", to, toIdx)
return
}
h.msgHook(msg, from, cli)
cli.writeMessage(msg)
}
// msgHook hook from on message
func (h *Handler) msgHook(msg *network.Msg, from, to *client) {
switch msg.GetXType() {
case network.Msg_connect_req,
network.Msg_shell_create:
id := msg.GetLinkId()
var pair [2]*client
if from != nil {
from.addLink(id)
pair[0] = from
}
if to != nil {
to.addLink(id)
pair[1] = to
}
h.lockLinks.Lock()
h.links[id] = pair
h.lockLinks.Unlock()
logging.Info("link %s name %s request from %s-%d to %s-%d",
id, msg.GetCreq().GetName(), from.parent.id, from.idx, to.parent.id, to.idx)
case network.Msg_connect_rep:
rep := msg.GetCrep()
if rep.GetOk() {
logging.Info("link %s from %s-%d to %s-%d connect successed",
msg.GetLinkId(), from.parent.id, from.idx, to.parent.id, to.idx)
} else {
logging.Info("link %s from %s-%d to %s-%d connect failed, %s",
msg.GetLinkId(), from.parent.id, from.idx, to.parent.id, to.idx, rep.GetMsg())
}
case network.Msg_forward,
network.Msg_shell_data:
data := msg.GetXData()
logging.Debug("link %s forward %d bytes from %s-%d to %s-%d",
msg.GetLinkId(), len(data.GetData()), from.parent.id, from.idx, to.parent.id, to.idx)
case network.Msg_disconnect,
network.Msg_shell_close:
id := msg.GetLinkId()
if from != nil {
from.removeLink(id)
}
if to != nil {
to.removeLink(id)
}
h.lockLinks.Lock()
delete(h.links, id)
h.lockLinks.Unlock()
logging.Info("link %s disconnect from %s-%d to %s-%d",
msg.GetLinkId(), from.parent.id, from.idx, to.parent.id, to.idx)
case network.Msg_shell_resize:
data := msg.GetSresize()
logging.Info("shell %s from %s-%d to %s-%d resize to (%d,%d)",
msg.GetLinkId(), from.parent.id, from.idx, to.parent.id, to.idx,
data.GetRows(), data.GetCols())
}
msg.From = from.parent.id
msg.FromIdx = from.idx
msg.To = to.parent.id
msg.ToIdx = to.idx
}
func (h *Handler) closeClient(cli *client) {
links := cli.getLinks()
for _, t := range links {
h.lockLinks.RLock()
pair := h.links[t]
h.lockLinks.RUnlock()
if pair[0] != nil {
pair[0].closeLink(t)
}
if pair[1] != nil {
pair[1].closeLink(t)
}
h.lockLinks.Lock()
delete(h.links, t)
h.lockLinks.Unlock()
}
h.lockClients.RLock()
clients := h.clients[cli.parent.id]
h.lockClients.RUnlock()
if clients != nil {
clients.close(cli.idx)
}
}
func (h *Handler) removeClients(id string) {
h.lockClients.Lock()
delete(h.clients, id)
h.lockClients.Unlock()
}