mirror of
https://github.com/ICKelin/opennotr.git
synced 2025-09-26 20:01:13 +08:00
feature: support udp tproxy
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
"github.com/ICKelin/opennotr/pkg/proto"
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
@@ -106,7 +107,7 @@ func (c *Client) handleStream(stream *yamux.Stream) {
|
||||
proxyProtocol := proto.ProxyProtocol{}
|
||||
err = json.Unmarshal(buf[:nr], &proxyProtocol)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
log.Println("unmarshal fail: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -123,6 +124,7 @@ func (c *Client) tcpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {
|
||||
remoteConn, err := net.DialTimeout("tcp", addr, time.Second*10)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -139,4 +141,67 @@ func (c *Client) tcpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *Client) udpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {}
|
||||
func (c *Client) udpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {
|
||||
addr := fmt.Sprintf("%s:%s", p.DstIP, p.DstPort)
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
remoteConn, err := net.DialUDP("udp", nil, raddr)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer remoteConn.Close()
|
||||
defer stream.Close()
|
||||
hdr := make([]byte, 2)
|
||||
for {
|
||||
_, err := io.ReadFull(stream, hdr)
|
||||
if err != nil {
|
||||
logs.Error("read stream fail %v", err)
|
||||
break
|
||||
}
|
||||
nlen := binary.BigEndian.Uint16(hdr)
|
||||
buf := make([]byte, nlen)
|
||||
_, err = io.ReadFull(stream, buf)
|
||||
if err != nil {
|
||||
logs.Error("read stream body fail: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
remoteConn.Write(buf)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer remoteConn.Close()
|
||||
defer stream.Close()
|
||||
buf := make([]byte, 64*1024)
|
||||
for {
|
||||
nr, err := remoteConn.Read(buf)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
|
||||
bytes := encode(buf[:nr])
|
||||
_, err = stream.Write(bytes)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func encode(raw []byte) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(raw)))
|
||||
buf = append(buf, raw...)
|
||||
return buf
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ICKelin/opennotr/opennotrd/plugin"
|
||||
"github.com/ICKelin/opennotr/pkg/logs"
|
||||
@@ -93,7 +95,8 @@ func (s *Server) ListenAndServe() error {
|
||||
func (s *Server) onConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// authorize
|
||||
// auth key verify
|
||||
// currently we use auth key which configured in notrd.yaml
|
||||
auth := proto.C2SAuth{}
|
||||
err := proto.ReadJSON(conn, &auth)
|
||||
if err != nil {
|
||||
@@ -106,10 +109,15 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// it client without domain
|
||||
// generate random domain base on time nano
|
||||
if len(auth.Domain) <= 0 {
|
||||
auth.Domain = fmt.Sprintf("%s.%s", randomDomain(time.Now().Unix()), s.domain)
|
||||
auth.Domain = fmt.Sprintf("%s.%s", randomDomain(time.Now().UnixNano()), s.domain)
|
||||
}
|
||||
|
||||
// select a virtual ip for client.
|
||||
// a virtual ip is the ip address which can be use in our system
|
||||
// but cannot be used by other networks
|
||||
vip, err := s.dhcp.SelectIP()
|
||||
if err != nil {
|
||||
logs.Error("dhcp select ip fail: %v", err)
|
||||
@@ -173,7 +181,6 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
// tunnel session
|
||||
sess := newSession(mux, conn.RemoteAddr().String())
|
||||
s.sess.Store(vip, sess)
|
||||
defer s.sess.Delete(vip)
|
||||
@@ -184,6 +191,7 @@ func (s *Server) onConn(conn net.Conn) {
|
||||
case <-mux.CloseChan():
|
||||
logs.Info("session %v close", sess.conn.RemoteAddr().String())
|
||||
return
|
||||
|
||||
case <-rttInterval.C:
|
||||
rx := atomic.SwapUint64(&sess.rxbytes, 0)
|
||||
tx := atomic.SwapUint64(&sess.txbytes, 0)
|
||||
@@ -289,7 +297,202 @@ func (s *Server) tcpProxy(conn net.Conn) {
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) tproxyUDP(listenAddr string) {}
|
||||
func (s *Server) tproxyUDP(listenAddr string) error {
|
||||
laddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
logs.Error("resolve udp fail: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
lconn, err := net.ListenUDP("udp", laddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set socket with ip transparent option
|
||||
file, err := lconn.File()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set socket with recv origin dst option
|
||||
err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rawfd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
|
||||
if err != nil || rawfd < 0 {
|
||||
logs.Error("call socket fail: %v", err)
|
||||
return err
|
||||
}
|
||||
defer syscall.Close(rawfd)
|
||||
|
||||
err = syscall.SetsockoptInt(rawfd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
streams := sync.Map{}
|
||||
defer func() {
|
||||
streams.Range(func(k, v interface{}) bool {
|
||||
v.(*yamux.Stream).Close()
|
||||
return true
|
||||
})
|
||||
}()
|
||||
|
||||
buf := make([]byte, 64*1024)
|
||||
oob := make([]byte, 1024)
|
||||
for {
|
||||
nr, oobn, _, raddr, err := lconn.ReadMsgUDP(buf, oob)
|
||||
if err != nil {
|
||||
logs.Error("read from udp fail: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
origindst, err := getOriginDst(oob[:oobn])
|
||||
if err != nil {
|
||||
logs.Error("%v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
dip, dport, _ := net.SplitHostPort(origindst.String())
|
||||
sip, sport, _ := net.SplitHostPort(raddr.String())
|
||||
|
||||
key := fmt.Sprintf("%s:%s:%s:%s", sip, sport, dip, dport)
|
||||
val, ok := streams.Load(key)
|
||||
if !ok {
|
||||
val, ok := s.sess.Load(dip)
|
||||
if !ok {
|
||||
logs.Error("no route to host: %s", dip)
|
||||
continue
|
||||
}
|
||||
|
||||
stream, err := val.(*Session).conn.OpenStream()
|
||||
if err != nil {
|
||||
logs.Error("open stream fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
streams.Store(key, stream)
|
||||
|
||||
// write proxy protocol
|
||||
proxyProtocol := &proto.ProxyProtocol{
|
||||
Protocol: "udp",
|
||||
SrcIP: sip,
|
||||
SrcPort: sport,
|
||||
// DstIP: dip,
|
||||
DstIP: "127.0.0.1", // may change to client setting
|
||||
DstPort: dport,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(proxyProtocol)
|
||||
if err != nil {
|
||||
logs.Error("json marshal fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
bytes := encode(body)
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
logs.Error("stream write fail: %v", err)
|
||||
continue
|
||||
}
|
||||
go s.udpProxy(stream, rawfd, origindst, raddr)
|
||||
}
|
||||
|
||||
val, ok = streams.Load(key)
|
||||
if !ok {
|
||||
logs.Error("get stream for %s fail", key)
|
||||
continue
|
||||
}
|
||||
|
||||
stream := val.(*yamux.Stream)
|
||||
bytes := encode(buf[:nr])
|
||||
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
|
||||
_, err = stream.Write(bytes)
|
||||
stream.SetWriteDeadline(time.Time{})
|
||||
if err != nil {
|
||||
logs.Error("stream write fail: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) udpProxy(stream *yamux.Stream, tofd int, fromaddr, toaddr *net.UDPAddr) {
|
||||
hdr := make([]byte, 2)
|
||||
for {
|
||||
_, err := io.ReadFull(stream, hdr)
|
||||
if err != nil {
|
||||
logs.Error("read stream fail %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
nlen := binary.BigEndian.Uint16(hdr)
|
||||
buf := make([]byte, nlen)
|
||||
_, err = io.ReadFull(stream, buf)
|
||||
if err != nil {
|
||||
logs.Error("read stream body fail: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
err = sendUDPViaRaw(tofd, fromaddr, toaddr, buf)
|
||||
if err != nil {
|
||||
logs.Error("send via raw socket fail: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func encode(raw []byte) []byte {
|
||||
buf := make([]byte, 2)
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(raw)))
|
||||
buf = append(buf, raw...)
|
||||
return buf
|
||||
}
|
||||
|
||||
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
|
||||
msgs, err := syscall.ParseSocketControlMessage(hdr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var origindst *net.UDPAddr
|
||||
for _, msg := range msgs {
|
||||
if msg.Header.Level == syscall.SOL_IP &&
|
||||
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
|
||||
originDstRaw := &syscall.RawSockaddrInet4{}
|
||||
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
|
||||
if err != nil {
|
||||
logs.Error("read origin dst fail: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// only support for ipv4
|
||||
if originDstRaw.Family == syscall.AF_INET {
|
||||
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
|
||||
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
|
||||
origindst = &net.UDPAddr{
|
||||
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
|
||||
Port: int(p[0])<<8 + int(p[1]),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if origindst == nil {
|
||||
return nil, fmt.Errorf("get origin dst fail")
|
||||
}
|
||||
|
||||
return origindst, nil
|
||||
}
|
||||
|
||||
// randomDomain generate random domain for client
|
||||
func randomDomain(num int64) string {
|
||||
|
75
opennotrd/core/udpproxy.go
Normal file
75
opennotrd/core/udpproxy.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func checksum_add(buf []byte, seed uint32) uint32 {
|
||||
sum := seed
|
||||
for i, l := 0, len(buf); i < l; i += 2 {
|
||||
j := i + 1
|
||||
if j == l {
|
||||
sum += uint32(buf[i]) << 8
|
||||
break
|
||||
}
|
||||
sum += uint32(buf[i])<<8 | uint32(buf[j])
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func checksum_warp(seed uint32) uint16 {
|
||||
sum := seed
|
||||
for sum > 0xffff {
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
}
|
||||
csum := ^uint16(sum)
|
||||
|
||||
// RFC 768
|
||||
if csum == 0 {
|
||||
csum = 0xffff
|
||||
}
|
||||
return csum
|
||||
}
|
||||
|
||||
func CheckSum(buf []byte) uint16 {
|
||||
return checksum_warp(checksum_add(buf, 0))
|
||||
}
|
||||
|
||||
func sendUDPViaRaw(fd int, src, dst *net.UDPAddr, payload []byte) error {
|
||||
iplen, ulen := uint16(28+len(payload)), uint16(8+len(payload))
|
||||
if iplen > 65535 {
|
||||
return fmt.Errorf("too big packet")
|
||||
}
|
||||
|
||||
// UDP checksum: sip + dip + udp-head + payload + PROTO + ulen
|
||||
data := make([]byte, iplen)
|
||||
data[9] = syscall.IPPROTO_UDP
|
||||
copy(data[12:16], src.IP.To4())
|
||||
copy(data[16:20], dst.IP.To4())
|
||||
data[20] = byte(src.Port >> 8)
|
||||
data[21] = byte(src.Port)
|
||||
data[22] = byte(dst.Port >> 8)
|
||||
data[23] = byte(dst.Port)
|
||||
data[24] = byte(ulen >> 8)
|
||||
data[25] = byte(ulen)
|
||||
copy(data[28:], payload)
|
||||
|
||||
uc := checksum_warp(checksum_add(data, uint32(ulen)))
|
||||
data[26] = byte(uc >> 8)
|
||||
data[27] = byte(uc)
|
||||
|
||||
data[0] = 0x45
|
||||
data[2] = byte(iplen >> 8)
|
||||
data[3] = byte(iplen)
|
||||
data[6] = 0x40
|
||||
data[8] = 64
|
||||
ipc := CheckSum(data[:20])
|
||||
data[10] = byte(ipc >> 8)
|
||||
data[11] = byte(ipc)
|
||||
|
||||
addr := syscall.SockaddrInet4{Port: dst.Port}
|
||||
copy(addr.Addr[:], data[16:20])
|
||||
return syscall.Sendto(fd, data, 0, &addr)
|
||||
}
|
Reference in New Issue
Block a user