Files
natpass/code/server/handler/handler.go
2022-05-10 19:54:49 +08:00

224 lines
5.0 KiB
Go

package handler
import (
"bytes"
"net"
"sync"
"time"
"github.com/lwch/logging"
"github.com/lwch/natpass/code/network"
"github.com/lwch/natpass/code/server/global"
)
type link struct {
id string
t network.ConnectRequestType
endPoints [2]*client
}
func (link *link) close() {
close := func(cli *client) {
if cli == nil {
return
}
cli.sendClose(link.id)
}
close(link.endPoints[0])
close(link.endPoints[1])
}
// Handler handler
type Handler struct {
cfg *global.Configure
clis *clients
lockLinks sync.RWMutex
links map[string]link // link id => endpoints
}
// New create handler
func New(cfg *global.Configure) *Handler {
h := &Handler{
cfg: cfg,
links: make(map[string]link),
}
h.clis = newClients(h)
return h
}
// Handle main loop
func (h *Handler) Handle(conn net.Conn) {
c := network.NewConn(conn)
var id string
defer func() {
if len(id) > 0 {
logging.Info("%s disconnected", id)
}
c.Close()
}()
var err error
for i := 0; i < 10; i++ {
id, err = h.readHandshake(c)
if err != nil {
if err == errInvalidHandshake {
logging.Error("invalid handshake from %s", c.RemoteAddr().String())
return
}
logging.Error("read handshake from %s %d times, err=%v", c.RemoteAddr().String(), i+1, err)
continue
}
break
}
if err != nil {
return
}
logging.Info("%s connected", id)
cli := h.clis.new(id, c)
defer h.clis.close(id)
go cli.keepalive()
cli.run()
}
// readHandshake read handshake message and compare secret encoded from md5
func (h *Handler) readHandshake(c *network.Conn) (string, error) {
msg, _, err := c.ReadMessage(5 * time.Second)
if err != nil {
return "", err
}
if msg.GetXType() != network.Msg_handshake {
return "", errNotHandshake
}
n := bytes.Compare(msg.GetHsp().GetEnc(), h.cfg.Enc[:])
if n != 0 {
return "", errInvalidHandshake
}
return msg.GetFrom(), nil
}
func (h *Handler) getClient(linkID, to string) *client {
h.lockLinks.RLock()
link := h.links[linkID]
h.lockLinks.RUnlock()
if link.endPoints[0] != nil && link.endPoints[0].id == to {
return link.endPoints[0]
}
if link.endPoints[1] != nil && link.endPoints[1].id == to {
return link.endPoints[1]
}
return h.clis.lookup(to)
}
func (h *Handler) onMessage(from *client, conn *network.Conn, msg *network.Msg, size uint16) {
to := msg.GetTo()
if msg.GetXType() == network.Msg_keepalive {
return
}
cli := h.getClient(msg.GetLinkId(), to)
if cli == nil {
logging.Error("client %s not found", to)
return
}
h.msgHook(msg, from, cli, size)
err := cli.writeMessage(msg)
if err != nil {
logging.Error("write message %s from %s to %s: %v",
msg.GetXType().String(),
msg.GetFrom(), msg.GetTo(),
err)
}
}
func (h *Handler) addLink(name, id string, t network.ConnectRequestType, from, to *client) {
var link link
link.id = id
link.t = t
if from != nil {
from.addLink(id)
link.endPoints[0] = from
}
if to != nil {
to.addLink(id)
link.endPoints[1] = to
}
h.lockLinks.Lock()
h.links[id] = link
h.lockLinks.Unlock()
logging.Info("add link %s name %s from %s to %s",
id, name, from.id, to.id)
}
func (h *Handler) removeLink(id string, from, to *client) {
if from != nil {
from.removeLink(id)
}
if to != nil {
to.removeLink(id)
}
h.lockLinks.Lock()
delete(h.links, id)
h.lockLinks.Unlock()
logging.Info("remove link %s from %s to %s",
id, from.id, to.id)
}
func (h *Handler) responseLink(id string, ok bool, msg string, from, to *client) {
if ok {
logging.Info("link %s from %s to %s connect successed",
id, from.id, to.id)
} else {
logging.Info("link %s from %s to %s connect failed, %s",
id, from.id, to.id, msg)
// TODO: remove link?
}
}
// msgHook hook from on message
func (h *Handler) msgHook(msg *network.Msg, from, to *client, size uint16) {
switch msg.GetXType() {
// create link
case network.Msg_connect_req:
h.addLink(msg.GetCreq().GetName(), msg.GetLinkId(), msg.GetCreq().GetXType(), from, to)
// remove link
case network.Msg_disconnect:
h.removeLink(msg.GetLinkId(), from, to)
// response link
case network.Msg_connect_rep:
rep := msg.GetCrep()
h.responseLink(msg.GetLinkId(), rep.GetOk(), rep.GetMsg(), from, to)
// forward data
case network.Msg_forward:
data := msg.GetXData()
logging.Debug("link %s forward %d bytes from %s to %s",
msg.GetLinkId(), len(data.GetData()), from.id, to.id)
case network.Msg_shell_data:
data := msg.GetSdata()
logging.Debug("shell %s forward %d bytes from %s to %s",
msg.GetLinkId(), len(data.GetData()), from.id, to.id)
// shell
case network.Msg_shell_resize:
data := msg.GetSresize()
logging.Info("shell %s from %s to %s resize to (%d,%d)",
msg.GetLinkId(), from.id, to.id,
data.GetRows(), data.GetCols())
}
msg.From = from.id
msg.To = to.id
logging.Debug("forward %d bytes on link %s from %s to %s", size, msg.GetLinkId(),
from.id, to.id)
}
func (h *Handler) closeLink(id string) {
h.lockLinks.RLock()
link := h.links[id]
h.lockLinks.RUnlock()
link.close()
h.lockLinks.Lock()
delete(h.links, id)
h.lockLinks.Unlock()
}