refactor: refactor code (#306)

This commit is contained in:
naison
2024-07-23 19:11:58 +08:00
committed by GitHub
parent a37bfc28da
commit f13e21a049
4 changed files with 78 additions and 141 deletions

View File

@@ -7,9 +7,7 @@ import (
"strings"
)
var (
ErrorInvalidNode = errors.New("invalid node")
)
var ErrorInvalidNode = errors.New("invalid node")
type Node struct {
Addr string
@@ -29,12 +27,13 @@ func ParseNode(s string) (*Node, error) {
if err != nil {
return nil, err
}
return &Node{
node := &Node{
Addr: u.Host,
Remote: strings.Trim(u.EscapedPath(), "/"),
Values: u.Query(),
Protocol: u.Scheme,
}, nil
}
return node, nil
}
// Get returns node parameter specified by key.

View File

@@ -16,12 +16,10 @@ import (
)
var (
// RouteNAT Globe route table for inner ip
RouteNAT = NewNAT()
// RouteConnNAT map[srcIP]net.Conn
RouteConnNAT = &sync.Map{}
// Chan tcp connects
Chan = make(chan *datagramPacket, MaxSize)
// RouteMapTCP map[srcIP]net.Conn Globe route table for inner ip
RouteMapTCP = &sync.Map{}
// TCPPacketChan tcp connects
TCPPacketChan = make(chan *datagramPacket, MaxSize)
)
type TCPUDPacket struct {
@@ -39,7 +37,6 @@ type Route struct {
}
func (r *Route) parseChain() (*Chain, error) {
// parse the base nodes
node, err := parseChainNode(r.ChainNode)
if err != nil {
return nil, err
@@ -50,7 +47,6 @@ func (r *Route) parseChain() (*Chain, error) {
func parseChainNode(ns string) (*Node, error) {
node, err := ParseNode(ns)
if err != nil {
log.Errorf("parse node error: %v", err)
return nil, err
}
node.Client = &Client{

View File

@@ -3,6 +3,7 @@ package core
import (
"context"
"net"
"strings"
"sync"
"time"
@@ -41,33 +42,33 @@ func (c *fakeUDPTunnelConnector) ConnectContext(ctx context.Context, conn net.Co
type fakeUdpHandler struct {
// map[srcIP]net.Conn
connNAT *sync.Map
ch chan *datagramPacket
routeMapTCP *sync.Map
packetChan chan *datagramPacket
}
func TCPHandler() Handler {
return &fakeUdpHandler{
connNAT: RouteConnNAT,
ch: Chan,
routeMapTCP: RouteMapTCP,
packetChan: TCPPacketChan,
}
}
func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) {
defer tcpConn.Close()
log.Debugf("[tcpserver] %s -> %s\n", tcpConn.RemoteAddr(), tcpConn.LocalAddr())
log.Debugf("[tcpserver] %s -> %s", tcpConn.RemoteAddr(), tcpConn.LocalAddr())
defer func(addr net.Addr) {
var keys []string
h.connNAT.Range(func(key, value any) bool {
h.routeMapTCP.Range(func(key, value any) bool {
if value.(net.Conn) == tcpConn {
keys = append(keys, key.(string))
}
return true
})
for _, key := range keys {
h.connNAT.Delete(key)
h.routeMapTCP.Delete(key)
}
log.Debugf("[tcpserver] delete conn %s from globle routeConnNAT, deleted count %d", addr, len(keys))
log.Debugf("[tcpserver] to %s by conn %s from globle route map TCP", strings.Join(keys, " "), addr)
}(tcpConn.LocalAddr())
for {
@@ -80,7 +81,7 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) {
b := config.LPool.Get().([]byte)[:]
dgram, err := readDatagramPacketServer(tcpConn, b[:])
if err != nil {
log.Debugf("[tcpserver] %s -> 0 : %v", tcpConn.RemoteAddr(), err)
log.Debugf("[tcpserver] %s -> %s : %v", tcpConn.RemoteAddr(), tcpConn.LocalAddr(), err)
return
}
@@ -94,17 +95,17 @@ func (h *fakeUdpHandler) Handle(ctx context.Context, tcpConn net.Conn) {
log.Errorf("[tcpserver] unknown packet")
continue
}
value, loaded := h.connNAT.LoadOrStore(src.String(), tcpConn)
value, loaded := h.routeMapTCP.LoadOrStore(src.String(), tcpConn)
if loaded {
if tcpConn != value.(net.Conn) {
h.connNAT.Store(src.String(), tcpConn)
log.Debugf("[tcpserver] replace routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
h.routeMapTCP.Store(src.String(), tcpConn)
log.Debugf("[tcpserver] replace route map TCP: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
}
log.Debugf("[tcpserver] find routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
log.Debugf("[tcpserver] find route map TCP: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
} else {
log.Debugf("[tcpserver] new routeConnNAT: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
log.Debugf("[tcpserver] new route map TCP: %s -> %s-%s", src, tcpConn.LocalAddr(), tcpConn.RemoteAddr())
}
util.SafeWrite(h.ch, dgram)
util.SafeWrite(h.packetChan, dgram)
}
}

View File

@@ -3,9 +3,7 @@ package core
import (
"context"
"fmt"
"math/rand"
"net"
"strings"
"sync"
"time"
@@ -22,98 +20,47 @@ const (
)
type tunHandler struct {
chain *Chain
node *Node
routeNAT *NAT
chain *Chain
node *Node
routeMapUDP *RouteMap
// map[srcIP]net.Conn
routeConnNAT *sync.Map
chExit chan error
routeMapTCP *sync.Map
chExit chan error
}
type NAT struct {
type RouteMap struct {
lock *sync.RWMutex
routes map[string][]net.Addr
routes map[string]net.Addr
}
func NewNAT() *NAT {
return &NAT{
func NewRouteMap() *RouteMap {
return &RouteMap{
lock: &sync.RWMutex{},
routes: map[string][]net.Addr{},
routes: map[string]net.Addr{},
}
}
func (n *NAT) RemoveAddr(addr net.Addr) (count int) {
n.lock.Lock()
defer n.lock.Unlock()
for k, v := range n.routes {
for i := 0; i < len(v); i++ {
if v[i].String() == addr.String() {
v = append(v[:i], v[i+1:]...)
i--
count++
}
}
n.routes[k] = v
}
return
}
func (n *NAT) LoadOrStore(to net.IP, addr net.Addr) (result net.Addr, load bool) {
func (n *RouteMap) LoadOrStore(to net.IP, addr net.Addr) (result net.Addr, load bool) {
n.lock.RLock()
addrList := n.routes[to.String()]
route, ok := n.routes[to.String()]
n.lock.RUnlock()
for _, add := range addrList {
if add.String() == addr.String() {
load = true
result = addr
return
}
if ok && route.String() == addr.String() {
return addr, true
}
n.lock.Lock()
defer n.lock.Unlock()
if addrList == nil {
n.routes[to.String()] = []net.Addr{addr}
result = addr
return
} else {
n.routes[to.String()] = append(n.routes[to.String()], addr)
result = addr
return
}
n.routes[to.String()] = addr
return addr, false
}
func (n *NAT) RouteTo(ip net.IP) net.Addr {
func (n *RouteMap) RouteTo(ip net.IP) net.Addr {
n.lock.RLock()
defer n.lock.RUnlock()
addrList := n.routes[ip.String()]
if len(addrList) == 0 {
return nil
}
// for load balance
index := rand.Intn(len(n.routes[ip.String()]))
return addrList[index]
return n.routes[ip.String()]
}
func (n *NAT) Remove(ip net.IP, addr net.Addr) {
n.lock.Lock()
defer n.lock.Unlock()
addrList, ok := n.routes[ip.String()]
if !ok {
return
}
for i := 0; i < len(addrList); i++ {
if addrList[i].String() == addr.String() {
addrList = append(addrList[:i], addrList[i+1:]...)
i--
}
}
n.routes[ip.String()] = addrList
return
}
func (n *NAT) Range(f func(key string, v []net.Addr)) {
func (n *RouteMap) Range(f func(key string, value net.Addr)) {
n.lock.RLock()
defer n.lock.RUnlock()
for k, v := range n.routes {
@@ -124,11 +71,11 @@ func (n *NAT) Range(f func(key string, v []net.Addr)) {
// TunHandler creates a handler for tun tunnel.
func TunHandler(chain *Chain, node *Node) Handler {
return &tunHandler{
chain: chain,
node: node,
routeNAT: RouteNAT,
routeConnNAT: RouteConnNAT,
chExit: make(chan error, 1),
chain: chain,
node: node,
routeMapUDP: NewRouteMap(),
routeMapTCP: RouteMapTCP,
chExit: make(chan error, 1),
}
}
@@ -143,27 +90,12 @@ func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) {
func (h *tunHandler) printRoute(ctx context.Context) {
ticker := time.NewTicker(time.Second * 5)
defer ticker.Stop()
var sb strings.Builder
var i int
for ctx.Err() == nil {
select {
case <-ticker.C:
i = 0
sb.Reset()
h.routeNAT.Range(func(key string, value []net.Addr) {
i++
var s []string
for _, addr := range value {
if addr != nil {
s = append(s, addr.String())
}
}
if len(s) != 0 {
sb.WriteString(fmt.Sprintf("to: %s, route: %s\n", key, strings.Join(s, " ")))
}
h.routeMapUDP.Range(func(key string, value net.Addr) {
log.Debugf("to: %s, route: %s", key, value.String())
})
log.Debug(sb.String())
log.Debug(i)
}
}
}
@@ -247,7 +179,7 @@ func (d *Device) Close() {
util.SafeClose(d.tunInbound)
util.SafeClose(d.tunOutbound)
util.SafeClose(d.tunInboundRaw)
util.SafeClose(Chan)
util.SafeClose(TCPPacketChan)
}
func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
@@ -266,6 +198,12 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
if config.RouterIP6.To4().Equal(srcIPv6) {
return
}
var dstIPv4, dstIPv6 = net.IPv4zero, net.IPv6zero
if config.CIDR.Contains(srcIPv4) {
dstIPv4, dstIPv6 = config.RouterIP, config.RouterIP6
} else if config.DockerCIDR.Contains(srcIPv4) {
dstIPv4 = config.DockerRouterIP
}
var bytes []byte
var bytes6 []byte
@@ -282,14 +220,14 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
for i := 0; i < 4; i++ {
if bytes == nil {
bytes, err = genICMPPacket(srcIPv4, config.RouterIP)
bytes, err = genICMPPacket(srcIPv4, dstIPv4)
if err != nil {
log.Errorf("generate ipv4 packet error: %s", err.Error())
continue
}
}
if bytes6 == nil {
bytes6, err = genICMPPacketIPv6(srcIPv6, config.RouterIP6)
bytes6, err = genICMPPacketIPv6(srcIPv6, dstIPv6)
if err != nil {
log.Errorf("generate ipv6 packet error: %s", err.Error())
continue
@@ -300,9 +238,12 @@ func heartbeats(ctx context.Context, tun net.Conn, in chan<- *DataElem) {
length := copy(data, i2)
var src, dst net.IP
if index == 0 {
src, dst = srcIPv4, config.RouterIP
src, dst = srcIPv4, dstIPv4
} else {
src, dst = srcIPv6, config.RouterIP6
src, dst = srcIPv6, dstIPv6
}
if dst.IsUnspecified() {
continue
}
util.SafeWrite(in, &DataElem{
data: data[:],
@@ -403,7 +344,7 @@ func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) {
log.Debugf("[udp] can not listen %s, err: %v", h.node.Addr, err)
return
}
err = transportTun(ctx, tunInbound, tunOutbound, packetConn, h.routeNAT, h.routeConnNAT)
err = transportTun(ctx, tunInbound, tunOutbound, packetConn, h.routeMapUDP, h.routeMapTCP)
if err != nil {
log.Debugf("[tun] %s: %v", tun.LocalAddr(), err)
}
@@ -455,10 +396,10 @@ type Peer struct {
tunInbound <-chan *DataElem
tunOutbound chan<- *DataElem
routeNAT *NAT
// map[srcIP]net.Conn
// routeConnNAT sync.Map
routeConnNAT *sync.Map
// map[srcIP.String()]net.Addr for udp
routeMapUDP *RouteMap
// map[srcIP.String()]net.Conn for tcp
routeMapTCP *sync.Map
errChan chan error
}
@@ -487,7 +428,7 @@ func (p *Peer) readFromConn() {
}
func (p *Peer) readFromTCPConn() {
for packet := range Chan {
for packet := range TCPPacketChan {
u := &udpElem{
data: packet.Data[:],
length: int(packet.DataLength),
@@ -533,7 +474,7 @@ func (p *Peer) parseHeader() {
} else {
firstIPv6 = false
}
if _, loaded := p.routeNAT.LoadOrStore(e.src, e.from); loaded {
if _, loaded := p.routeMapUDP.LoadOrStore(e.src, e.from); loaded {
log.Debugf("[tun] find route: %s -> %s", e.src, e.from)
} else {
log.Debugf("[tun] new route: %s -> %s", e.src, e.from)
@@ -545,7 +486,7 @@ func (p *Peer) parseHeader() {
func (p *Peer) routePeer() {
for e := range p.parsedConnInfo {
if routeToAddr := p.routeNAT.RouteTo(e.dst); routeToAddr != nil {
if routeToAddr := p.routeMapUDP.RouteTo(e.dst); routeToAddr != nil {
log.Debugf("[tun] find route: %s -> %s", e.dst, routeToAddr)
_, err := p.conn.WriteTo(e.data[:e.length], routeToAddr)
config.LPool.Put(e.data[:])
@@ -553,7 +494,7 @@ func (p *Peer) routePeer() {
p.sendErr(err)
return
}
} else if conn, ok := p.routeConnNAT.Load(e.dst.String()); ok {
} else if conn, ok := p.routeMapTCP.Load(e.dst.String()); ok {
dgram := newDatagramPacket(e.data[:e.length])
if err := dgram.Write(conn.(net.Conn)); err != nil {
log.Debugf("[tcpserver] udp-tun %s <- %s : %s", conn.(net.Conn).RemoteAddr(), dgram.Addr(), err)
@@ -574,7 +515,7 @@ func (p *Peer) routePeer() {
func (p *Peer) routeTUN() {
for e := range p.tunInbound {
if addr := p.routeNAT.RouteTo(e.dst); addr != nil {
if addr := p.routeMapUDP.RouteTo(e.dst); addr != nil {
log.Debugf("[tun] find route: %s -> %s", e.dst, addr)
_, err := p.conn.WriteTo(e.data[:e.length], addr)
config.LPool.Put(e.data[:])
@@ -583,7 +524,7 @@ func (p *Peer) routeTUN() {
p.sendErr(err)
return
}
} else if conn, ok := p.routeConnNAT.Load(e.dst.String()); ok {
} else if conn, ok := p.routeMapTCP.Load(e.dst.String()); ok {
dgram := newDatagramPacket(e.data[:e.length])
err := dgram.Write(conn.(net.Conn))
config.LPool.Put(e.data[:])
@@ -611,15 +552,15 @@ func (p *Peer) Close() {
p.conn.Close()
}
func transportTun(ctx context.Context, tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem, packetConn net.PacketConn, nat *NAT, connNAT *sync.Map) error {
func transportTun(ctx context.Context, tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem, packetConn net.PacketConn, routeMapUDP *RouteMap, routeMapTCP *sync.Map) error {
p := &Peer{
conn: packetConn,
connInbound: make(chan *udpElem, MaxSize),
parsedConnInfo: make(chan *udpElem, MaxSize),
tunInbound: tunInbound,
tunOutbound: tunOutbound,
routeNAT: nat,
routeConnNAT: connNAT,
routeMapUDP: routeMapUDP,
routeMapTCP: routeMapTCP,
errChan: make(chan error, 2),
}