Fix(stack): update udp transport

This commit is contained in:
xjasonlyu
2021-05-13 13:29:26 +08:00
parent b51ad6feb1
commit 70afaf3d55

View File

@@ -86,33 +86,30 @@ func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
}
route, err := p.s.FindRoute(p.nicID, p.netHdr.DestinationAddress(), p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
var (
localAddress tcpip.Address
localPort uint16
)
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
localAddress = p.netHdr.DestinationAddress()
localPort = p.id.LocalPort
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
localAddress = tcpip.Address(ipv4)
localPort = uint16(udpAddr.Port)
} else {
localAddress = tcpip.Address(udpAddr.IP)
localPort = uint16(udpAddr.Port)
}
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
if err != nil {
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
}
defer route.Release()
data := v.ToVectorisedView()
// if addr is not provided, write back use original dst Addr as src Addr.
if addr == nil {
if err = sendUDP(route, data, p.id.LocalPort, p.id.RemotePort, udpNoChecksum); err != nil {
return 0, fmt.Errorf("%v", err)
}
return data.Size(), nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, fmt.Errorf("type %T is not a valid udp address", addr)
}
if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
route.LocalAddress = tcpip.Address(ipv4)
} else {
route.LocalAddress = tcpip.Address(udpAddr.IP)
}
if err = sendUDP(route, data, uint16(udpAddr.Port), p.id.RemotePort, udpNoChecksum); err != nil {
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
return 0, fmt.Errorf("%v", err)
}
return data.Size(), nil
@@ -142,7 +139,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
// transmitter skipped the checksum generation (RFC768).
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
if r.RequiresTXTransportChecksum() &&
(!noChecksum || r.NetProto == header.IPv6ProtocolNumber) {
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
@@ -152,7 +149,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
ttl := r.DefaultTTL()
if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
if err := r.WritePacket(stack.NetworkHeaderParams{
Protocol: udp.ProtocolNumber,
TTL: ttl,
TOS: 0, /* default */