mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-05 08:36:57 +08:00
device: test up/down using virtual conn
This prevents port clashing bugs. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
136
conn/bindtest/bindtest.go
Normal file
136
conn/bindtest/bindtest.go
Normal file
@@ -0,0 +1,136 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package bindtest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type ChannelBind struct {
|
||||
rx4, tx4 *chan []byte
|
||||
rx6, tx6 *chan []byte
|
||||
closeSignal chan bool
|
||||
source4, source6 ChannelEndpoint
|
||||
target4, target6 ChannelEndpoint
|
||||
}
|
||||
|
||||
type ChannelEndpoint uint16
|
||||
|
||||
var _ conn.Bind = (*ChannelBind)(nil)
|
||||
var _ conn.Endpoint = (*ChannelEndpoint)(nil)
|
||||
|
||||
func NewChannelBinds() [2]conn.Bind {
|
||||
arx4 := make(chan []byte, 8192)
|
||||
brx4 := make(chan []byte, 8192)
|
||||
arx6 := make(chan []byte, 8192)
|
||||
brx6 := make(chan []byte, 8192)
|
||||
var binds [2]ChannelBind
|
||||
binds[0].rx4 = &arx4
|
||||
binds[0].tx4 = &brx4
|
||||
binds[1].rx4 = &brx4
|
||||
binds[1].tx4 = &arx4
|
||||
binds[0].rx6 = &arx6
|
||||
binds[0].tx6 = &brx6
|
||||
binds[1].rx6 = &brx6
|
||||
binds[1].tx6 = &arx6
|
||||
binds[0].target4 = ChannelEndpoint(1)
|
||||
binds[1].target4 = ChannelEndpoint(2)
|
||||
binds[0].target6 = ChannelEndpoint(3)
|
||||
binds[1].target6 = ChannelEndpoint(4)
|
||||
binds[0].source4 = binds[1].target4
|
||||
binds[0].source6 = binds[1].target6
|
||||
binds[1].source4 = binds[0].target4
|
||||
binds[1].source6 = binds[0].target6
|
||||
return [2]conn.Bind{&binds[0], &binds[1]}
|
||||
}
|
||||
|
||||
func (c ChannelEndpoint) ClearSrc() {}
|
||||
|
||||
func (c ChannelEndpoint) SrcToString() string { return "" }
|
||||
|
||||
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
|
||||
|
||||
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
||||
|
||||
func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
|
||||
|
||||
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
|
||||
|
||||
func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
|
||||
c.closeSignal = make(chan bool)
|
||||
if rand.Uint32()&1 == 0 {
|
||||
return uint16(c.source4), nil
|
||||
} else {
|
||||
return uint16(c.source6), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelBind) Close() error {
|
||||
if c.closeSignal != nil {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
default:
|
||||
close(c.closeSignal)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
||||
|
||||
func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return 0, nil, net.ErrClosed
|
||||
case rx := <-*c.rx6:
|
||||
return copy(b, rx), c.target6, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return 0, nil, net.ErrClosed
|
||||
case rx := <-*c.rx4:
|
||||
return copy(b, rx), c.target4, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {
|
||||
select {
|
||||
case <-c.closeSignal:
|
||||
return net.ErrClosed
|
||||
default:
|
||||
bc := make([]byte, len(b))
|
||||
copy(bc, b)
|
||||
if ep.(ChannelEndpoint) == c.target4 {
|
||||
*c.tx4 <- bc
|
||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
||||
*c.tx6 <- bc
|
||||
} else {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
||||
_, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ChannelEndpoint(i), nil
|
||||
}
|
Reference in New Issue
Block a user