mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-09 02:30:05 +08:00
Refactor: optimize UDP module
Symmetric NAT support for now.
This commit is contained in:
@@ -1,192 +1,24 @@
|
||||
package stack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
// udpNoChecksum disables UDP checksum if set to true.
|
||||
udpNoChecksum = true
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
func withUDPHandler() Option {
|
||||
return func(s *Stack) error {
|
||||
udpHandlePacket := func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go HandlePacket.
|
||||
udpHdr := header.UDP(pkt.TransportHeader().View())
|
||||
if int(udpHdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
|
||||
// Malformed packet.
|
||||
s.Stats().UDP.MalformedPacketsReceived.Increment()
|
||||
return true
|
||||
udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
// TODO: handler errors in the future.
|
||||
return
|
||||
}
|
||||
|
||||
if !verifyChecksum(udpHdr, pkt) {
|
||||
// Checksum error.
|
||||
s.Stats().UDP.ChecksumErrors.Increment()
|
||||
return true
|
||||
}
|
||||
|
||||
s.Stats().UDP.PacketsReceived.Increment()
|
||||
|
||||
packet := &udpPacket{
|
||||
s: s,
|
||||
id: &id,
|
||||
data: pkt.Data().ExtractVV(),
|
||||
nicID: pkt.NICID,
|
||||
netHdr: pkt.Network(),
|
||||
netProto: pkt.NetworkProtocolNumber,
|
||||
}
|
||||
|
||||
s.handler.AddPacket(packet)
|
||||
return true
|
||||
}
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpHandlePacket)
|
||||
s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep))
|
||||
})
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
s *Stack
|
||||
id *stack.TransportEndpointID
|
||||
data buffer.VectorisedView
|
||||
nicID tcpip.NICID
|
||||
netHdr header.Network
|
||||
netProto tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (p *udpPacket) Data() []byte {
|
||||
return p.data.ToView()
|
||||
}
|
||||
|
||||
func (p *udpPacket) Drop() {}
|
||||
|
||||
func (p *udpPacket) ID() *stack.TransportEndpointID {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *udpPacket) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP(p.id.LocalAddress), Port: int(p.id.LocalPort)}
|
||||
}
|
||||
|
||||
func (p *udpPacket) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP(p.id.RemoteAddress), Port: int(p.id.RemotePort)}
|
||||
}
|
||||
|
||||
func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
|
||||
v := buffer.View(b)
|
||||
if len(v) > header.UDPMaximumPacketSize {
|
||||
// Payload can't possibly fit in a packet.
|
||||
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
|
||||
}
|
||||
|
||||
var (
|
||||
localAddress tcpip.Address
|
||||
localPort uint16
|
||||
)
|
||||
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
|
||||
localAddress = p.netHdr.DestinationAddress()
|
||||
localPort = p.id.LocalPort
|
||||
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
||||
localAddress = tcpip.Address(ipv4)
|
||||
localPort = uint16(udpAddr.Port)
|
||||
} else {
|
||||
localAddress = tcpip.Address(udpAddr.IP)
|
||||
localPort = uint16(udpAddr.Port)
|
||||
}
|
||||
|
||||
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
|
||||
}
|
||||
defer route.Release()
|
||||
|
||||
data := v.ToVectorisedView()
|
||||
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
||||
return 0, fmt.Errorf("%v", err)
|
||||
}
|
||||
return data.Size(), nil
|
||||
}
|
||||
|
||||
// sendUDP sends a UDP segment via the provided network endpoint and under the
|
||||
// provided identity.
|
||||
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, noChecksum bool) tcpip.Error {
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
|
||||
Data: data,
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
// Initialize the UDP header.
|
||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||
pkt.TransportProtocolNumber = udp.ProtocolNumber
|
||||
|
||||
length := uint16(pkt.Size())
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: localPort,
|
||||
DstPort: remotePort,
|
||||
Length: length,
|
||||
})
|
||||
|
||||
// Set the checksum field unless TX checksum offload is enabled.
|
||||
// On IPv4, UDP checksum is optional, and a zero value indicates the
|
||||
// transmitter skipped the checksum generation (RFC768).
|
||||
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
||||
if r.RequiresTXTransportChecksum() &&
|
||||
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
|
||||
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
||||
for _, v := range data.Views() {
|
||||
xsum = header.Checksum(v, xsum)
|
||||
}
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
}
|
||||
|
||||
ttl := r.DefaultTTL()
|
||||
|
||||
if err := r.WritePacket(stack.NetworkHeaderParams{
|
||||
Protocol: udp.ProtocolNumber,
|
||||
TTL: ttl,
|
||||
TOS: 0, /* default */
|
||||
}, pkt); err != nil {
|
||||
r.Stats().UDP.PacketSendErrors.Increment()
|
||||
return err
|
||||
}
|
||||
|
||||
// Track count of packets sent.
|
||||
r.Stats().UDP.PacketsSent.Increment()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go verifyChecksum.
|
||||
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
|
||||
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
|
||||
if pkt.RXTransportChecksumValidated {
|
||||
return true
|
||||
}
|
||||
|
||||
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
|
||||
// omitted the checksum generation, as per RFC 768:
|
||||
//
|
||||
// An all zero transmitted checksum value means that the transmitter
|
||||
// generated no checksum (for debugging or for higher level protocols that
|
||||
// don't care).
|
||||
//
|
||||
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
|
||||
//
|
||||
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
|
||||
// checksum is not optional.
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
netHdr := pkt.Network()
|
||||
payloadChecksum := pkt.Data().AsRange().Checksum()
|
||||
return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
|
||||
}
|
||||
|
Reference in New Issue
Block a user