mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 11:56:22 +08:00 
			
		
		
		
	all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which to receive packets (an IPv4 source and an IPv6 source), allow the conn.Bind to specify a set of sources. Beneficial consequences: * If there's no IPv6 support on a system, conn.Bind.Open can choose not to return a receive function for it, which is simpler than tracking that state in the bind. This simplification removes existing data races from both conn.StdNetBind and bindtest.ChannelBind. * If there are more than two sources on a system, the conn.Bind no longer needs to add a separate muxing layer. Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
This commit is contained in:
		 Josh Bleecher Snyder
					Josh Bleecher Snyder
				
			
				
					committed by
					
						 Jason A. Donenfeld
						Jason A. Donenfeld
					
				
			
			
				
	
			
			
			 Jason A. Donenfeld
						Jason A. Donenfeld
					
				
			
						parent
						
							8ed83e0427
						
					
				
				
					commit
					10533c3e73
				
			| @@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { | |||||||
|  |  | ||||||
| // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. | // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. | ||||||
| type LinuxSocketBind struct { | type LinuxSocketBind struct { | ||||||
| 	sock4    int | 	// mu guards sock4 and sock6 and the associated fds. | ||||||
| 	sock6    int | 	// As long as someone holds mu (read or write), the associated fds are valid. | ||||||
| 	lastMark uint32 | 	mu    sync.RWMutex | ||||||
| 	closing  sync.RWMutex | 	sock4 int | ||||||
|  | 	sock6 int | ||||||
| } | } | ||||||
|  |  | ||||||
| func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } | func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } | ||||||
| @@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { | |||||||
| 	return nil, errors.New("invalid IP address") | 	return nil, errors.New("invalid IP address") | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { | func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { | ||||||
|  | 	bind.mu.Lock() | ||||||
|  | 	defer bind.mu.Unlock() | ||||||
|  |  | ||||||
| 	var err error | 	var err error | ||||||
| 	var newPort uint16 | 	var newPort uint16 | ||||||
| 	var tries int | 	var tries int | ||||||
|  |  | ||||||
| 	if bind.sock4 != -1 || bind.sock6 != -1 { | 	if bind.sock4 != -1 || bind.sock6 != -1 { | ||||||
| 		return 0, ErrBindAlreadyOpen | 		return nil, 0, ErrBindAlreadyOpen | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	originalPort := port | 	originalPort := port | ||||||
|  |  | ||||||
| again: | again: | ||||||
| 	port = originalPort | 	port = originalPort | ||||||
|  | 	var sock4, sock6 int | ||||||
| 	// Attempt ipv6 bind, update port if successful. | 	// Attempt ipv6 bind, update port if successful. | ||||||
| 	bind.sock6, newPort, err = create6(port) | 	sock6, newPort, err = create6(port) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if err != syscall.EAFNOSUPPORT { | 		if !errors.Is(err, syscall.EAFNOSUPPORT) { | ||||||
| 			return 0, err | 			return nil, 0, err | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		port = newPort | 		port = newPort | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Attempt ipv4 bind, update port if successful. | 	// Attempt ipv4 bind, update port if successful. | ||||||
| 	bind.sock4, newPort, err = create4(port) | 	sock4, newPort, err = create4(port) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { | 		if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { | ||||||
| 			unix.Close(bind.sock6) | 			unix.Close(sock6) | ||||||
| 			tries++ | 			tries++ | ||||||
| 			goto again | 			goto again | ||||||
| 		} | 		} | ||||||
| 		if err != syscall.EAFNOSUPPORT { | 		if !errors.Is(err, syscall.EAFNOSUPPORT) { | ||||||
| 			unix.Close(bind.sock6) | 			unix.Close(sock6) | ||||||
| 			return 0, err | 			return nil, 0, err | ||||||
| 		} | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		port = newPort | 		port = newPort | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if bind.sock4 == -1 && bind.sock6 == -1 { | 	var fns []ReceiveFunc | ||||||
| 		return 0, syscall.EAFNOSUPPORT | 	if sock4 != -1 { | ||||||
|  | 		fns = append(fns, makeReceiveIPv4(sock4)) | ||||||
|  | 		bind.sock4 = sock4 | ||||||
| 	} | 	} | ||||||
| 	return port, nil | 	if sock6 != -1 { | ||||||
|  | 		fns = append(fns, makeReceiveIPv6(sock6)) | ||||||
|  | 		bind.sock6 = sock6 | ||||||
|  | 	} | ||||||
|  | 	if len(fns) == 0 { | ||||||
|  | 		return nil, 0, syscall.EAFNOSUPPORT | ||||||
|  | 	} | ||||||
|  | 	return fns, port, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *LinuxSocketBind) SetMark(value uint32) error { | func (bind *LinuxSocketBind) SetMark(value uint32) error { | ||||||
| 	bind.closing.RLock() | 	bind.mu.RLock() | ||||||
| 	defer bind.closing.RUnlock() | 	defer bind.mu.RUnlock() | ||||||
|  |  | ||||||
| 	if bind.sock6 != -1 { | 	if bind.sock6 != -1 { | ||||||
| 		err := unix.SetsockoptInt( | 		err := unix.SetsockoptInt( | ||||||
| @@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error { | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	bind.lastMark = value |  | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *LinuxSocketBind) Close() error { | func (bind *LinuxSocketBind) Close() error { | ||||||
| 	var err1, err2 error | 	// Take a readlock to shut down the sockets... | ||||||
| 	bind.closing.RLock() | 	bind.mu.RLock() | ||||||
| 	if bind.sock6 != -1 { | 	if bind.sock6 != -1 { | ||||||
| 		unix.Shutdown(bind.sock6, unix.SHUT_RDWR) | 		unix.Shutdown(bind.sock6, unix.SHUT_RDWR) | ||||||
| 	} | 	} | ||||||
| 	if bind.sock4 != -1 { | 	if bind.sock4 != -1 { | ||||||
| 		unix.Shutdown(bind.sock4, unix.SHUT_RDWR) | 		unix.Shutdown(bind.sock4, unix.SHUT_RDWR) | ||||||
| 	} | 	} | ||||||
| 	bind.closing.RUnlock() | 	bind.mu.RUnlock() | ||||||
| 	bind.closing.Lock() | 	// ...and a write lock to close the fd. | ||||||
|  | 	// This ensures that no one else is using the fd. | ||||||
|  | 	bind.mu.Lock() | ||||||
|  | 	defer bind.mu.Unlock() | ||||||
|  | 	var err1, err2 error | ||||||
| 	if bind.sock6 != -1 { | 	if bind.sock6 != -1 { | ||||||
| 		err1 = unix.Close(bind.sock6) | 		err1 = unix.Close(bind.sock6) | ||||||
| 		bind.sock6 = -1 | 		bind.sock6 = -1 | ||||||
| @@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error { | |||||||
| 		err2 = unix.Close(bind.sock4) | 		err2 = unix.Close(bind.sock4) | ||||||
| 		bind.sock4 = -1 | 		bind.sock4 = -1 | ||||||
| 	} | 	} | ||||||
| 	bind.closing.Unlock() |  | ||||||
|  |  | ||||||
| 	if err1 != nil { | 	if err1 != nil { | ||||||
| 		return err1 | 		return err1 | ||||||
| @@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error { | |||||||
| 	return err2 | 	return err2 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { | func makeReceiveIPv6(sock int) ReceiveFunc { | ||||||
| 	bind.closing.RLock() | 	return func(buff []byte) (int, Endpoint, error) { | ||||||
| 	defer bind.closing.RUnlock() | 		var end LinuxSocketEndpoint | ||||||
|  | 		n, err := receive6(sock, buff, &end) | ||||||
| 	var end LinuxSocketEndpoint | 		return n, &end, err | ||||||
| 	if bind.sock6 == -1 { |  | ||||||
| 		return 0, nil, net.ErrClosed |  | ||||||
| 	} | 	} | ||||||
| 	n, err := receive6( |  | ||||||
| 		bind.sock6, |  | ||||||
| 		buff, |  | ||||||
| 		&end, |  | ||||||
| 	) |  | ||||||
| 	return n, &end, err |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { | func makeReceiveIPv4(sock int) ReceiveFunc { | ||||||
| 	bind.closing.RLock() | 	return func(buff []byte) (int, Endpoint, error) { | ||||||
| 	defer bind.closing.RUnlock() | 		var end LinuxSocketEndpoint | ||||||
|  | 		n, err := receive4(sock, buff, &end) | ||||||
| 	var end LinuxSocketEndpoint | 		return n, &end, err | ||||||
| 	if bind.sock4 == -1 { |  | ||||||
| 		return 0, nil, net.ErrClosed |  | ||||||
| 	} | 	} | ||||||
| 	n, err := receive4( |  | ||||||
| 		bind.sock4, |  | ||||||
| 		buff, |  | ||||||
| 		&end, |  | ||||||
| 	) |  | ||||||
| 	return n, &end, err |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { | func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { | ||||||
| 	bind.closing.RLock() |  | ||||||
| 	defer bind.closing.RUnlock() |  | ||||||
|  |  | ||||||
| 	nend, ok := end.(*LinuxSocketEndpoint) | 	nend, ok := end.(*LinuxSocketEndpoint) | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return ErrWrongEndpointType | 		return ErrWrongEndpointType | ||||||
| 	} | 	} | ||||||
|  | 	bind.mu.RLock() | ||||||
|  | 	defer bind.mu.RUnlock() | ||||||
| 	if !nend.isV6 { | 	if !nend.isV6 { | ||||||
| 		if bind.sock4 == -1 { | 		if bind.sock4 == -1 { | ||||||
| 			return net.ErrClosed | 			return net.ErrClosed | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ package conn | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"sync" | ||||||
| 	"syscall" | 	"syscall" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -16,6 +17,7 @@ import ( | |||||||
| // It uses the Go's net package to implement networking. | // It uses the Go's net package to implement networking. | ||||||
| // See LinuxSocketBind for a proper implementation on the Linux platform. | // See LinuxSocketBind for a proper implementation on the Linux platform. | ||||||
| type StdNetBind struct { | type StdNetBind struct { | ||||||
|  | 	mu         sync.Mutex // protects following fields | ||||||
| 	ipv4       *net.UDPConn | 	ipv4       *net.UDPConn | ||||||
| 	ipv6       *net.UDPConn | 	ipv6       *net.UDPConn | ||||||
| 	blackhole4 bool | 	blackhole4 bool | ||||||
| @@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { | |||||||
| 	return conn, uaddr.Port, nil | 	return conn, uaddr.Port, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *StdNetBind) Open(uport uint16) (uint16, error) { | func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { | ||||||
|  | 	bind.mu.Lock() | ||||||
|  | 	defer bind.mu.Unlock() | ||||||
|  |  | ||||||
| 	var err error | 	var err error | ||||||
| 	var tries int | 	var tries int | ||||||
|  |  | ||||||
| 	if bind.ipv4 != nil || bind.ipv6 != nil { | 	if bind.ipv4 != nil || bind.ipv6 != nil { | ||||||
| 		return 0, ErrBindAlreadyOpen | 		return nil, 0, ErrBindAlreadyOpen | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Attempt to open ipv4 and ipv6 listeners on the same port. | 	// Attempt to open ipv4 and ipv6 listeners on the same port. | ||||||
| @@ -97,7 +102,7 @@ again: | |||||||
|  |  | ||||||
| 	ipv4, port, err = listenNet("udp4", port) | 	ipv4, port, err = listenNet("udp4", port) | ||||||
| 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { | 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { | ||||||
| 		return 0, err | 		return nil, 0, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// Listen on the same port as we're using for ipv4. | 	// Listen on the same port as we're using for ipv4. | ||||||
| @@ -109,17 +114,27 @@ again: | |||||||
| 	} | 	} | ||||||
| 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { | 	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { | ||||||
| 		ipv4.Close() | 		ipv4.Close() | ||||||
| 		return 0, err | 		return nil, 0, err | ||||||
| 	} | 	} | ||||||
| 	if ipv4 == nil && ipv6 == nil { | 	var fns []ReceiveFunc | ||||||
| 		return 0, syscall.EAFNOSUPPORT | 	if ipv4 != nil { | ||||||
|  | 		fns = append(fns, makeReceiveFunc(ipv4, true)) | ||||||
|  | 		bind.ipv4 = ipv4 | ||||||
| 	} | 	} | ||||||
| 	bind.ipv4 = ipv4 | 	if ipv6 != nil { | ||||||
| 	bind.ipv6 = ipv6 | 		fns = append(fns, makeReceiveFunc(ipv6, false)) | ||||||
| 	return uint16(port), nil | 		bind.ipv6 = ipv6 | ||||||
|  | 	} | ||||||
|  | 	if len(fns) == 0 { | ||||||
|  | 		return nil, 0, syscall.EAFNOSUPPORT | ||||||
|  | 	} | ||||||
|  | 	return fns, uint16(port), nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *StdNetBind) Close() error { | func (bind *StdNetBind) Close() error { | ||||||
|  | 	bind.mu.Lock() | ||||||
|  | 	defer bind.mu.Unlock() | ||||||
|  |  | ||||||
| 	var err1, err2 error | 	var err1, err2 error | ||||||
| 	if bind.ipv4 != nil { | 	if bind.ipv4 != nil { | ||||||
| 		err1 = bind.ipv4.Close() | 		err1 = bind.ipv4.Close() | ||||||
| @@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error { | |||||||
| 	return err2 | 	return err2 | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { | func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc { | ||||||
| 	if bind.ipv4 == nil { | 	return func(buff []byte) (int, Endpoint, error) { | ||||||
| 		return 0, nil, syscall.EAFNOSUPPORT | 		n, endpoint, err := conn.ReadFromUDP(buff) | ||||||
|  | 		if isIPv4 && endpoint != nil { | ||||||
|  | 			endpoint.IP = endpoint.IP.To4() | ||||||
|  | 		} | ||||||
|  | 		return n, (*StdNetEndpoint)(endpoint), err | ||||||
| 	} | 	} | ||||||
| 	n, endpoint, err := bind.ipv4.ReadFromUDP(buff) |  | ||||||
| 	if endpoint != nil { |  | ||||||
| 		endpoint.IP = endpoint.IP.To4() |  | ||||||
| 	} |  | ||||||
| 	return n, (*StdNetEndpoint)(endpoint), err |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { |  | ||||||
| 	if bind.ipv6 == nil { |  | ||||||
| 		return 0, nil, syscall.EAFNOSUPPORT |  | ||||||
| 	} |  | ||||||
| 	n, endpoint, err := bind.ipv6.ReadFromUDP(buff) |  | ||||||
| 	return n, (*StdNetEndpoint)(endpoint), err |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { | func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { | ||||||
| @@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { | |||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return ErrWrongEndpointType | 		return ErrWrongEndpointType | ||||||
| 	} | 	} | ||||||
| 	var conn *net.UDPConn |  | ||||||
| 	var blackhole bool | 	bind.mu.Lock() | ||||||
| 	if nend.IP.To4() != nil { | 	blackhole := bind.blackhole4 | ||||||
| 		blackhole = bind.blackhole4 | 	conn := bind.ipv4 | ||||||
| 		conn = bind.ipv4 | 	if nend.IP.To4() == nil { | ||||||
| 	} else { |  | ||||||
| 		blackhole = bind.blackhole6 | 		blackhole = bind.blackhole6 | ||||||
| 		conn = bind.ipv6 | 		conn = bind.ipv6 | ||||||
| 	} | 	} | ||||||
|  | 	bind.mu.Unlock() | ||||||
|  |  | ||||||
| 	if blackhole { | 	if blackhole { | ||||||
| 		return nil | 		return nil | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -266,7 +266,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock | |||||||
| 	return sa, nil | 	return sa, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { | func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { | ||||||
| 	bind.mu.Lock() | 	bind.mu.Lock() | ||||||
| 	defer bind.mu.Unlock() | 	defer bind.mu.Unlock() | ||||||
| 	defer func() { | 	defer func() { | ||||||
| @@ -275,30 +275,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { | |||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 	if atomic.LoadUint32(&bind.isOpen) != 0 { | 	if atomic.LoadUint32(&bind.isOpen) != 0 { | ||||||
| 		return 0, ErrBindAlreadyOpen | 		return nil, 0, ErrBindAlreadyOpen | ||||||
| 	} | 	} | ||||||
| 	var sa windows.Sockaddr | 	var sa windows.Sockaddr | ||||||
| 	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) | 	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, err | 		return nil, 0, err | ||||||
| 	} | 	} | ||||||
| 	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) | 	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return 0, err | 		return nil, 0, err | ||||||
| 	} | 	} | ||||||
| 	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) | 	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) | ||||||
| 	for i := 0; i < packetsPerRing; i++ { | 	for i := 0; i < packetsPerRing; i++ { | ||||||
| 		err = bind.v4.InsertReceiveRequest() | 		err = bind.v4.InsertReceiveRequest() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return 0, err | 			return nil, 0, err | ||||||
| 		} | 		} | ||||||
| 		err = bind.v6.InsertReceiveRequest() | 		err = bind.v6.InsertReceiveRequest() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return 0, err | 			return nil, 0, err | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 	atomic.StoreUint32(&bind.isOpen, 1) | 	atomic.StoreUint32(&bind.isOpen, 1) | ||||||
| 	return | 	return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *WinRingBind) Close() error { | func (bind *WinRingBind) Close() error { | ||||||
| @@ -395,13 +395,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e | |||||||
| 	return n, &ep, nil | 	return n, &ep, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) { | func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { | ||||||
| 	bind.mu.RLock() | 	bind.mu.RLock() | ||||||
| 	defer bind.mu.RUnlock() | 	defer bind.mu.RUnlock() | ||||||
| 	return bind.v4.Receive(buf, &bind.isOpen) | 	return bind.v4.Receive(buf, &bind.isOpen) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) { | func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { | ||||||
| 	bind.mu.RLock() | 	bind.mu.RLock() | ||||||
| 	defer bind.mu.RUnlock() | 	defer bind.mu.RUnlock() | ||||||
| 	return bind.v6.Receive(buf, &bind.isOpen) | 	return bind.v6.Receive(buf, &bind.isOpen) | ||||||
| @@ -482,6 +482,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { | func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { | ||||||
|  | 	bind.mu.Lock() | ||||||
|  | 	defer bind.mu.Unlock() | ||||||
| 	sysconn, err := bind.ipv4.SyscallConn() | 	sysconn, err := bind.ipv4.SyscallConn() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| @@ -500,6 +502,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole | |||||||
| } | } | ||||||
|  |  | ||||||
| func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { | func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { | ||||||
|  | 	bind.mu.Lock() | ||||||
|  | 	defer bind.mu.Unlock() | ||||||
| 	sysconn, err := bind.ipv6.SyscallConn() | 	sysconn, err := bind.ipv6.SyscallConn() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
|   | |||||||
| @@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) } | |||||||
|  |  | ||||||
| func (c ChannelEndpoint) SrcIP() net.IP { return nil } | func (c ChannelEndpoint) SrcIP() net.IP { return nil } | ||||||
|  |  | ||||||
| func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) { | func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { | ||||||
| 	c.closeSignal = make(chan bool) | 	c.closeSignal = make(chan bool) | ||||||
|  | 	fns = append(fns, c.makeReceiveFunc(*c.rx4)) | ||||||
|  | 	fns = append(fns, c.makeReceiveFunc(*c.rx6)) | ||||||
| 	if rand.Uint32()&1 == 0 { | 	if rand.Uint32()&1 == 0 { | ||||||
| 		return uint16(c.source4), nil | 		return fns, uint16(c.source4), nil | ||||||
| 	} else { | 	} else { | ||||||
| 		return uint16(c.source6), nil | 		return fns, uint16(c.source6), nil | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -87,21 +89,14 @@ func (c *ChannelBind) Close() error { | |||||||
|  |  | ||||||
| func (c *ChannelBind) SetMark(mark uint32) error { return nil } | func (c *ChannelBind) SetMark(mark uint32) error { return nil } | ||||||
|  |  | ||||||
| func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) { | func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { | ||||||
| 	select { | 	return func(b []byte) (n int, ep conn.Endpoint, err error) { | ||||||
| 	case <-c.closeSignal: | 		select { | ||||||
| 		return 0, nil, net.ErrClosed | 		case <-c.closeSignal: | ||||||
| 	case rx := <-*c.rx6: | 			return 0, nil, net.ErrClosed | ||||||
| 		return copy(b, rx), c.target6, nil | 		case rx := <-ch: | ||||||
| 	} | 			return copy(b, rx), c.target6, nil | ||||||
| } | 		} | ||||||
|  |  | ||||||
| func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) { |  | ||||||
| 	select { |  | ||||||
| 	case <-c.closeSignal: |  | ||||||
| 		return 0, nil, net.ErrClosed |  | ||||||
| 	case rx := <-*c.rx4: |  | ||||||
| 		return copy(b, rx), c.target4, nil |  | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										17
									
								
								conn/conn.go
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								conn/conn.go
									
									
									
									
									
								
							| @@ -12,6 +12,11 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | // A ReceiveFunc receives a single inbound packet from the network. | ||||||
|  | // It writes the data into b. n is the length of the packet. | ||||||
|  | // ep is the remote endpoint. | ||||||
|  | type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error) | ||||||
|  |  | ||||||
| // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. | // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. | ||||||
| // | // | ||||||
| // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, | // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, | ||||||
| @@ -19,23 +24,17 @@ import ( | |||||||
| type Bind interface { | type Bind interface { | ||||||
| 	// Open puts the Bind into a listening state on a given port and reports the actual | 	// Open puts the Bind into a listening state on a given port and reports the actual | ||||||
| 	// port that it bound to. Passing zero results in a random selection. | 	// port that it bound to. Passing zero results in a random selection. | ||||||
| 	Open(port uint16) (actualPort uint16, err error) | 	// fns is the set of functions that will be called to receive packets. | ||||||
|  | 	Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) | ||||||
|  |  | ||||||
| 	// Close closes the Bind listener. | 	// Close closes the Bind listener. | ||||||
|  | 	// All fns returned by Open must return net.ErrClosed after a call to Close. | ||||||
| 	Close() error | 	Close() error | ||||||
|  |  | ||||||
| 	// SetMark sets the mark for each packet sent through this Bind. | 	// SetMark sets the mark for each packet sent through this Bind. | ||||||
| 	// This mark is passed to the kernel as the socket option SO_MARK. | 	// This mark is passed to the kernel as the socket option SO_MARK. | ||||||
| 	SetMark(mark uint32) error | 	SetMark(mark uint32) error | ||||||
|  |  | ||||||
| 	// ReceiveIPv6 reads an IPv6 UDP packet into b.  It reports the number of bytes read, |  | ||||||
| 	// n, the packet source address ep, and any error. |  | ||||||
| 	ReceiveIPv6(b []byte) (n int, ep Endpoint, err error) |  | ||||||
|  |  | ||||||
| 	// ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read, |  | ||||||
| 	// n, the packet source address ep, and any error. |  | ||||||
| 	ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) |  | ||||||
|  |  | ||||||
| 	// Send writes a packet b to address ep. | 	// Send writes a packet b to address ep. | ||||||
| 	Send(b []byte, ep Endpoint) error | 	Send(b []byte, ep Endpoint) error | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,9 +11,6 @@ import ( | |||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"golang.org/x/net/ipv4" |  | ||||||
| 	"golang.org/x/net/ipv6" |  | ||||||
|  |  | ||||||
| 	"golang.zx2c4.com/wireguard/conn" | 	"golang.zx2c4.com/wireguard/conn" | ||||||
| 	"golang.zx2c4.com/wireguard/ratelimiter" | 	"golang.zx2c4.com/wireguard/ratelimiter" | ||||||
| 	"golang.zx2c4.com/wireguard/rwcancel" | 	"golang.zx2c4.com/wireguard/rwcancel" | ||||||
| @@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error { | |||||||
|  |  | ||||||
| 	// bind to new port | 	// bind to new port | ||||||
| 	var err error | 	var err error | ||||||
|  | 	var recvFns []conn.ReceiveFunc | ||||||
| 	netc := &device.net | 	netc := &device.net | ||||||
| 	netc.port, err = netc.bind.Open(netc.port) | 	recvFns, netc.port, err = netc.bind.Open(netc.port) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		netc.port = 0 | 		netc.port = 0 | ||||||
| 		return err | 		return err | ||||||
| @@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error { | |||||||
| 	device.peers.RUnlock() | 	device.peers.RUnlock() | ||||||
|  |  | ||||||
| 	// start receiving routines | 	// start receiving routines | ||||||
| 	device.net.stopping.Add(2) | 	device.net.stopping.Add(len(recvFns)) | ||||||
| 	device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption | 	device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption | ||||||
| 	device.queue.handshake.wg.Add(2)  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake | 	device.queue.handshake.wg.Add(len(recvFns))  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake | ||||||
| 	go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) | 	for _, fn := range recvFns { | ||||||
| 	go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) | 		go device.RoutineReceiveIncoming(fn) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	device.log.Verbosef("UDP bind has been updated") | 	device.log.Verbosef("UDP bind has been updated") | ||||||
| 	return nil | 	return nil | ||||||
|   | |||||||
| @@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() { | |||||||
|  * Every time the bind is updated a new routine is started for |  * Every time the bind is updated a new routine is started for | ||||||
|  * IPv4 and IPv6 (separately) |  * IPv4 and IPv6 (separately) | ||||||
|  */ |  */ | ||||||
| func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { | func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { | ||||||
| 	defer func() { | 	defer func() { | ||||||
| 		device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) | 		device.log.Verbosef("Routine: receive incoming %p - stopped", recv) | ||||||
| 		device.queue.decryption.wg.Done() | 		device.queue.decryption.wg.Done() | ||||||
| 		device.queue.handshake.wg.Done() | 		device.queue.handshake.wg.Done() | ||||||
| 		device.net.stopping.Done() | 		device.net.stopping.Done() | ||||||
| 	}() | 	}() | ||||||
|  |  | ||||||
| 	device.log.Verbosef("Routine: receive incoming IPv%d - started", IP) | 	device.log.Verbosef("Routine: receive incoming %p - started", recv) | ||||||
|  |  | ||||||
| 	// receive datagrams until conn is closed | 	// receive datagrams until conn is closed | ||||||
|  |  | ||||||
| @@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { | |||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		switch IP { | 		size, endpoint, err = recv(buffer[:]) | ||||||
| 		case ipv4.Version: |  | ||||||
| 			size, endpoint, err = bind.ReceiveIPv4(buffer[:]) |  | ||||||
| 		case ipv6.Version: |  | ||||||
| 			size, endpoint, err = bind.ReceiveIPv6(buffer[:]) |  | ||||||
| 		default: |  | ||||||
| 			panic("invalid IP version") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			device.PutMessageBuffer(buffer) | 			device.PutMessageBuffer(buffer) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user