Refactor(metadata): replace net.IP with netip.Addr (#396)

This commit is contained in:
Jason Lyu
2024-09-01 02:57:09 +08:00
committed by GitHub
parent 1f09b4d42d
commit bd37a1a4c6
6 changed files with 57 additions and 44 deletions

View File

@@ -2,26 +2,34 @@ package metadata
import ( import (
"net" "net"
"strconv" "net/netip"
) )
// Metadata contains metadata of transport protocol sessions. // Metadata contains metadata of transport protocol sessions.
type Metadata struct { type Metadata struct {
Network Network `json:"network"` Network Network `json:"network"`
SrcIP net.IP `json:"sourceIP"` SrcIP netip.Addr `json:"sourceIP"`
MidIP net.IP `json:"dialerIP"` MidIP netip.Addr `json:"dialerIP"`
DstIP net.IP `json:"destinationIP"` DstIP netip.Addr `json:"destinationIP"`
SrcPort uint16 `json:"sourcePort"` SrcPort uint16 `json:"sourcePort"`
MidPort uint16 `json:"dialerPort"` MidPort uint16 `json:"dialerPort"`
DstPort uint16 `json:"destinationPort"` DstPort uint16 `json:"destinationPort"`
}
func (m *Metadata) DestinationAddrPort() netip.AddrPort {
return netip.AddrPortFrom(m.DstIP, m.DstPort)
} }
func (m *Metadata) DestinationAddress() string { func (m *Metadata) DestinationAddress() string {
return net.JoinHostPort(m.DstIP.String(), strconv.FormatUint(uint64(m.DstPort), 10)) return m.DestinationAddrPort().String()
}
func (m *Metadata) SourceAddrPort() netip.AddrPort {
return netip.AddrPortFrom(m.SrcIP, m.SrcPort)
} }
func (m *Metadata) SourceAddress() string { func (m *Metadata) SourceAddress() string {
return net.JoinHostPort(m.SrcIP.String(), strconv.FormatUint(uint64(m.SrcPort), 10)) return m.SourceAddrPort().String()
} }
func (m *Metadata) Addr() net.Addr { func (m *Metadata) Addr() net.Addr {
@@ -29,23 +37,17 @@ func (m *Metadata) Addr() net.Addr {
} }
func (m *Metadata) TCPAddr() *net.TCPAddr { func (m *Metadata) TCPAddr() *net.TCPAddr {
if m.Network != TCP || m.DstIP == nil { if m.Network != TCP || !m.DstIP.IsValid() {
return nil return nil
} }
return &net.TCPAddr{ return net.TCPAddrFromAddrPort(m.DestinationAddrPort())
IP: m.DstIP,
Port: int(m.DstPort),
}
} }
func (m *Metadata) UDPAddr() *net.UDPAddr { func (m *Metadata) UDPAddr() *net.UDPAddr {
if m.Network != UDP || m.DstIP == nil { if m.Network != UDP || !m.DstIP.IsValid() {
return nil return nil
} }
return &net.UDPAddr{ return net.UDPAddrFromAddrPort(m.DestinationAddrPort())
IP: m.DstIP,
Port: int(m.DstPort),
}
} }
// Addr implements the net.Addr interface. // Addr implements the net.Addr interface.

View File

@@ -243,7 +243,7 @@ func serializeRelayAddr(m *M.Metadata) *relay.AddrFeature {
Host: m.DstIP.String(), Host: m.DstIP.String(),
Port: m.DstPort, Port: m.DstPort,
} }
if m.DstIP.To4() != nil { if m.DstIP.Is4() {
af.AType = relay.AddrIPv4 af.AType = relay.AddrIPv4
} else { } else {
af.AType = relay.AddrIPv6 af.AType = relay.AddrIPv6

View File

@@ -186,5 +186,5 @@ func (pc *socksPacketConn) Close() error {
} }
func serializeSocksAddr(m *M.Metadata) socks5.Addr { func serializeSocksAddr(m *M.Metadata) socks5.Addr {
return socks5.SerializeAddr("", m.DstIP, m.DstPort) return socks5.SerializeAddr("", m.DstIP.AsSlice(), m.DstPort)
} }

View File

@@ -2,26 +2,37 @@ package tunnel
import ( import (
"net" "net"
"strconv" "net/netip"
"gvisor.dev/gvisor/pkg/tcpip"
) )
// parseAddr parses net.Addr to IP and port. // parseNetAddr parses net.Addr to IP and port.
func parseAddr(addr net.Addr) (net.IP, uint16) { func parseNetAddr(addr net.Addr) (netip.Addr, uint16) {
switch v := addr.(type) { if addr == nil {
case *net.TCPAddr: return netip.Addr{}, 0
return v.IP, uint16(v.Port)
case *net.UDPAddr:
return v.IP, uint16(v.Port)
case nil:
return nil, 0
default:
return parseAddrString(addr.String())
} }
if v, ok := addr.(interface {
AddrPort() netip.AddrPort
}); ok {
ap := v.AddrPort()
return ap.Addr(), ap.Port()
}
return parseAddrString(addr.String())
} }
// parseAddrString parses address string to IP and port. // parseAddrString parses address string to IP and port.
func parseAddrString(addr string) (net.IP, uint16) { // It doesn't do any name resolution.
host, port, _ := net.SplitHostPort(addr) func parseAddrString(s string) (netip.Addr, uint16) {
portInt, _ := strconv.ParseUint(port, 10, 16) ap, err := netip.ParseAddrPort(s)
return net.ParseIP(host), uint16(portInt) if err != nil {
return netip.Addr{}, 0
}
return ap.Addr(), ap.Port()
}
// parseTCPIPAddress parses tcpip.Address to netip.Addr.
func parseTCPIPAddress(addr tcpip.Address) netip.Addr {
ip, _ := netip.AddrFromSlice(addr.AsSlice())
return ip
} }

View File

@@ -20,9 +20,9 @@ func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) {
id := originConn.ID() id := originConn.ID()
metadata := &M.Metadata{ metadata := &M.Metadata{
Network: M.TCP, Network: M.TCP,
SrcIP: net.IP(id.RemoteAddress.AsSlice()), SrcIP: parseTCPIPAddress(id.RemoteAddress),
SrcPort: id.RemotePort, SrcPort: id.RemotePort,
DstIP: net.IP(id.LocalAddress.AsSlice()), DstIP: parseTCPIPAddress(id.LocalAddress),
DstPort: id.LocalPort, DstPort: id.LocalPort,
} }
@@ -34,7 +34,7 @@ func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) {
log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err) log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err)
return return
} }
metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr()) metadata.MidIP, metadata.MidPort = parseNetAddr(remoteConn.LocalAddr())
remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager) remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager)
defer remoteConn.Close() defer remoteConn.Close()

View File

@@ -20,9 +20,9 @@ func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) {
id := uc.ID() id := uc.ID()
metadata := &M.Metadata{ metadata := &M.Metadata{
Network: M.UDP, Network: M.UDP,
SrcIP: net.IP(id.RemoteAddress.AsSlice()), SrcIP: parseTCPIPAddress(id.RemoteAddress),
SrcPort: id.RemotePort, SrcPort: id.RemotePort,
DstIP: net.IP(id.LocalAddress.AsSlice()), DstIP: parseTCPIPAddress(id.LocalAddress),
DstPort: id.LocalPort, DstPort: id.LocalPort,
} }
@@ -31,7 +31,7 @@ func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) {
log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err)
return return
} }
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr()) metadata.MidIP, metadata.MidPort = parseNetAddr(pc.LocalAddr())
pc = statistic.NewUDPTracker(pc, metadata, t.manager) pc = statistic.NewUDPTracker(pc, metadata, t.manager)
defer pc.Close() defer pc.Close()