Chore: make set operation atomic

This commit is contained in:
xjasonlyu
2021-03-13 11:50:14 +08:00
parent f41b844e43
commit 094199656b
2 changed files with 17 additions and 9 deletions

View File

@@ -8,6 +8,8 @@ import (
M "github.com/xjasonlyu/tun2socks/constant" M "github.com/xjasonlyu/tun2socks/constant"
"github.com/xjasonlyu/tun2socks/proxy/proto" "github.com/xjasonlyu/tun2socks/proxy/proto"
"go.uber.org/atomic"
) )
const ( const (
@@ -15,7 +17,7 @@ const (
) )
var ( var (
_defaultDialer Dialer = &Base{} _defaultDialer atomic.Value
) )
type Dialer interface { type Dialer interface {
@@ -29,24 +31,28 @@ type Proxy interface {
Proto() proto.Proto Proto() proto.Proto
} }
func init() {
_defaultDialer.Store(&Base{})
}
// SetDialer sets default Dialer. // SetDialer sets default Dialer.
func SetDialer(d Dialer) { func SetDialer(d Dialer) {
_defaultDialer = d _defaultDialer.Store(d)
} }
// Dial uses default Dialer to dial TCP. // Dial uses default Dialer to dial TCP.
func Dial(metadata *M.Metadata) (net.Conn, error) { func Dial(metadata *M.Metadata) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel() defer cancel()
return _defaultDialer.DialContext(ctx, metadata) return _defaultDialer.Load().(Dialer).DialContext(ctx, metadata)
} }
// DialContext uses default Dialer to dial TCP with context. // DialContext uses default Dialer to dial TCP with context.
func DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) { func DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) {
return _defaultDialer.DialContext(ctx, metadata) return _defaultDialer.Load().(Dialer).DialContext(ctx, metadata)
} }
// DialUDP uses default Dialer to dial UDP. // DialUDP uses default Dialer to dial UDP.
func DialUDP(metadata *M.Metadata) (net.PacketConn, error) { func DialUDP(metadata *M.Metadata) (net.PacketConn, error) {
return _defaultDialer.DialUDP(metadata) return _defaultDialer.Load().(Dialer).DialUDP(metadata)
} }

View File

@@ -14,6 +14,8 @@ import (
"github.com/xjasonlyu/tun2socks/log" "github.com/xjasonlyu/tun2socks/log"
"github.com/xjasonlyu/tun2socks/proxy" "github.com/xjasonlyu/tun2socks/proxy"
"github.com/xjasonlyu/tun2socks/tunnel/statistic" "github.com/xjasonlyu/tun2socks/tunnel/statistic"
"go.uber.org/atomic"
) )
var ( var (
@@ -23,11 +25,11 @@ var (
// _udpSessionTimeout is the default timeout for // _udpSessionTimeout is the default timeout for
// each UDP session. // each UDP session.
_udpSessionTimeout = 60 * time.Second _udpSessionTimeout = atomic.NewInt64(int64(60 * time.Second))
) )
func SetUDPTimeout(v int) { func SetUDPTimeout(v int) {
_udpSessionTimeout = time.Duration(v) * time.Second _udpSessionTimeout.Store(int64(time.Duration(v) * time.Second))
} }
func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
@@ -119,7 +121,7 @@ func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr
if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil { if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil {
log.Warnf("[UDP] write to %s error: %v", remote, err) log.Warnf("[UDP] write to %s error: %v", remote, err)
} }
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */ pc.SetReadDeadline(time.Now().Add(time.Duration(_udpSessionTimeout.Load()))) /* reset timeout */
log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote) log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote)
} }
@@ -129,7 +131,7 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
defer pool.Put(buf) defer pool.Put(buf)
for /* just loop */ { for /* just loop */ {
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) pc.SetReadDeadline(time.Now().Add(time.Duration(_udpSessionTimeout.Load())))
n, from, err := pc.ReadFrom(buf) n, from, err := pc.ReadFrom(buf)
if err != nil { if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) /* ignore i/o timeout */ { if !errors.Is(err, os.ErrDeadlineExceeded) /* ignore i/o timeout */ {