mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-06 17:26:58 +08:00
Fix(stack): update udp transport
This commit is contained in:
@@ -86,33 +86,30 @@ func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
|
|||||||
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
|
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
|
||||||
}
|
}
|
||||||
|
|
||||||
route, err := p.s.FindRoute(p.nicID, p.netHdr.DestinationAddress(), p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
var (
|
||||||
|
localAddress tcpip.Address
|
||||||
|
localPort uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
|
||||||
|
localAddress = p.netHdr.DestinationAddress()
|
||||||
|
localPort = p.id.LocalPort
|
||||||
|
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
||||||
|
localAddress = tcpip.Address(ipv4)
|
||||||
|
localPort = uint16(udpAddr.Port)
|
||||||
|
} else {
|
||||||
|
localAddress = tcpip.Address(udpAddr.IP)
|
||||||
|
localPort = uint16(udpAddr.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
|
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
|
||||||
}
|
}
|
||||||
defer route.Release()
|
defer route.Release()
|
||||||
|
|
||||||
data := v.ToVectorisedView()
|
data := v.ToVectorisedView()
|
||||||
// if addr is not provided, write back use original dst Addr as src Addr.
|
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
||||||
if addr == nil {
|
|
||||||
if err = sendUDP(route, data, p.id.LocalPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
|
||||||
return 0, fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
return data.Size(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
udpAddr, ok := addr.(*net.UDPAddr)
|
|
||||||
if !ok {
|
|
||||||
return 0, fmt.Errorf("type %T is not a valid udp address", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
|
||||||
route.LocalAddress = tcpip.Address(ipv4)
|
|
||||||
} else {
|
|
||||||
route.LocalAddress = tcpip.Address(udpAddr.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = sendUDP(route, data, uint16(udpAddr.Port), p.id.RemotePort, udpNoChecksum); err != nil {
|
|
||||||
return 0, fmt.Errorf("%v", err)
|
return 0, fmt.Errorf("%v", err)
|
||||||
}
|
}
|
||||||
return data.Size(), nil
|
return data.Size(), nil
|
||||||
@@ -142,7 +139,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
|
|||||||
// transmitter skipped the checksum generation (RFC768).
|
// transmitter skipped the checksum generation (RFC768).
|
||||||
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
||||||
if r.RequiresTXTransportChecksum() &&
|
if r.RequiresTXTransportChecksum() &&
|
||||||
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
|
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
|
||||||
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
||||||
for _, v := range data.Views() {
|
for _, v := range data.Views() {
|
||||||
xsum = header.Checksum(v, xsum)
|
xsum = header.Checksum(v, xsum)
|
||||||
@@ -152,7 +149,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
|
|||||||
|
|
||||||
ttl := r.DefaultTTL()
|
ttl := r.DefaultTTL()
|
||||||
|
|
||||||
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
|
if err := r.WritePacket(stack.NetworkHeaderParams{
|
||||||
Protocol: udp.ProtocolNumber,
|
Protocol: udp.ProtocolNumber,
|
||||||
TTL: ttl,
|
TTL: ttl,
|
||||||
TOS: 0, /* default */
|
TOS: 0, /* default */
|
||||||
|
Reference in New Issue
Block a user