mirror of
https://github.com/ICKelin/opennotr.git
synced 2025-09-26 20:01:13 +08:00
refactor: code optimize
This commit is contained in:
@@ -2,14 +2,13 @@ server:
|
||||
listen: ":10100"
|
||||
authKey: "client server exchange key"
|
||||
domain: "open.notr.tech"
|
||||
tcplisten: ":4398"
|
||||
udplisten: ":4399"
|
||||
|
||||
dhcp:
|
||||
cidr: "100.64.242.1/24"
|
||||
ip: "100.64.242.1"
|
||||
|
||||
upstream:
|
||||
remoteAddr: "http://127.0.0.1:81/upstreams"
|
||||
|
||||
plugin:
|
||||
tcp: |
|
||||
{
|
||||
@@ -32,9 +31,8 @@ plugin:
|
||||
{
|
||||
"adminUrl": "http://127.0.0.1:81/upstreams"
|
||||
}
|
||||
|
||||
|
||||
h2c: |
|
||||
{
|
||||
"adminUrl": "http://127.0.0.1:81/upstreams"
|
||||
}
|
||||
|
@@ -1,21 +0,0 @@
|
||||
package core
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Packet []byte
|
||||
|
||||
func (p Packet) Invalid() bool {
|
||||
return len(p) < 20
|
||||
}
|
||||
|
||||
func (p Packet) Version() int {
|
||||
return int((p[0] >> 4))
|
||||
}
|
||||
|
||||
func (p Packet) Dst() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", p[16], p[17], p[18], p[19])
|
||||
}
|
||||
|
||||
func (p Packet) Src() string {
|
||||
return fmt.Sprintf("%d.%d.%d.%d", p[12], p[13], p[14], p[15])
|
||||
}
|
@@ -1,17 +1,19 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
"github.com/ICKelin/opennotr/pkg/proto"
|
||||
)
|
||||
|
||||
func checksum_add(buf []byte, seed uint32) uint32 {
|
||||
func checksumAdd(buf []byte, seed uint32) uint32 {
|
||||
sum := seed
|
||||
for i, l := 0, len(buf); i < l; i += 2 {
|
||||
j := i + 1
|
||||
@@ -24,7 +26,7 @@ func checksum_add(buf []byte, seed uint32) uint32 {
|
||||
return sum
|
||||
}
|
||||
|
||||
func checksum_wrap(seed uint32) uint16 {
|
||||
func checksumWrapper(seed uint32) uint16 {
|
||||
sum := seed
|
||||
for sum > 0xffff {
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
@@ -39,7 +41,7 @@ func checksum_wrap(seed uint32) uint16 {
|
||||
}
|
||||
|
||||
func CheckSum(buf []byte) uint16 {
|
||||
return checksum_wrap(checksum_add(buf, 0))
|
||||
return checksumWrapper(checksumAdd(buf, 0))
|
||||
}
|
||||
|
||||
func sendUDPViaRaw(fd int, src, dst *net.UDPAddr, payload []byte) error {
|
||||
@@ -61,7 +63,7 @@ func sendUDPViaRaw(fd int, src, dst *net.UDPAddr, payload []byte) error {
|
||||
data[25] = byte(ulen)
|
||||
copy(data[28:], payload)
|
||||
|
||||
uc := checksum_wrap(checksum_add(data, uint32(ulen)))
|
||||
uc := checksumWrapper(checksumAdd(data, uint32(ulen)))
|
||||
data[26] = byte(uc >> 8)
|
||||
data[27] = byte(uc)
|
||||
|
||||
@@ -95,11 +97,42 @@ func encodeProxyProtocol(protocol, sip, sport, dip, dport string) []byte {
|
||||
DstPort: dport,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(proxyProtocol)
|
||||
body, _ := json.Marshal(proxyProtocol)
|
||||
return encode(body)
|
||||
}
|
||||
|
||||
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
|
||||
msgs, err := syscall.ParseSocketControlMessage(hdr)
|
||||
if err != nil {
|
||||
logs.Error("json marshal fail: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bytes := encode(body)
|
||||
return bytes
|
||||
var origindst *net.UDPAddr
|
||||
for _, msg := range msgs {
|
||||
if msg.Header.Level == syscall.SOL_IP &&
|
||||
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
|
||||
originDstRaw := &syscall.RawSockaddrInet4{}
|
||||
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
|
||||
if err != nil {
|
||||
logs.Error("read origin dst fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// only support for ipv4
|
||||
if originDstRaw.Family == syscall.AF_INET {
|
||||
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
|
||||
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
|
||||
origindst = &net.UDPAddr{
|
||||
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
|
||||
Port: int(p[0])<<8 + int(p[1]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if origindst == nil {
|
||||
return nil, fmt.Errorf("get origin dst fail")
|
||||
}
|
||||
|
||||
return origindst, nil
|
||||
}
|
@@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ICKelin/opennotr/opennotrd/plugin"
|
||||
@@ -84,7 +83,7 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// it client without domain
|
||||
// if client without domain
|
||||
// generate random domain base on time nano
|
||||
if len(auth.Domain) <= 0 {
|
||||
auth.Domain = fmt.Sprintf("%s.%s", randomDomain(time.Now().UnixNano()), s.domain)
|
||||
@@ -100,9 +99,8 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
}
|
||||
|
||||
reply := &proto.S2CAuth{
|
||||
Vip: vip,
|
||||
Gateway: s.dhcp.GetCIDR(),
|
||||
Domain: auth.Domain,
|
||||
Vip: vip,
|
||||
Domain: auth.Domain,
|
||||
}
|
||||
|
||||
err = proto.WriteJSON(conn, proto.CmdAuth, reply)
|
||||
@@ -125,23 +123,23 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
logs.Info("select domain: %s", auth.Domain)
|
||||
|
||||
// create forward
|
||||
// $localPort => vip:$upstreamPort
|
||||
// $publicPort => vip:$localPort
|
||||
// 1. for from address, we listen 0.0.0.0:$inport
|
||||
// from member is not used for restyproxy
|
||||
// 2. for to address, we use $vip:$upstreamPort
|
||||
// the vip is the virtual lan ip address
|
||||
// Domain is only use for restyproxy
|
||||
for _, forward := range auth.Forward {
|
||||
for localPort, upstreamPort := range forward.Ports {
|
||||
for publicPort, localPort := range forward.Ports {
|
||||
item := &plugin.PluginMeta{
|
||||
Protocol: forward.Protocol,
|
||||
From: fmt.Sprintf("0.0.0.0:%d", localPort),
|
||||
To: fmt.Sprintf("%s:%d", vip, upstreamPort),
|
||||
From: fmt.Sprintf("0.0.0.0:%d", publicPort),
|
||||
To: fmt.Sprintf("%s:%s", vip, localPort),
|
||||
Domain: auth.Domain,
|
||||
RecycleSignal: make(chan struct{}),
|
||||
}
|
||||
|
||||
err := s.pluginMgr.AddProxy(item)
|
||||
err = s.pluginMgr.AddProxy(item)
|
||||
if err != nil {
|
||||
logs.Error("add proxy fail: %v", err)
|
||||
return
|
||||
@@ -156,7 +154,7 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
sess := newSession(mux, conn.RemoteAddr().String())
|
||||
sess := newSession(mux, vip)
|
||||
s.sessMgr.AddSession(vip, sess)
|
||||
defer s.sessMgr.DeleteSession(vip)
|
||||
|
||||
@@ -168,10 +166,10 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
return
|
||||
|
||||
case <-rttInterval.C:
|
||||
rx := atomic.SwapUint64(&sess.rxbytes, 0)
|
||||
tx := atomic.SwapUint64(&sess.txbytes, 0)
|
||||
rx := sess.ResetRx()
|
||||
tx := sess.ResetTx()
|
||||
rtt, _ := mux.Ping()
|
||||
logs.Debug("session %s rtt %d, rx %d tx %d",
|
||||
logs.Debug("session %s rtt %dms, rx %d tx %d",
|
||||
sess.conn.RemoteAddr().String(), rtt.Milliseconds(), rx, tx)
|
||||
}
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package core
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
@@ -17,19 +18,33 @@ func GetSessionManager() *SessionManager {
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
conn *yamux.Session
|
||||
clientAddr string
|
||||
rxbytes uint64
|
||||
txbytes uint64
|
||||
conn *yamux.Session
|
||||
rxbytes uint64
|
||||
txbytes uint64
|
||||
}
|
||||
|
||||
func newSession(conn *yamux.Session, clientAddr string) *Session {
|
||||
func newSession(conn *yamux.Session, vip string) *Session {
|
||||
return &Session{
|
||||
conn: conn,
|
||||
clientAddr: clientAddr,
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) ResetRx() uint64 {
|
||||
return atomic.SwapUint64(&s.rxbytes, 0)
|
||||
}
|
||||
|
||||
func (s *Session) ResetTx() uint64 {
|
||||
return atomic.SwapUint64(&s.txbytes, 0)
|
||||
}
|
||||
|
||||
func (s *Session) IncRx(nb uint64) {
|
||||
atomic.AddUint64(&s.rxbytes, nb)
|
||||
}
|
||||
|
||||
func (s *Session) IncTx(nb uint64) {
|
||||
atomic.AddUint64(&s.txbytes, nb)
|
||||
}
|
||||
|
||||
func (mgr *SessionManager) AddSession(vip string, sess *Session) {
|
||||
mgr.sessions.Store(vip, sess)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package core
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -69,7 +70,9 @@ func (f *TCPForward) forwardTCP(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
bytes := encodeProxyProtocol("tcp", sip, sport, "127.0.0.1", dport)
|
||||
// todo rewrite to client configuration
|
||||
targetIP := "127.0.0.1"
|
||||
bytes := encodeProxyProtocol("tcp", sip, sport, targetIP, dport)
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
@@ -80,15 +83,22 @@ func (f *TCPForward) forwardTCP(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
defer wg.Wait()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer stream.Close()
|
||||
defer conn.Close()
|
||||
io.Copy(stream, conn)
|
||||
buf := make([]byte, 4096)
|
||||
io.CopyBuffer(stream, conn, buf)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer stream.Close()
|
||||
defer conn.Close()
|
||||
io.Copy(conn, stream)
|
||||
}()
|
||||
// todo: optimize mem alloc
|
||||
// one session will cause 4KB + 4KB buffer for io copy
|
||||
// and two goroutine 4KB mem used
|
||||
buf := make([]byte, 4096)
|
||||
io.CopyBuffer(conn, stream, buf)
|
||||
stream.Close()
|
||||
conn.Close()
|
||||
}
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -9,7 +8,6 @@ import (
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
"github.com/hashicorp/yamux"
|
||||
@@ -88,7 +86,7 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
|
||||
origindst, err := getOriginDst(oob[:oobn])
|
||||
if err != nil {
|
||||
logs.Error("%v", err)
|
||||
logs.Error("get origin dst fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -111,7 +109,8 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
}
|
||||
streams.Store(key, stream)
|
||||
|
||||
bytes := encodeProxyProtocol("udp", sip, sport, "127.0.0.1", dport)
|
||||
targetIP := "127.0.0.1"
|
||||
bytes := encodeProxyProtocol("udp", sip, sport, targetIP, dport)
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
@@ -119,13 +118,13 @@ func (f *UDPForward) ListenAndServe(listenAddr string) error {
|
||||
logs.Error("stream write fail: %v", err)
|
||||
continue
|
||||
}
|
||||
go f.forwardUDP(stream, rawfd, origindst, raddr)
|
||||
}
|
||||
|
||||
val, ok = streams.Load(key)
|
||||
if !ok {
|
||||
logs.Error("get stream for %s fail", key)
|
||||
continue
|
||||
go f.forwardUDP(stream, rawfd, origindst, raddr)
|
||||
val, ok = streams.Load(key)
|
||||
if !ok {
|
||||
logs.Error("get stream for %s fail", key)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stream := val.(*yamux.Stream)
|
||||
@@ -163,39 +162,3 @@ func (f *UDPForward) forwardUDP(stream *yamux.Stream, tofd int, fromaddr, toaddr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
|
||||
msgs, err := syscall.ParseSocketControlMessage(hdr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var origindst *net.UDPAddr
|
||||
for _, msg := range msgs {
|
||||
if msg.Header.Level == syscall.SOL_IP &&
|
||||
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
|
||||
originDstRaw := &syscall.RawSockaddrInet4{}
|
||||
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
|
||||
if err != nil {
|
||||
logs.Error("read origin dst fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// only support for ipv4
|
||||
if originDstRaw.Family == syscall.AF_INET {
|
||||
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
|
||||
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
|
||||
origindst = &net.UDPAddr{
|
||||
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
|
||||
Port: int(p[0])<<8 + int(p[1]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if origindst == nil {
|
||||
return nil, fmt.Errorf("get origin dst fail")
|
||||
}
|
||||
|
||||
return origindst, nil
|
||||
}
|
||||
|
@@ -25,15 +25,23 @@ type C2SAuth struct {
|
||||
}
|
||||
|
||||
type ForwardItem struct {
|
||||
Protocol string `json:"protocol"` // forward protocol
|
||||
Ports map[int]int `json:"ports"` // forward to local ports
|
||||
// ... add other item to controller forward
|
||||
// forward protocol. eg: tcp, udp, https, http
|
||||
Protocol string `json:"protocol"`
|
||||
|
||||
// forward ports
|
||||
// key is the port opennotrd listen
|
||||
// value is local port
|
||||
Ports map[int]string `json:"ports"`
|
||||
|
||||
// local ip, default is 127.0.0.1
|
||||
// the traffic will be forward to $LocalIP:$LocalPort
|
||||
// for example: 127.0.0.1:8080. 192.168.31.65:8080
|
||||
LocalIP string `json:"localIP"`
|
||||
}
|
||||
|
||||
type S2CAuth struct {
|
||||
Domain string `json:"domain"` // 分配域名
|
||||
Vip string `json:"vip"` // 分配虚拟ip地址
|
||||
Gateway string `json:"gateway"` // 网关地址(cidr)
|
||||
Domain string `json:"domain"` // 分配域名
|
||||
Vip string `json:"vip"` // 分配虚拟ip地址
|
||||
}
|
||||
|
||||
type ProxyProtocol struct {
|
||||
|
Reference in New Issue
Block a user