refactor: code optimize

This commit is contained in:
ICKelin
2021-05-05 20:34:50 +08:00
parent f122d9d791
commit 604ae9b53c
8 changed files with 118 additions and 114 deletions

View File

@@ -2,14 +2,13 @@ server:
listen: ":10100"
authKey: "client server exchange key"
domain: "open.notr.tech"
tcplisten: ":4398"
udplisten: ":4399"
dhcp:
cidr: "100.64.242.1/24"
ip: "100.64.242.1"
upstream:
remoteAddr: "http://127.0.0.1:81/upstreams"
plugin:
tcp: |
{
@@ -32,9 +31,8 @@ plugin:
{
"adminUrl": "http://127.0.0.1:81/upstreams"
}
h2c: |
{
"adminUrl": "http://127.0.0.1:81/upstreams"
}

View File

@@ -1,21 +0,0 @@
package core
import "fmt"
type Packet []byte
func (p Packet) Invalid() bool {
return len(p) < 20
}
func (p Packet) Version() int {
return int((p[0] >> 4))
}
func (p Packet) Dst() string {
return fmt.Sprintf("%d.%d.%d.%d", p[16], p[17], p[18], p[19])
}
func (p Packet) Src() string {
return fmt.Sprintf("%d.%d.%d.%d", p[12], p[13], p[14], p[15])
}

View File

@@ -1,17 +1,19 @@
package core
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"net"
"syscall"
"unsafe"
"github.com/ICKelin/opennotr/pkg/logs"
"github.com/ICKelin/opennotr/pkg/proto"
)
func checksum_add(buf []byte, seed uint32) uint32 {
func checksumAdd(buf []byte, seed uint32) uint32 {
sum := seed
for i, l := 0, len(buf); i < l; i += 2 {
j := i + 1
@@ -24,7 +26,7 @@ func checksum_add(buf []byte, seed uint32) uint32 {
return sum
}
func checksum_wrap(seed uint32) uint16 {
func checksumWrapper(seed uint32) uint16 {
sum := seed
for sum > 0xffff {
sum = (sum >> 16) + (sum & 0xffff)
@@ -39,7 +41,7 @@ func checksum_wrap(seed uint32) uint16 {
}
func CheckSum(buf []byte) uint16 {
return checksum_wrap(checksum_add(buf, 0))
return checksumWrapper(checksumAdd(buf, 0))
}
func sendUDPViaRaw(fd int, src, dst *net.UDPAddr, payload []byte) error {
@@ -61,7 +63,7 @@ func sendUDPViaRaw(fd int, src, dst *net.UDPAddr, payload []byte) error {
data[25] = byte(ulen)
copy(data[28:], payload)
uc := checksum_wrap(checksum_add(data, uint32(ulen)))
uc := checksumWrapper(checksumAdd(data, uint32(ulen)))
data[26] = byte(uc >> 8)
data[27] = byte(uc)
@@ -95,11 +97,42 @@ func encodeProxyProtocol(protocol, sip, sport, dip, dport string) []byte {
DstPort: dport,
}
body, err := json.Marshal(proxyProtocol)
body, _ := json.Marshal(proxyProtocol)
return encode(body)
}
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
msgs, err := syscall.ParseSocketControlMessage(hdr)
if err != nil {
logs.Error("json marshal fail: %v", err)
return nil, err
}
bytes := encode(body)
return bytes
var origindst *net.UDPAddr
for _, msg := range msgs {
if msg.Header.Level == syscall.SOL_IP &&
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
originDstRaw := &syscall.RawSockaddrInet4{}
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
if err != nil {
logs.Error("read origin dst fail: %v", err)
continue
}
// only support for ipv4
if originDstRaw.Family == syscall.AF_INET {
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
origindst = &net.UDPAddr{
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
Port: int(p[0])<<8 + int(p[1]),
}
}
}
}
if origindst == nil {
return nil, fmt.Errorf("get origin dst fail")
}
return origindst, nil
}

View File

@@ -6,7 +6,6 @@ import (
"net"
"net/http"
"strings"
"sync/atomic"
"time"
"github.com/ICKelin/opennotr/opennotrd/plugin"
@@ -84,7 +83,7 @@ func (s *Server) onConn(conn net.Conn) {
return
}
// it client without domain
// if client without domain
// generate random domain base on time nano
if len(auth.Domain) <= 0 {
auth.Domain = fmt.Sprintf("%s.%s", randomDomain(time.Now().UnixNano()), s.domain)
@@ -100,9 +99,8 @@ func (s *Server) onConn(conn net.Conn) {
}
reply := &proto.S2CAuth{
Vip: vip,
Gateway: s.dhcp.GetCIDR(),
Domain: auth.Domain,
Vip: vip,
Domain: auth.Domain,
}
err = proto.WriteJSON(conn, proto.CmdAuth, reply)
@@ -125,23 +123,23 @@ func (s *Server) onConn(conn net.Conn) {
logs.Info("select domain: %s", auth.Domain)
// create forward
// $localPort => vip:$upstreamPort
// $publicPort => vip:$localPort
// 1. for from address, we listen 0.0.0.0:$inport
// from member is not used for restyproxy
// 2. for to address, we use $vip:$upstreamPort
// the vip is the virtual lan ip address
// Domain is only use for restyproxy
for _, forward := range auth.Forward {
for localPort, upstreamPort := range forward.Ports {
for publicPort, localPort := range forward.Ports {
item := &plugin.PluginMeta{
Protocol: forward.Protocol,
From: fmt.Sprintf("0.0.0.0:%d", localPort),
To: fmt.Sprintf("%s:%d", vip, upstreamPort),
From: fmt.Sprintf("0.0.0.0:%d", publicPort),
To: fmt.Sprintf("%s:%s", vip, localPort),
Domain: auth.Domain,
RecycleSignal: make(chan struct{}),
}
err := s.pluginMgr.AddProxy(item)
err = s.pluginMgr.AddProxy(item)
if err != nil {
logs.Error("add proxy fail: %v", err)
return
@@ -156,7 +154,7 @@ func (s *Server) onConn(conn net.Conn) {
return
}
sess := newSession(mux, conn.RemoteAddr().String())
sess := newSession(mux, vip)
s.sessMgr.AddSession(vip, sess)
defer s.sessMgr.DeleteSession(vip)
@@ -168,10 +166,10 @@ func (s *Server) onConn(conn net.Conn) {
return
case <-rttInterval.C:
rx := atomic.SwapUint64(&sess.rxbytes, 0)
tx := atomic.SwapUint64(&sess.txbytes, 0)
rx := sess.ResetRx()
tx := sess.ResetTx()
rtt, _ := mux.Ping()
logs.Debug("session %s rtt %d, rx %d tx %d",
logs.Debug("session %s rtt %dms, rx %d tx %d",
sess.conn.RemoteAddr().String(), rtt.Milliseconds(), rx, tx)
}
}

View File

@@ -2,6 +2,7 @@ package core
import (
"sync"
"sync/atomic"
"github.com/hashicorp/yamux"
)
@@ -17,19 +18,33 @@ func GetSessionManager() *SessionManager {
}
type Session struct {
conn *yamux.Session
clientAddr string
rxbytes uint64
txbytes uint64
conn *yamux.Session
rxbytes uint64
txbytes uint64
}
func newSession(conn *yamux.Session, clientAddr string) *Session {
func newSession(conn *yamux.Session, vip string) *Session {
return &Session{
conn: conn,
clientAddr: clientAddr,
conn: conn,
}
}
func (s *Session) ResetRx() uint64 {
return atomic.SwapUint64(&s.rxbytes, 0)
}
func (s *Session) ResetTx() uint64 {
return atomic.SwapUint64(&s.txbytes, 0)
}
func (s *Session) IncRx(nb uint64) {
atomic.AddUint64(&s.rxbytes, nb)
}
func (s *Session) IncTx(nb uint64) {
atomic.AddUint64(&s.txbytes, nb)
}
func (mgr *SessionManager) AddSession(vip string, sess *Session) {
mgr.sessions.Store(vip, sess)
}

View File

@@ -3,6 +3,7 @@ package core
import (
"io"
"net"
"sync"
"syscall"
"time"
@@ -69,7 +70,9 @@ func (f *TCPForward) forwardTCP(conn net.Conn) {
return
}
bytes := encodeProxyProtocol("tcp", sip, sport, "127.0.0.1", dport)
// todo rewrite to client configuration
targetIP := "127.0.0.1"
bytes := encodeProxyProtocol("tcp", sip, sport, targetIP, dport)
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
@@ -80,15 +83,22 @@ func (f *TCPForward) forwardTCP(conn net.Conn) {
return
}
wg := &sync.WaitGroup{}
wg.Add(1)
defer wg.Wait()
go func() {
defer wg.Done()
defer stream.Close()
defer conn.Close()
io.Copy(stream, conn)
buf := make([]byte, 4096)
io.CopyBuffer(stream, conn, buf)
}()
go func() {
defer stream.Close()
defer conn.Close()
io.Copy(conn, stream)
}()
// todo: optimize mem alloc
// one session will cause 4KB + 4KB buffer for io copy
// and two goroutine 4KB mem used
buf := make([]byte, 4096)
io.CopyBuffer(conn, stream, buf)
stream.Close()
conn.Close()
}

View File

@@ -1,7 +1,6 @@
package core
import (
"bytes"
"encoding/binary"
"fmt"
"io"
@@ -9,7 +8,6 @@ import (
"sync"
"syscall"
"time"
"unsafe"
"github.com/ICKelin/opennotr/pkg/logs"
"github.com/hashicorp/yamux"
@@ -88,7 +86,7 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
origindst, err := getOriginDst(oob[:oobn])
if err != nil {
logs.Error("%v", err)
logs.Error("get origin dst fail: %v", err)
continue
}
@@ -111,7 +109,8 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
}
streams.Store(key, stream)
bytes := encodeProxyProtocol("udp", sip, sport, "127.0.0.1", dport)
targetIP := "127.0.0.1"
bytes := encodeProxyProtocol("udp", sip, sport, targetIP, dport)
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
@@ -119,13 +118,13 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
logs.Error("stream write fail: %v", err)
continue
}
go f.forwardUDP(stream, rawfd, origindst, raddr)
}
val, ok = streams.Load(key)
if !ok {
logs.Error("get stream for %s fail", key)
continue
go f.forwardUDP(stream, rawfd, origindst, raddr)
val, ok = streams.Load(key)
if !ok {
logs.Error("get stream for %s fail", key)
continue
}
}
stream := val.(*yamux.Stream)
@@ -163,39 +162,3 @@ func (f *UDPForward) forwardUDP(stream *yamux.Stream, tofd int, fromaddr, toaddr
}
}
}
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
msgs, err := syscall.ParseSocketControlMessage(hdr)
if err != nil {
return nil, err
}
var origindst *net.UDPAddr
for _, msg := range msgs {
if msg.Header.Level == syscall.SOL_IP &&
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
originDstRaw := &syscall.RawSockaddrInet4{}
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
if err != nil {
logs.Error("read origin dst fail: %v", err)
continue
}
// only support for ipv4
if originDstRaw.Family == syscall.AF_INET {
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
origindst = &net.UDPAddr{
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
Port: int(p[0])<<8 + int(p[1]),
}
}
}
}
if origindst == nil {
return nil, fmt.Errorf("get origin dst fail")
}
return origindst, nil
}

View File

@@ -25,15 +25,23 @@ type C2SAuth struct {
}
type ForwardItem struct {
Protocol string `json:"protocol"` // forward protocol
Ports map[int]int `json:"ports"` // forward to local ports
// ... add other item to controller forward
// forward protocol. eg: tcp, udp, https, http
Protocol string `json:"protocol"`
// forward ports
// key is the port opennotrd listen
// value is local port
Ports map[int]string `json:"ports"`
// local ip, default is 127.0.0.1
// the traffic will be forward to $LocalIP:$LocalPort
// for example: 127.0.0.1:8080. 192.168.31.65:8080
LocalIP string `json:"localIP"`
}
type S2CAuth struct {
Domain string `json:"domain"` // 分配域名
Vip string `json:"vip"` // 分配虚拟ip地址
Gateway string `json:"gateway"` // 网关地址(cidr)
Domain string `json:"domain"` // 分配域名
Vip string `json:"vip"` // 分配虚拟ip地址
}
type ProxyProtocol struct {