mirror of
https://github.com/ICKelin/opennotr.git
synced 2025-09-26 20:01:13 +08:00
refactor: optimize forwarder
This commit is contained in:
Binary file not shown.
Before Width: | Height: | Size: 4.0 KiB After Width: | Height: | Size: 7.0 KiB |
@@ -2,8 +2,12 @@ server:
|
||||
listen: ":10100"
|
||||
authKey: "client server exchange key"
|
||||
domain: "open.notr.tech"
|
||||
tcplisten: ":4398"
|
||||
udplisten: ":4399"
|
||||
|
||||
tcpforward:
|
||||
listen: ":4398"
|
||||
|
||||
udpforward:
|
||||
listen: ":4399"
|
||||
|
||||
dhcp:
|
||||
cidr: "100.64.242.1/24"
|
||||
|
@@ -8,18 +8,31 @@ import (
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
ServerConfig ServerConfig `yaml:"server"`
|
||||
DHCPConfig DHCPConfig `yaml:"dhcp"`
|
||||
ResolverConfig ResolverConfig `yaml:"resolver"`
|
||||
Plugins map[string]string `yaml:"plugin"`
|
||||
ServerConfig ServerConfig `yaml:"server"`
|
||||
DHCPConfig DHCPConfig `yaml:"dhcp"`
|
||||
ResolverConfig ResolverConfig `yaml:"resolver"`
|
||||
TCPForwardConfig TCPForwardConfig `yaml:"tcpforward"`
|
||||
UDPForwardConfig UDPForwardConfig `yaml:"udpforward"`
|
||||
Plugins map[string]string `yaml:"plugin"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
ListenAddr string `yaml:"listen"`
|
||||
AuthKey string `yaml:"authKey"`
|
||||
Domain string `yaml:"domain"`
|
||||
TCPForwardListen string `yaml:"tcplisten"`
|
||||
UDPForwardListen string `yaml:"udplisten"`
|
||||
ListenAddr string `yaml:"listen"`
|
||||
AuthKey string `yaml:"authKey"`
|
||||
Domain string `yaml:"domain"`
|
||||
}
|
||||
|
||||
type TCPForwardConfig struct {
|
||||
ListenAddr string `yaml:"listen"`
|
||||
ReadTimeout int `yaml:"readTimeout"`
|
||||
WriteTimeout int `yaml:"writeTimeout"`
|
||||
}
|
||||
|
||||
type UDPForwardConfig struct {
|
||||
ListenAddr string `yaml:"listen"`
|
||||
ReadTimeout int `yaml:"readTimeout"`
|
||||
WriteTimeout int `yaml:"writeTimeout"`
|
||||
SessionTimeout int `yaml:"sessionTimeout"`
|
||||
}
|
||||
|
||||
type DHCPConfig struct {
|
||||
|
@@ -1,15 +1,12 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
"github.com/ICKelin/opennotr/pkg/proto"
|
||||
)
|
||||
|
||||
@@ -100,39 +97,3 @@ func encodeProxyProtocol(protocol, sip, sport, dip, dport string) []byte {
|
||||
body, _ := json.Marshal(proxyProtocol)
|
||||
return encode(body)
|
||||
}
|
||||
|
||||
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
|
||||
msgs, err := syscall.ParseSocketControlMessage(hdr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var origindst *net.UDPAddr
|
||||
for _, msg := range msgs {
|
||||
if msg.Header.Level == syscall.SOL_IP &&
|
||||
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
|
||||
originDstRaw := &syscall.RawSockaddrInet4{}
|
||||
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
|
||||
if err != nil {
|
||||
logs.Error("read origin dst fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// only support for ipv4
|
||||
if originDstRaw.Family == syscall.AF_INET {
|
||||
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
|
||||
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
|
||||
origindst = &net.UDPAddr{
|
||||
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
|
||||
Port: int(p[0])<<8 + int(p[1]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if origindst == nil {
|
||||
return nil, fmt.Errorf("get origin dst fail")
|
||||
}
|
||||
|
||||
return origindst, nil
|
||||
}
|
||||
|
@@ -10,18 +10,43 @@ import (
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
)
|
||||
|
||||
var (
|
||||
// default tcp timeout(read, write), 10 seconds
|
||||
defaultTCPTimeout = 10
|
||||
)
|
||||
|
||||
type TCPForward struct {
|
||||
sessMgr *SessionManager
|
||||
listenAddr string
|
||||
// writeTimeout defines the tcp connection write timeout in second
|
||||
// default value set to 10 seconds
|
||||
writeTimeout time.Duration
|
||||
|
||||
// readTimeout defines the tcp connection write timeout in second
|
||||
// default value set to 10 seconds
|
||||
readTimeout time.Duration
|
||||
sessMgr *SessionManager
|
||||
}
|
||||
|
||||
func NewTCPForward() *TCPForward {
|
||||
func NewTCPForward(cfg TCPForwardConfig) *TCPForward {
|
||||
tcpReadTimeout := cfg.ReadTimeout
|
||||
if tcpReadTimeout <= 0 {
|
||||
tcpReadTimeout = defaultTCPTimeout
|
||||
}
|
||||
|
||||
tcpWriteTimeout := cfg.WriteTimeout
|
||||
if tcpWriteTimeout <= 0 {
|
||||
tcpWriteTimeout = int(defaultTCPTimeout)
|
||||
}
|
||||
return &TCPForward{
|
||||
sessMgr: GetSessionManager(),
|
||||
listenAddr: cfg.ListenAddr,
|
||||
writeTimeout: time.Duration(tcpWriteTimeout) * time.Second,
|
||||
readTimeout: time.Duration(tcpReadTimeout) * time.Second,
|
||||
sessMgr: GetSessionManager(),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *TCPForward) Listen(listenAddr string) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
func (f *TCPForward) Listen() (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", f.listenAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -77,7 +102,7 @@ func (f *TCPForward) forwardTCP(conn net.Conn) {
|
||||
// todo rewrite to client configuration
|
||||
targetIP := "127.0.0.1"
|
||||
bytes := encodeProxyProtocol("tcp", sip, sport, targetIP, dport)
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
|
@@ -68,7 +68,6 @@ func runBackend() {
|
||||
for {
|
||||
nr, err := stream.Read(buf)
|
||||
if err != nil {
|
||||
fmt.Println("read stream fail:", err)
|
||||
break
|
||||
}
|
||||
stream.Write(buf[:nr])
|
||||
@@ -115,8 +114,10 @@ func runtproxy(tcpfw *TCPForward, listener net.Listener) {
|
||||
|
||||
func TestTCPForward(t *testing.T) {
|
||||
// listen tproxy
|
||||
tcpfw := NewTCPForward()
|
||||
listener, err := tcpfw.Listen(tproxyAddr)
|
||||
tcpfw := NewTCPForward(TCPForwardConfig{
|
||||
ListenAddr: tproxyAddr,
|
||||
})
|
||||
listener, err := tcpfw.Listen()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -164,8 +165,10 @@ func TestTCPForward(t *testing.T) {
|
||||
|
||||
func benchmark(t *testing.B, nconn int) {
|
||||
// listen tproxy
|
||||
tcpfw := NewTCPForward()
|
||||
listener, err := tcpfw.Listen(tproxyAddr)
|
||||
tcpfw := NewTCPForward(TCPForwardConfig{
|
||||
ListenAddr: tproxyAddr,
|
||||
})
|
||||
listener, err := tcpfw.Listen()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@@ -226,3 +229,11 @@ func Benchmark4K(b *testing.B) {
|
||||
func Benchmark8K(b *testing.B) {
|
||||
benchmark(b, 1024*8)
|
||||
}
|
||||
|
||||
func Benchmark10K(b *testing.B) {
|
||||
benchmark(b, 1024*10)
|
||||
}
|
||||
|
||||
func Benchmark14K(b *testing.B) {
|
||||
benchmark(b, 1024*14)
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -8,23 +9,61 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
type UDPForward struct {
|
||||
sessMgr *SessionManager
|
||||
var (
|
||||
// default udp timeout(read, write)(seconds)
|
||||
defaultUDPTimeout = 10
|
||||
|
||||
// default udp session timeout(seconds)
|
||||
defaultUDPSessionTimeout = 30
|
||||
)
|
||||
|
||||
type udpSession struct {
|
||||
stream *yamux.Stream
|
||||
lastActive time.Time
|
||||
}
|
||||
|
||||
func NewUDPForward() *UDPForward {
|
||||
type UDPForward struct {
|
||||
listenAddr string
|
||||
sessionTimeout int
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
sessMgr *SessionManager
|
||||
udpSessions sync.Map
|
||||
}
|
||||
|
||||
func NewUDPForward(cfg UDPForwardConfig) *UDPForward {
|
||||
readTimeout := cfg.ReadTimeout
|
||||
if readTimeout <= 0 {
|
||||
readTimeout = defaultUDPTimeout
|
||||
}
|
||||
|
||||
writeTimeout := cfg.WriteTimeout
|
||||
if writeTimeout <= 0 {
|
||||
writeTimeout = defaultUDPTimeout
|
||||
}
|
||||
|
||||
sessionTimeout := cfg.SessionTimeout
|
||||
if sessionTimeout <= 0 {
|
||||
sessionTimeout = defaultUDPSessionTimeout
|
||||
}
|
||||
|
||||
return &UDPForward{
|
||||
sessMgr: GetSessionManager(),
|
||||
listenAddr: cfg.ListenAddr,
|
||||
readTimeout: time.Duration(readTimeout) * time.Second,
|
||||
writeTimeout: time.Duration(writeTimeout) * time.Second,
|
||||
sessionTimeout: sessionTimeout,
|
||||
sessMgr: GetSessionManager(),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
laddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
func (f *UDPForward) ListenAndServe() error {
|
||||
laddr, err := net.ResolveUDPAddr("udp", f.listenAddr)
|
||||
if err != nil {
|
||||
logs.Error("resolve udp fail: %v", err)
|
||||
return err
|
||||
@@ -67,24 +106,19 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
streams := sync.Map{}
|
||||
defer func() {
|
||||
streams.Range(func(k, v interface{}) bool {
|
||||
v.(*yamux.Stream).Close()
|
||||
return true
|
||||
})
|
||||
}()
|
||||
|
||||
go f.recyeleSession()
|
||||
buf := make([]byte, 64*1024)
|
||||
oob := make([]byte, 1024)
|
||||
for {
|
||||
// udp is not connect oriented, it should use read message
|
||||
// and read the origin dst ip and port from msghdr
|
||||
nr, oobn, _, raddr, err := lconn.ReadMsgUDP(buf, oob)
|
||||
if err != nil {
|
||||
logs.Error("read from udp fail: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
origindst, err := getOriginDst(oob[:oobn])
|
||||
origindst, err := f.getOriginDst(oob[:oobn])
|
||||
if err != nil {
|
||||
logs.Error("get origin dst fail: %v", err)
|
||||
continue
|
||||
@@ -94,7 +128,7 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
sip, sport, _ := net.SplitHostPort(raddr.String())
|
||||
|
||||
key := fmt.Sprintf("%s:%s:%s:%s", sip, sport, dip, dport)
|
||||
val, ok := streams.Load(key)
|
||||
val, ok := f.udpSessions.Load(key)
|
||||
if !ok {
|
||||
sess := f.sessMgr.GetSession(dip)
|
||||
if sess == nil {
|
||||
@@ -107,11 +141,11 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
logs.Error("open stream fail: %v", err)
|
||||
continue
|
||||
}
|
||||
streams.Store(key, stream)
|
||||
f.udpSessions.Store(key, &udpSession{stream, time.Now()})
|
||||
|
||||
targetIP := "127.0.0.1"
|
||||
bytes := encodeProxyProtocol("udp", sip, sport, targetIP, dport)
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
@@ -119,17 +153,25 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
continue
|
||||
}
|
||||
|
||||
go f.forwardUDP(stream, rawfd, origindst, raddr)
|
||||
val, ok = streams.Load(key)
|
||||
go f.forwardUDP(stream, key, rawfd, origindst, raddr)
|
||||
val, ok = f.udpSessions.Load(key)
|
||||
if !ok {
|
||||
logs.Error("get stream for %s fail", key)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stream := val.(*yamux.Stream)
|
||||
udpsess, ok := val.(*udpSession)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// update active time to avoid session recycle
|
||||
udpsess.lastActive = time.Now()
|
||||
stream := udpsess.stream
|
||||
|
||||
bytes := encode(buf[:nr])
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
@@ -140,18 +182,28 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
}
|
||||
|
||||
// forwardUDP reads from stream and write to tofd via rawsocket
|
||||
func (f *UDPForward) forwardUDP(stream *yamux.Stream, tofd int, fromaddr, toaddr *net.UDPAddr) {
|
||||
func (f *UDPForward) forwardUDP(stream *yamux.Stream, sessionKey string, tofd int, fromaddr, toaddr *net.UDPAddr) {
|
||||
defer stream.Close()
|
||||
defer f.udpSessions.Delete(sessionKey)
|
||||
hdr := make([]byte, 2)
|
||||
for {
|
||||
_, err := io.ReadFull(stream, hdr)
|
||||
nr, err := stream.Read(hdr)
|
||||
if err != nil {
|
||||
logs.Error("read stream fail %v", err)
|
||||
if err != io.EOF {
|
||||
logs.Error("read stream fail %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
if nr != 2 {
|
||||
logs.Error("invalid bodylen: %d", nr)
|
||||
continue
|
||||
}
|
||||
|
||||
nlen := binary.BigEndian.Uint16(hdr)
|
||||
buf := make([]byte, nlen)
|
||||
stream.SetReadDeadline(time.Now().Add(f.readTimeout))
|
||||
_, err = io.ReadFull(stream, buf)
|
||||
stream.SetReadDeadline(time.Time{})
|
||||
if err != nil {
|
||||
logs.Error("read stream body fail: %v", err)
|
||||
break
|
||||
@@ -163,3 +215,63 @@ func (f *UDPForward) forwardUDP(stream *yamux.Stream, tofd int, fromaddr, toaddr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForward) recyeleSession() {
|
||||
tick := time.NewTicker(time.Second * 5)
|
||||
for range tick.C {
|
||||
total, timeout := 0, 0
|
||||
f.udpSessions.Range(func(k, v interface{}) bool {
|
||||
total += 1
|
||||
s, ok := v.(*udpSession)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if time.Now().Sub(s.lastActive).Seconds() > float64(f.sessionTimeout) {
|
||||
logs.Warn("remove udp %v session, lastActive: %v", k, s.lastActive)
|
||||
f.udpSessions.Delete(k)
|
||||
s.stream.Close()
|
||||
timeout += 1
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
logs.Debug("total %d, timeout %d, left: %d", total, timeout, total-timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForward) getOriginDst(hdr []byte) (*net.UDPAddr, error) {
|
||||
msgs, err := syscall.ParseSocketControlMessage(hdr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var origindst *net.UDPAddr
|
||||
for _, msg := range msgs {
|
||||
if msg.Header.Level == syscall.SOL_IP &&
|
||||
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
|
||||
originDstRaw := &syscall.RawSockaddrInet4{}
|
||||
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
|
||||
if err != nil {
|
||||
logs.Error("read origin dst fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// only support for ipv4
|
||||
if originDstRaw.Family == syscall.AF_INET {
|
||||
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
|
||||
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
|
||||
origindst = &net.UDPAddr{
|
||||
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
|
||||
Port: int(p[0])<<8 + int(p[1]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if origindst == nil {
|
||||
return nil, fmt.Errorf("get origin dst fail")
|
||||
}
|
||||
|
||||
return origindst, nil
|
||||
}
|
||||
|
13
opennotrd/core/udpforward_test.go
Normal file
13
opennotrd/core/udpforward_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type mockUDPForward struct {
|
||||
*UDPForward
|
||||
}
|
||||
|
||||
func (f *mockUDPForward) getOriginDst([]byte) (*net.UDPAddr, error) {
|
||||
return nil, nil
|
||||
}
|
@@ -51,8 +51,8 @@ func Run() {
|
||||
|
||||
// up local tcp,udp service
|
||||
// we use tproxy to route traffic to the tcp port and udp port here.
|
||||
tcpfw := core.NewTCPForward()
|
||||
listener, err := tcpfw.Listen(cfg.ServerConfig.TCPForwardListen)
|
||||
tcpfw := core.NewTCPForward(cfg.TCPForwardConfig)
|
||||
listener, err := tcpfw.Listen()
|
||||
if err != nil {
|
||||
logs.Error("listen tproxy tcp fail: %v", err)
|
||||
return
|
||||
@@ -60,7 +60,7 @@ func Run() {
|
||||
|
||||
go tcpfw.Serve(listener)
|
||||
|
||||
go core.NewUDPForward().ListenAndServe(cfg.ServerConfig.UDPForwardListen)
|
||||
go core.NewUDPForward(cfg.UDPForwardConfig).ListenAndServe()
|
||||
|
||||
// server provides tcp server for opennotr client
|
||||
s := core.NewServer(cfg.ServerConfig, dhcp, resolver)
|
||||
|
Reference in New Issue
Block a user