mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-06 17:26:58 +08:00
Fix(stack): update udp transport
This commit is contained in:
@@ -86,33 +86,30 @@ func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
|
||||
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
|
||||
}
|
||||
|
||||
route, err := p.s.FindRoute(p.nicID, p.netHdr.DestinationAddress(), p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
||||
var (
|
||||
localAddress tcpip.Address
|
||||
localPort uint16
|
||||
)
|
||||
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
|
||||
localAddress = p.netHdr.DestinationAddress()
|
||||
localPort = p.id.LocalPort
|
||||
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
||||
localAddress = tcpip.Address(ipv4)
|
||||
localPort = uint16(udpAddr.Port)
|
||||
} else {
|
||||
localAddress = tcpip.Address(udpAddr.IP)
|
||||
localPort = uint16(udpAddr.Port)
|
||||
}
|
||||
|
||||
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
|
||||
}
|
||||
defer route.Release()
|
||||
|
||||
data := v.ToVectorisedView()
|
||||
// if addr is not provided, write back use original dst Addr as src Addr.
|
||||
if addr == nil {
|
||||
if err = sendUDP(route, data, p.id.LocalPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
||||
return 0, fmt.Errorf("%v", err)
|
||||
}
|
||||
return data.Size(), nil
|
||||
}
|
||||
|
||||
udpAddr, ok := addr.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("type %T is not a valid udp address", addr)
|
||||
}
|
||||
|
||||
if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
||||
route.LocalAddress = tcpip.Address(ipv4)
|
||||
} else {
|
||||
route.LocalAddress = tcpip.Address(udpAddr.IP)
|
||||
}
|
||||
|
||||
if err = sendUDP(route, data, uint16(udpAddr.Port), p.id.RemotePort, udpNoChecksum); err != nil {
|
||||
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
||||
return 0, fmt.Errorf("%v", err)
|
||||
}
|
||||
return data.Size(), nil
|
||||
@@ -142,7 +139,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
|
||||
// transmitter skipped the checksum generation (RFC768).
|
||||
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
||||
if r.RequiresTXTransportChecksum() &&
|
||||
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
|
||||
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
|
||||
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
||||
for _, v := range data.Views() {
|
||||
xsum = header.Checksum(v, xsum)
|
||||
@@ -152,7 +149,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
|
||||
|
||||
ttl := r.DefaultTTL()
|
||||
|
||||
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
|
||||
if err := r.WritePacket(stack.NetworkHeaderParams{
|
||||
Protocol: udp.ProtocolNumber,
|
||||
TTL: ttl,
|
||||
TOS: 0, /* default */
|
||||
|
Reference in New Issue
Block a user