mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 20:02:37 +08:00 
			
		
		
		
	Begin generic Bind implementation
This commit is contained in:
		| @@ -24,11 +24,9 @@ type Bind interface { | |||||||
|  */ |  */ | ||||||
| type Endpoint interface { | type Endpoint interface { | ||||||
| 	ClearSrc()           // clears the source address | 	ClearSrc()           // clears the source address | ||||||
| 	ClearDst()           // clears the destination address |  | ||||||
| 	SrcToString() string // returns the local source address (ip:port) | 	SrcToString() string // returns the local source address (ip:port) | ||||||
| 	DstToString() string // returns the destination address (ip:port) | 	DstToString() string // returns the destination address (ip:port) | ||||||
| 	DstToBytes() []byte  // used for mac2 cookie calculations | 	DstToBytes() []byte  // used for mac2 cookie calculations | ||||||
| 	SetDst(string) error // used for manually setting the endpoint (uapi) |  | ||||||
| 	DstIP() net.IP | 	DstIP() net.IP | ||||||
| 	SrcIP() net.IP | 	SrcIP() net.IP | ||||||
| } | } | ||||||
| @@ -92,7 +90,7 @@ func UpdateUDPListener(device *Device) error { | |||||||
| 		// bind to new port | 		// bind to new port | ||||||
|  |  | ||||||
| 		var err error | 		var err error | ||||||
| 		netc.bind, netc.port, err = CreateUDPBind(netc.port) | 		netc.bind, netc.port, err = CreateBind(netc.port) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			netc.bind = nil | 			netc.bind = nil | ||||||
| 			return err | 			return err | ||||||
|   | |||||||
| @@ -13,11 +13,68 @@ import ( | |||||||
|  * See conn_linux.go for an implementation on the linux platform. |  * See conn_linux.go for an implementation on the linux platform. | ||||||
|  */ |  */ | ||||||
|  |  | ||||||
| type Endpoint *net.UDPAddr | type NativeBind struct { | ||||||
|  | 	ipv4 *net.UDPConn | ||||||
|  | 	ipv6 *net.UDPConn | ||||||
|  | } | ||||||
|  |  | ||||||
| type NativeBind *net.UDPConn | type NativeEndpoint net.UDPAddr | ||||||
|  |  | ||||||
| func CreateUDPBind(port uint16) (UDPBind, uint16, error) { | var _ Bind = (*NativeBind)(nil) | ||||||
|  | var _ Endpoint = (*NativeEndpoint)(nil) | ||||||
|  |  | ||||||
|  | func CreateEndpoint(s string) (Endpoint, error) { | ||||||
|  | 	addr, err := parseEndpoint(s) | ||||||
|  | 	return (addr).(*NativeEndpoint), err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (_ *NativeEndpoint) ClearSrc() {} | ||||||
|  |  | ||||||
|  | func (e *NativeEndpoint) DstIP() net.IP { | ||||||
|  | 	return (*net.UDPAddr)(e).IP | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (e *NativeEndpoint) SrcIP() net.IP { | ||||||
|  | 	return nil // not supported | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (e *NativeEndpoint) DstToBytes() []byte { | ||||||
|  | 	addr := (*net.UDPAddr)(e) | ||||||
|  | 	out := addr.IP.([]byte) | ||||||
|  | 	out = append(out, byte(addr.Port&0xff)) | ||||||
|  | 	out = append(out, byte((addr.Port>>8)&0xff)) | ||||||
|  | 	return out | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (e *NativeEndpoint) DstToString() string { | ||||||
|  | 	return (*net.UDPAddr)(e).String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (e *NativeEndpoint) SrcToString() string { | ||||||
|  | 	return "" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func listenNet(net string, port int) (*net.UDPConn, int, error) { | ||||||
|  |  | ||||||
|  | 	// listen | ||||||
|  |  | ||||||
|  | 	conn, err := net.ListenUDP("udp", &UDPAddr{Port: port}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, 0, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// retrieve port | ||||||
|  |  | ||||||
|  | 	laddr := conn.LocalAddr() | ||||||
|  | 	uaddr, _ = net.ResolveUDPAddr( | ||||||
|  | 		laddr.Network(), | ||||||
|  | 		laddr.String(), | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	return conn, uaddr.Port, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CreateBind(port uint16) (Bind, uint16, error) { | ||||||
|  |  | ||||||
| 	// listen | 	// listen | ||||||
|  |  | ||||||
| @@ -38,9 +95,3 @@ func CreateUDPBind(port uint16) (UDPBind, uint16, error) { | |||||||
| 	) | 	) | ||||||
| 	return uaddr.Port | 	return uaddr.Port | ||||||
| } | } | ||||||
|  |  | ||||||
| func (_ Endpoint) ClearSrc() {} |  | ||||||
|  |  | ||||||
| func SetMark(conn *net.UDPConn, value uint32) error { |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -50,11 +50,44 @@ func ntohs(val uint16) uint16 { | |||||||
| 	return binary.BigEndian.Uint16((*tmp)[:]) | 	return binary.BigEndian.Uint16((*tmp)[:]) | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewEndpoint() Endpoint { | func CreateEndpoint(s string) (Endpoint, error) { | ||||||
| 	return &NativeEndpoint{} | 	var end NativeEndpoint | ||||||
|  | 	addr, err := parseEndpoint(s) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ipv4 := addr.IP.To4() | ||||||
|  | 	if ipv4 != nil { | ||||||
|  | 		dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) | ||||||
|  | 		dst.Family = unix.AF_INET | ||||||
|  | 		dst.Port = htons(uint16(addr.Port)) | ||||||
|  | 		dst.Zero = [8]byte{} | ||||||
|  | 		copy(dst.Addr[:], ipv4) | ||||||
|  | 		end.ClearSrc() | ||||||
|  | 		return &end, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ipv6 := addr.IP.To16() | ||||||
|  | 	if ipv6 != nil { | ||||||
|  | 		zone, err := zoneToUint32(addr.Zone) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		dst := &end.dst | ||||||
|  | 		dst.Family = unix.AF_INET6 | ||||||
|  | 		dst.Port = htons(uint16(addr.Port)) | ||||||
|  | 		dst.Flowinfo = 0 | ||||||
|  | 		dst.Scope_id = zone | ||||||
|  | 		copy(dst.Addr[:], ipv6[:]) | ||||||
|  | 		end.ClearSrc() | ||||||
|  | 		return &end, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil, errors.New("Failed to recognize IP address format") | ||||||
| } | } | ||||||
|  |  | ||||||
| func CreateUDPBind(port uint16) (Bind, uint16, error) { | func CreateBind(port uint16) (Bind, uint16, error) { | ||||||
| 	var err error | 	var err error | ||||||
| 	var bind NativeBind | 	var bind NativeBind | ||||||
|  |  | ||||||
| @@ -325,42 +358,6 @@ func create6(port uint16) (int, uint16, error) { | |||||||
| 	return fd, uint16(addr.Port), err | 	return fd, uint16(addr.Port), err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (end *NativeEndpoint) SetDst(s string) error { |  | ||||||
| 	addr, err := parseEndpoint(s) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ipv4 := addr.IP.To4() |  | ||||||
| 	if ipv4 != nil { |  | ||||||
| 		dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) |  | ||||||
| 		dst.Family = unix.AF_INET |  | ||||||
| 		dst.Port = htons(uint16(addr.Port)) |  | ||||||
| 		dst.Zero = [8]byte{} |  | ||||||
| 		copy(dst.Addr[:], ipv4) |  | ||||||
| 		end.ClearSrc() |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ipv6 := addr.IP.To16() |  | ||||||
| 	if ipv6 != nil { |  | ||||||
| 		zone, err := zoneToUint32(addr.Zone) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return err |  | ||||||
| 		} |  | ||||||
| 		dst := &end.dst |  | ||||||
| 		dst.Family = unix.AF_INET6 |  | ||||||
| 		dst.Port = htons(uint16(addr.Port)) |  | ||||||
| 		dst.Flowinfo = 0 |  | ||||||
| 		dst.Scope_id = zone |  | ||||||
| 		copy(dst.Addr[:], ipv6[:]) |  | ||||||
| 		end.ClearSrc() |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return errors.New("Failed to recognize IP address format") |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func send6(sock int, end *NativeEndpoint, buff []byte) error { | func send6(sock int, end *NativeEndpoint, buff []byte) error { | ||||||
|  |  | ||||||
| 	// construct message header | 	// construct message header | ||||||
|   | |||||||
| @@ -260,9 +260,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { | |||||||
| 				err := func() error { | 				err := func() error { | ||||||
| 					peer.mutex.Lock() | 					peer.mutex.Lock() | ||||||
| 					defer peer.mutex.Unlock() | 					defer peer.mutex.Unlock() | ||||||
|  | 					endpoint, err := CreateEndpoint(value) | ||||||
| 					endpoint := NewEndpoint() | 					if err != nil { | ||||||
| 					if err := endpoint.SetDst(value); err != nil { |  | ||||||
| 						return err | 						return err | ||||||
| 					} | 					} | ||||||
| 					peer.endpoint = endpoint | 					peer.endpoint = endpoint | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Mathias Hall-Andersen
					Mathias Hall-Andersen