refactor: optimize forwarder

This commit is contained in:
ICKelin
2021-05-09 09:40:49 +08:00
parent 1fd1222a1b
commit 51bcd4dd44
9 changed files with 228 additions and 89 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.0 KiB

After

Width:  |  Height:  |  Size: 7.0 KiB

View File

@@ -2,8 +2,12 @@ server:
listen: ":10100"
authKey: "client server exchange key"
domain: "open.notr.tech"
tcplisten: ":4398"
udplisten: ":4399"
tcpforward:
listen: ":4398"
udpforward:
listen: ":4399"
dhcp:
cidr: "100.64.242.1/24"

View File

@@ -8,18 +8,31 @@ import (
)
type Config struct {
ServerConfig ServerConfig `yaml:"server"`
DHCPConfig DHCPConfig `yaml:"dhcp"`
ResolverConfig ResolverConfig `yaml:"resolver"`
Plugins map[string]string `yaml:"plugin"`
ServerConfig ServerConfig `yaml:"server"`
DHCPConfig DHCPConfig `yaml:"dhcp"`
ResolverConfig ResolverConfig `yaml:"resolver"`
TCPForwardConfig TCPForwardConfig `yaml:"tcpforward"`
UDPForwardConfig UDPForwardConfig `yaml:"udpforward"`
Plugins map[string]string `yaml:"plugin"`
}
type ServerConfig struct {
ListenAddr string `yaml:"listen"`
AuthKey string `yaml:"authKey"`
Domain string `yaml:"domain"`
TCPForwardListen string `yaml:"tcplisten"`
UDPForwardListen string `yaml:"udplisten"`
ListenAddr string `yaml:"listen"`
AuthKey string `yaml:"authKey"`
Domain string `yaml:"domain"`
}
type TCPForwardConfig struct {
ListenAddr string `yaml:"listen"`
ReadTimeout int `yaml:"readTimeout"`
WriteTimeout int `yaml:"writeTimeout"`
}
type UDPForwardConfig struct {
ListenAddr string `yaml:"listen"`
ReadTimeout int `yaml:"readTimeout"`
WriteTimeout int `yaml:"writeTimeout"`
SessionTimeout int `yaml:"sessionTimeout"`
}
type DHCPConfig struct {

View File

@@ -1,15 +1,12 @@
package core
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"net"
"syscall"
"unsafe"
"github.com/ICKelin/opennotr/pkg/logs"
"github.com/ICKelin/opennotr/pkg/proto"
)
@@ -100,39 +97,3 @@ func encodeProxyProtocol(protocol, sip, sport, dip, dport string) []byte {
body, _ := json.Marshal(proxyProtocol)
return encode(body)
}
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
msgs, err := syscall.ParseSocketControlMessage(hdr)
if err != nil {
return nil, err
}
var origindst *net.UDPAddr
for _, msg := range msgs {
if msg.Header.Level == syscall.SOL_IP &&
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
originDstRaw := &syscall.RawSockaddrInet4{}
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
if err != nil {
logs.Error("read origin dst fail: %v", err)
continue
}
// only support for ipv4
if originDstRaw.Family == syscall.AF_INET {
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
origindst = &net.UDPAddr{
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
Port: int(p[0])<<8 + int(p[1]),
}
}
}
}
if origindst == nil {
return nil, fmt.Errorf("get origin dst fail")
}
return origindst, nil
}

View File

@@ -10,18 +10,43 @@ import (
"github.com/ICKelin/opennotr/pkg/logs"
)
var (
// default tcp timeout(read, write), 10 seconds
defaultTCPTimeout = 10
)
type TCPForward struct {
sessMgr *SessionManager
listenAddr string
// writeTimeout defines the tcp connection write timeout in second
// default value set to 10 seconds
writeTimeout time.Duration
// readTimeout defines the tcp connection write timeout in second
// default value set to 10 seconds
readTimeout time.Duration
sessMgr *SessionManager
}
func NewTCPForward() *TCPForward {
func NewTCPForward(cfg TCPForwardConfig) *TCPForward {
tcpReadTimeout := cfg.ReadTimeout
if tcpReadTimeout <= 0 {
tcpReadTimeout = defaultTCPTimeout
}
tcpWriteTimeout := cfg.WriteTimeout
if tcpWriteTimeout <= 0 {
tcpWriteTimeout = int(defaultTCPTimeout)
}
return &TCPForward{
sessMgr: GetSessionManager(),
listenAddr: cfg.ListenAddr,
writeTimeout: time.Duration(tcpWriteTimeout) * time.Second,
readTimeout: time.Duration(tcpReadTimeout) * time.Second,
sessMgr: GetSessionManager(),
}
}
func (f *TCPForward) Listen(listenAddr string) (net.Listener, error) {
listener, err := net.Listen("tcp", listenAddr)
func (f *TCPForward) Listen() (net.Listener, error) {
listener, err := net.Listen("tcp", f.listenAddr)
if err != nil {
return nil, err
}
@@ -77,7 +102,7 @@ func (f *TCPForward) forwardTCP(conn net.Conn) {
// todo rewrite to client configuration
targetIP := "127.0.0.1"
bytes := encodeProxyProtocol("tcp", sip, sport, targetIP, dport)
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
if err != nil {

View File

@@ -68,7 +68,6 @@ func runBackend() {
for {
nr, err := stream.Read(buf)
if err != nil {
fmt.Println("read stream fail:", err)
break
}
stream.Write(buf[:nr])
@@ -115,8 +114,10 @@ func runtproxy(tcpfw *TCPForward, listener net.Listener) {
func TestTCPForward(t *testing.T) {
// listen tproxy
tcpfw := NewTCPForward()
listener, err := tcpfw.Listen(tproxyAddr)
tcpfw := NewTCPForward(TCPForwardConfig{
ListenAddr: tproxyAddr,
})
listener, err := tcpfw.Listen()
if err != nil {
t.Error(err)
return
@@ -164,8 +165,10 @@ func TestTCPForward(t *testing.T) {
func benchmark(t *testing.B, nconn int) {
// listen tproxy
tcpfw := NewTCPForward()
listener, err := tcpfw.Listen(tproxyAddr)
tcpfw := NewTCPForward(TCPForwardConfig{
ListenAddr: tproxyAddr,
})
listener, err := tcpfw.Listen()
if err != nil {
t.Error(err)
return
@@ -226,3 +229,11 @@ func Benchmark4K(b *testing.B) {
func Benchmark8K(b *testing.B) {
benchmark(b, 1024*8)
}
func Benchmark10K(b *testing.B) {
benchmark(b, 1024*10)
}
func Benchmark14K(b *testing.B) {
benchmark(b, 1024*14)
}

View File

@@ -1,6 +1,7 @@
package core
import (
"bytes"
"encoding/binary"
"fmt"
"io"
@@ -8,23 +9,61 @@ import (
"sync"
"syscall"
"time"
"unsafe"
"github.com/ICKelin/opennotr/pkg/logs"
"github.com/hashicorp/yamux"
)
type UDPForward struct {
sessMgr *SessionManager
var (
// default udp timeout(read, write)(seconds)
defaultUDPTimeout = 10
// default udp session timeout(seconds)
defaultUDPSessionTimeout = 30
)
type udpSession struct {
stream *yamux.Stream
lastActive time.Time
}
func NewUDPForward() *UDPForward {
type UDPForward struct {
listenAddr string
sessionTimeout int
readTimeout time.Duration
writeTimeout time.Duration
sessMgr *SessionManager
udpSessions sync.Map
}
func NewUDPForward(cfg UDPForwardConfig) *UDPForward {
readTimeout := cfg.ReadTimeout
if readTimeout <= 0 {
readTimeout = defaultUDPTimeout
}
writeTimeout := cfg.WriteTimeout
if writeTimeout <= 0 {
writeTimeout = defaultUDPTimeout
}
sessionTimeout := cfg.SessionTimeout
if sessionTimeout <= 0 {
sessionTimeout = defaultUDPSessionTimeout
}
return &UDPForward{
sessMgr: GetSessionManager(),
listenAddr: cfg.ListenAddr,
readTimeout: time.Duration(readTimeout) * time.Second,
writeTimeout: time.Duration(writeTimeout) * time.Second,
sessionTimeout: sessionTimeout,
sessMgr: GetSessionManager(),
}
}
func (f *UDPForward) ListenAndServe(listenAddr string) error {
laddr, err := net.ResolveUDPAddr("udp", listenAddr)
func (f *UDPForward) ListenAndServe() error {
laddr, err := net.ResolveUDPAddr("udp", f.listenAddr)
if err != nil {
logs.Error("resolve udp fail: %v", err)
return err
@@ -67,24 +106,19 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
return err
}
streams := sync.Map{}
defer func() {
streams.Range(func(k, v interface{}) bool {
v.(*yamux.Stream).Close()
return true
})
}()
go f.recyeleSession()
buf := make([]byte, 64*1024)
oob := make([]byte, 1024)
for {
// udp is not connect oriented, it should use read message
// and read the origin dst ip and port from msghdr
nr, oobn, _, raddr, err := lconn.ReadMsgUDP(buf, oob)
if err != nil {
logs.Error("read from udp fail: %v", err)
break
}
origindst, err := getOriginDst(oob[:oobn])
origindst, err := f.getOriginDst(oob[:oobn])
if err != nil {
logs.Error("get origin dst fail: %v", err)
continue
@@ -94,7 +128,7 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
sip, sport, _ := net.SplitHostPort(raddr.String())
key := fmt.Sprintf("%s:%s:%s:%s", sip, sport, dip, dport)
val, ok := streams.Load(key)
val, ok := f.udpSessions.Load(key)
if !ok {
sess := f.sessMgr.GetSession(dip)
if sess == nil {
@@ -107,11 +141,11 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
logs.Error("open stream fail: %v", err)
continue
}
streams.Store(key, stream)
f.udpSessions.Store(key, &udpSession{stream, time.Now()})
targetIP := "127.0.0.1"
bytes := encodeProxyProtocol("udp", sip, sport, targetIP, dport)
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
if err != nil {
@@ -119,17 +153,25 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
continue
}
go f.forwardUDP(stream, rawfd, origindst, raddr)
val, ok = streams.Load(key)
go f.forwardUDP(stream, key, rawfd, origindst, raddr)
val, ok = f.udpSessions.Load(key)
if !ok {
logs.Error("get stream for %s fail", key)
continue
}
}
stream := val.(*yamux.Stream)
udpsess, ok := val.(*udpSession)
if !ok {
continue
}
// update active time to avoid session recycle
udpsess.lastActive = time.Now()
stream := udpsess.stream
bytes := encode(buf[:nr])
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
if err != nil {
@@ -140,18 +182,28 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
}
// forwardUDP reads from stream and write to tofd via rawsocket
func (f *UDPForward) forwardUDP(stream *yamux.Stream, tofd int, fromaddr, toaddr *net.UDPAddr) {
func (f *UDPForward) forwardUDP(stream *yamux.Stream, sessionKey string, tofd int, fromaddr, toaddr *net.UDPAddr) {
defer stream.Close()
defer f.udpSessions.Delete(sessionKey)
hdr := make([]byte, 2)
for {
_, err := io.ReadFull(stream, hdr)
nr, err := stream.Read(hdr)
if err != nil {
logs.Error("read stream fail %v", err)
if err != io.EOF {
logs.Error("read stream fail %v", err)
}
break
}
if nr != 2 {
logs.Error("invalid bodylen: %d", nr)
continue
}
nlen := binary.BigEndian.Uint16(hdr)
buf := make([]byte, nlen)
stream.SetReadDeadline(time.Now().Add(f.readTimeout))
_, err = io.ReadFull(stream, buf)
stream.SetReadDeadline(time.Time{})
if err != nil {
logs.Error("read stream body fail: %v", err)
break
@@ -163,3 +215,63 @@ func (f *UDPForward) forwardUDP(stream *yamux.Stream, tofd int, fromaddr, toaddr
}
}
}
func (f *UDPForward) recyeleSession() {
tick := time.NewTicker(time.Second * 5)
for range tick.C {
total, timeout := 0, 0
f.udpSessions.Range(func(k, v interface{}) bool {
total += 1
s, ok := v.(*udpSession)
if !ok {
return true
}
if time.Now().Sub(s.lastActive).Seconds() > float64(f.sessionTimeout) {
logs.Warn("remove udp %v session, lastActive: %v", k, s.lastActive)
f.udpSessions.Delete(k)
s.stream.Close()
timeout += 1
}
return true
})
logs.Debug("total %d, timeout %d, left: %d", total, timeout, total-timeout)
}
}
func (f *UDPForward) getOriginDst(hdr []byte) (*net.UDPAddr, error) {
msgs, err := syscall.ParseSocketControlMessage(hdr)
if err != nil {
return nil, err
}
var origindst *net.UDPAddr
for _, msg := range msgs {
if msg.Header.Level == syscall.SOL_IP &&
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
originDstRaw := &syscall.RawSockaddrInet4{}
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
if err != nil {
logs.Error("read origin dst fail: %v", err)
continue
}
// only support for ipv4
if originDstRaw.Family == syscall.AF_INET {
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
origindst = &net.UDPAddr{
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
Port: int(p[0])<<8 + int(p[1]),
}
}
}
}
if origindst == nil {
return nil, fmt.Errorf("get origin dst fail")
}
return origindst, nil
}

View File

@@ -0,0 +1,13 @@
package core
import (
"net"
)
type mockUDPForward struct {
*UDPForward
}
func (f *mockUDPForward) getOriginDst([]byte) (*net.UDPAddr, error) {
return nil, nil
}

View File

@@ -51,8 +51,8 @@ func Run() {
// up local tcp,udp service
// we use tproxy to route traffic to the tcp port and udp port here.
tcpfw := core.NewTCPForward()
listener, err := tcpfw.Listen(cfg.ServerConfig.TCPForwardListen)
tcpfw := core.NewTCPForward(cfg.TCPForwardConfig)
listener, err := tcpfw.Listen()
if err != nil {
logs.Error("listen tproxy tcp fail: %v", err)
return
@@ -60,7 +60,7 @@ func Run() {
go tcpfw.Serve(listener)
go core.NewUDPForward().ListenAndServe(cfg.ServerConfig.UDPForwardListen)
go core.NewUDPForward(cfg.UDPForwardConfig).ListenAndServe()
// server provides tcp server for opennotr client
s := core.NewServer(cfg.ServerConfig, dhcp, resolver)