Files
tun2socks/core/stack/udp.go
2021-03-21 20:57:18 +08:00

185 lines
5.3 KiB
Go
Executable File

package stack
import (
"fmt"
"net"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// udpNoChecksum disables UDP checksum.
udpNoChecksum = true
)
func withUDPHandler() Option {
return func(s *Stack) error {
udpHandlePacket := func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go HandlePacket
udpHdr := header.UDP(pkt.TransportHeader().View())
if int(udpHdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
// Malformed packet.
s.Stats().UDP.MalformedPacketsReceived.Increment()
return true
}
if !verifyChecksum(udpHdr, pkt) {
// Checksum error.
s.Stats().UDP.ChecksumErrors.Increment()
return true
}
s.Stats().UDP.PacketsReceived.Increment()
packet := &udpPacket{
s: s,
id: &id,
data: pkt.Data().ExtractVV(),
nicID: pkt.NICID,
netHdr: pkt.Network(),
netProto: pkt.NetworkProtocolNumber,
}
s.handler.AddPacket(packet)
return true
}
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpHandlePacket)
return nil
}
}
type udpPacket struct {
s *Stack
id *stack.TransportEndpointID
data buffer.VectorisedView
nicID tcpip.NICID
netHdr header.Network
netProto tcpip.NetworkProtocolNumber
}
func (p *udpPacket) Data() []byte {
return p.data.ToView()
}
func (p *udpPacket) Drop() {}
func (p *udpPacket) ID() *stack.TransportEndpointID {
return p.id
}
func (p *udpPacket) LocalAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(p.id.LocalAddress), Port: int(p.id.LocalPort)}
}
func (p *udpPacket) RemoteAddr() net.Addr {
return &net.UDPAddr{IP: net.IP(p.id.RemoteAddress), Port: int(p.id.RemotePort)}
}
func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
v := buffer.View(b)
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
}
route, err := p.s.FindRoute(p.nicID, p.netHdr.DestinationAddress(), p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
if err != nil {
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
}
defer route.Release()
data := v.ToVectorisedView()
// if addr is not provided, write back use original dst Addr as src Addr.
if addr == nil {
if err = sendUDP(route, data, p.id.LocalPort, p.id.RemotePort, udpNoChecksum); err != nil {
return 0, fmt.Errorf("%v", err)
}
return data.Size(), nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, fmt.Errorf("type %T is not a valid udp address", addr)
}
if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
route.LocalAddress = tcpip.Address(ipv4)
} else {
route.LocalAddress = tcpip.Address(udpAddr.IP)
}
if err = sendUDP(route, data, uint16(udpAddr.Port), p.id.RemotePort, udpNoChecksum); err != nil {
return 0, fmt.Errorf("%v", err)
}
return data.Size(), nil
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, noChecksum bool) tcpip.Error {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
Data: data,
})
// Initialize the UDP header.
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
pkt.TransportProtocolNumber = udp.ProtocolNumber
length := uint16(pkt.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
// Set the checksum field unless TX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value indicates the
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
if r.RequiresTXTransportChecksum() &&
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
}
ttl := r.DefaultTTL()
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: udp.ProtocolNumber,
TTL: ttl,
TOS: 0, /* default */
}, pkt); err != nil {
r.Stats().UDP.PacketSendErrors.Increment()
return err
}
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
return nil
}
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
if !pkt.RXTransportChecksumValidated &&
(hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) {
netHdr := pkt.Network()
xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length())
for _, v := range pkt.Data().Views() {
xsum = header.Checksum(v, xsum)
}
return hdr.CalculateChecksum(xsum) == 0xffff
}
return true
}