diff --git a/core/stack/udp.go b/core/stack/udp.go index ab5b762..74763e9 100755 --- a/core/stack/udp.go +++ b/core/stack/udp.go @@ -86,33 +86,30 @@ func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) { return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{}) } - route, err := p.s.FindRoute(p.nicID, p.netHdr.DestinationAddress(), p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */) + var ( + localAddress tcpip.Address + localPort uint16 + ) + + if udpAddr, ok := addr.(*net.UDPAddr); !ok { + localAddress = p.netHdr.DestinationAddress() + localPort = p.id.LocalPort + } else if ipv4 := udpAddr.IP.To4(); ipv4 != nil { + localAddress = tcpip.Address(ipv4) + localPort = uint16(udpAddr.Port) + } else { + localAddress = tcpip.Address(udpAddr.IP) + localPort = uint16(udpAddr.Port) + } + + route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */) if err != nil { return 0, fmt.Errorf("%#v find route: %s", p.id, err) } defer route.Release() data := v.ToVectorisedView() - // if addr is not provided, write back use original dst Addr as src Addr. - if addr == nil { - if err = sendUDP(route, data, p.id.LocalPort, p.id.RemotePort, udpNoChecksum); err != nil { - return 0, fmt.Errorf("%v", err) - } - return data.Size(), nil - } - - udpAddr, ok := addr.(*net.UDPAddr) - if !ok { - return 0, fmt.Errorf("type %T is not a valid udp address", addr) - } - - if ipv4 := udpAddr.IP.To4(); ipv4 != nil { - route.LocalAddress = tcpip.Address(ipv4) - } else { - route.LocalAddress = tcpip.Address(udpAddr.IP) - } - - if err = sendUDP(route, data, uint16(udpAddr.Port), p.id.RemotePort, udpNoChecksum); err != nil { + if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil { return 0, fmt.Errorf("%v", err) } return data.Size(), nil @@ -142,7 +139,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). if r.RequiresTXTransportChecksum() && - (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { + (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) @@ -152,7 +149,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u ttl := r.DefaultTTL() - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + if err := r.WritePacket(stack.NetworkHeaderParams{ Protocol: udp.ProtocolNumber, TTL: ttl, TOS: 0, /* default */