diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index fe9134b9..f3b6bd25 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -129,12 +129,15 @@ func (d *Device) Close() { util.SafeClose(TCPPacketChan) } -func (d *Device) transport(ctx context.Context, addr string, routeMapUDP *sync.Map, routeMapTCP *sync.Map) { - for ctx.Err() == nil { +func (d *Device) transport(ctx1 context.Context, addr string, routeMapUDP *sync.Map, routeMapTCP *sync.Map) { + for ctx1.Err() == nil { func() { + ctx, cancelFunc := context.WithCancel(ctx1) + defer cancelFunc() + packetConn, err := (&net.ListenConfig{}).ListenPacket(ctx, "udp", addr) if err != nil { - plog.G(ctx).Errorf("[UDP] Failed to listen %s: %v", addr, err) + plog.G(ctx1).Errorf("[UDP] Failed to listen %s: %v", addr, err) return } @@ -156,7 +159,7 @@ func (d *Device) transport(ctx context.Context, addr string, routeMapUDP *sync.M select { case err = <-p.errChan: - plog.G(ctx).Errorf("[TUN] %s: %v", d.tun.LocalAddr(), err) + plog.G(ctx1).Errorf("[TUN] %s: %v", d.tun.LocalAddr(), err) return case <-ctx.Done(): return @@ -214,7 +217,7 @@ func (p *Peer) sendErr(err error) { func (p *Peer) readFromConn(ctx context.Context) { defer util.HandleCrash() - for { + for ctx.Err() == nil { buf := config.LPool.Get().([]byte)[:] n, from, err := p.conn.ReadFrom(buf[:]) if err != nil { @@ -249,43 +252,53 @@ func (p *Peer) readFromConn(ctx context.Context) { func (p *Peer) readFromTCPConn(ctx context.Context) { defer util.HandleCrash() - for packet := range TCPPacketChan { - src, dst, err := util.ParseIP(packet.Data) - if err != nil { - plog.G(ctx).Errorf("[TCP] Unknown packet") - config.LPool.Put(packet.Data[:]) - continue - } - plog.G(ctx).Debugf("[TCP] SRC: %s > DST: %s Length: %d", src, dst, packet.DataLength) - p.tcpInbound <- &Packet{ - data: packet.Data[:], - length: int(packet.DataLength), - src: src, - dst: dst, + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case packet := <-TCPPacketChan: + src, dst, err := util.ParseIP(packet.Data) + if err != nil { + plog.G(ctx).Errorf("[TCP] Unknown packet") + config.LPool.Put(packet.Data[:]) + continue + } + plog.G(ctx).Debugf("[TCP] SRC: %s > DST: %s Length: %d", src, dst, packet.DataLength) + p.tcpInbound <- &Packet{ + data: packet.Data[:], + length: int(packet.DataLength), + src: src, + dst: dst, + } } } } func (p *Peer) routeTCP(ctx context.Context) { defer util.HandleCrash() - for packet := range p.tcpInbound { - if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { - plog.G(ctx).Debugf("[TCP] Find TCP route SRC: %s to DST: %s -> %s", packet.src.String(), packet.dst.String(), conn.(net.Conn).RemoteAddr()) - dgram := newDatagramPacket(packet.data[:packet.length]) - err := dgram.Write(conn.(net.Conn)) - config.LPool.Put(packet.data[:]) - if err != nil { - plog.G(ctx).Errorf("[TCP] Failed to write to %s <- %s : %s", conn.(net.Conn).RemoteAddr(), conn.(net.Conn).LocalAddr(), err) - p.sendErr(err) - return - } - } else { - plog.G(ctx).Debugf("[TCP] Not found route, write to TUN device. SRC: %s, DST: %s", packet.src.String(), packet.dst.String()) - p.tunOutbound <- &Packet{ - data: packet.data, - length: packet.length, - src: packet.src, - dst: packet.dst, + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case packet := <-p.tcpInbound: + if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { + plog.G(ctx).Debugf("[TCP] Find TCP route SRC: %s to DST: %s -> %s", packet.src.String(), packet.dst.String(), conn.(net.Conn).RemoteAddr()) + dgram := newDatagramPacket(packet.data[:packet.length]) + err := dgram.Write(conn.(net.Conn)) + config.LPool.Put(packet.data[:]) + if err != nil { + plog.G(ctx).Errorf("[TCP] Failed to write to %s <- %s : %s", conn.(net.Conn).RemoteAddr(), conn.(net.Conn).LocalAddr(), err) + p.sendErr(err) + return + } + } else { + plog.G(ctx).Debugf("[TCP] Not found route, write to TUN device. SRC: %s, DST: %s", packet.src.String(), packet.dst.String()) + p.tunOutbound <- &Packet{ + data: packet.data, + length: packet.length, + src: packet.src, + dst: packet.dst, + } } } } @@ -293,29 +306,34 @@ func (p *Peer) routeTCP(ctx context.Context) { func (p *Peer) routeTUN(ctx context.Context) { defer util.HandleCrash() - for packet := range p.tunInbound { - if addr, ok := p.routeMapUDP.Load(packet.dst.String()); ok { - plog.G(ctx).Debugf("[TUN] Find UDP route to DST: %s -> %s, SRC: %s, DST: %s", packet.dst, addr, packet.src.String(), packet.dst.String()) - _, err := p.conn.WriteTo(packet.data[:packet.length], addr.(net.Addr)) - config.LPool.Put(packet.data[:]) - if err != nil { - plog.G(ctx).Errorf("[TUN] Failed wirte to route dst: %s -> %s", packet.dst, addr) - p.sendErr(err) - return + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case packet := <-p.tunInbound: + if addr, ok := p.routeMapUDP.Load(packet.dst.String()); ok { + plog.G(ctx).Debugf("[TUN] Find UDP route to DST: %s -> %s, SRC: %s, DST: %s", packet.dst, addr, packet.src.String(), packet.dst.String()) + _, err := p.conn.WriteTo(packet.data[:packet.length], addr.(net.Addr)) + config.LPool.Put(packet.data[:]) + if err != nil { + plog.G(ctx).Errorf("[TUN] Failed wirte to route dst: %s -> %s", packet.dst, addr) + p.sendErr(err) + return + } + } else if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { + plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr()) + dgram := newDatagramPacket(packet.data[:packet.length]) + err := dgram.Write(conn.(net.Conn)) + config.LPool.Put(packet.data[:]) + if err != nil { + plog.G(ctx).Errorf("[TUN] Failed to write TCP %s <- %s : %s", conn.(net.Conn).RemoteAddr(), conn.(net.Conn).LocalAddr(), err) + p.sendErr(err) + return + } + } else { + plog.G(ctx).Warnf("[TUN] No route for src: %s -> dst: %s, drop it", packet.src, packet.dst) + config.LPool.Put(packet.data[:]) } - } else if conn, ok := p.routeMapTCP.Load(packet.dst.String()); ok { - plog.G(ctx).Debugf("[TUN] Find TCP route to dst: %s -> %s", packet.dst.String(), conn.(net.Conn).RemoteAddr()) - dgram := newDatagramPacket(packet.data[:packet.length]) - err := dgram.Write(conn.(net.Conn)) - config.LPool.Put(packet.data[:]) - if err != nil { - plog.G(ctx).Errorf("[TUN] Failed to write TCP %s <- %s : %s", conn.(net.Conn).RemoteAddr(), conn.(net.Conn).LocalAddr(), err) - p.sendErr(err) - return - } - } else { - plog.G(ctx).Warnf("[TUN] No route for src: %s -> dst: %s, drop it", packet.src, packet.dst) - config.LPool.Put(packet.data[:]) } } }