feature: support udp tproxy

This commit is contained in:
ICKelin
2021-05-05 11:54:17 +08:00
parent b53287d522
commit 54e4b0f24a
3 changed files with 349 additions and 6 deletions

View File

@@ -9,6 +9,7 @@ import (
"net"
"time"
"github.com/ICKelin/opennotr/pkg/logs"
"github.com/ICKelin/opennotr/pkg/proto"
"github.com/hashicorp/yamux"
)
@@ -106,7 +107,7 @@ func (c *Client) handleStream(stream *yamux.Stream) {
proxyProtocol := proto.ProxyProtocol{}
err = json.Unmarshal(buf[:nr], &proxyProtocol)
if err != nil {
log.Println(err)
log.Println("unmarshal fail: ", err)
return
}
@@ -123,6 +124,7 @@ func (c *Client) tcpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {
remoteConn, err := net.DialTimeout("tcp", addr, time.Second*10)
if err != nil {
log.Println(err)
stream.Close()
return
}
@@ -139,4 +141,67 @@ func (c *Client) tcpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {
}()
}
func (c *Client) udpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {}
func (c *Client) udpProxy(stream *yamux.Stream, p *proto.ProxyProtocol) {
addr := fmt.Sprintf("%s:%s", p.DstIP, p.DstPort)
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
log.Println(err)
stream.Close()
return
}
remoteConn, err := net.DialUDP("udp", nil, raddr)
if err != nil {
log.Println(err)
return
}
go func() {
defer remoteConn.Close()
defer stream.Close()
hdr := make([]byte, 2)
for {
_, err := io.ReadFull(stream, hdr)
if err != nil {
logs.Error("read stream fail %v", err)
break
}
nlen := binary.BigEndian.Uint16(hdr)
buf := make([]byte, nlen)
_, err = io.ReadFull(stream, buf)
if err != nil {
logs.Error("read stream body fail: %v", err)
break
}
remoteConn.Write(buf)
}
}()
go func() {
defer remoteConn.Close()
defer stream.Close()
buf := make([]byte, 64*1024)
for {
nr, err := remoteConn.Read(buf)
if err != nil {
log.Println(err)
break
}
bytes := encode(buf[:nr])
_, err = stream.Write(bytes)
if err != nil {
log.Println(err)
break
}
}
}()
}
func encode(raw []byte) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(len(raw)))
buf = append(buf, raw...)
return buf
}

View File

@@ -1,6 +1,7 @@
package core
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
@@ -13,6 +14,7 @@ import (
"sync/atomic"
"syscall"
"time"
"unsafe"
"github.com/ICKelin/opennotr/opennotrd/plugin"
"github.com/ICKelin/opennotr/pkg/logs"
@@ -93,7 +95,8 @@ func (s *Server) ListenAndServe() error {
func (s *Server) onConn(conn net.Conn) {
defer conn.Close()
// authorize
// auth key verify
// currently we use auth key which configured in notrd.yaml
auth := proto.C2SAuth{}
err := proto.ReadJSON(conn, &auth)
if err != nil {
@@ -106,10 +109,15 @@ func (s *Server) onConn(conn net.Conn) {
return
}
// it client without domain
// generate random domain base on time nano
if len(auth.Domain) <= 0 {
auth.Domain = fmt.Sprintf("%s.%s", randomDomain(time.Now().Unix()), s.domain)
auth.Domain = fmt.Sprintf("%s.%s", randomDomain(time.Now().UnixNano()), s.domain)
}
// select a virtual ip for client.
// a virtual ip is the ip address which can be use in our system
// but cannot be used by other networks
vip, err := s.dhcp.SelectIP()
if err != nil {
logs.Error("dhcp select ip fail: %v", err)
@@ -173,7 +181,6 @@ func (s *Server) onConn(conn net.Conn) {
return
}
// tunnel session
sess := newSession(mux, conn.RemoteAddr().String())
s.sess.Store(vip, sess)
defer s.sess.Delete(vip)
@@ -184,6 +191,7 @@ func (s *Server) onConn(conn net.Conn) {
case <-mux.CloseChan():
logs.Info("session %v close", sess.conn.RemoteAddr().String())
return
case <-rttInterval.C:
rx := atomic.SwapUint64(&sess.rxbytes, 0)
tx := atomic.SwapUint64(&sess.txbytes, 0)
@@ -289,7 +297,202 @@ func (s *Server) tcpProxy(conn net.Conn) {
}()
}
func (s *Server) tproxyUDP(listenAddr string) {}
func (s *Server) tproxyUDP(listenAddr string) error {
laddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
logs.Error("resolve udp fail: %v", err)
return err
}
lconn, err := net.ListenUDP("udp", laddr)
if err != nil {
return err
}
// set socket with ip transparent option
file, err := lconn.File()
if err != nil {
return err
}
defer file.Close()
err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_IP, syscall.IP_TRANSPARENT, 1)
if err != nil {
return err
}
// set socket with recv origin dst option
err = syscall.SetsockoptInt(int(file.Fd()), syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1)
if err != nil {
return err
}
rawfd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil || rawfd < 0 {
logs.Error("call socket fail: %v", err)
return err
}
defer syscall.Close(rawfd)
err = syscall.SetsockoptInt(rawfd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return err
}
streams := sync.Map{}
defer func() {
streams.Range(func(k, v interface{}) bool {
v.(*yamux.Stream).Close()
return true
})
}()
buf := make([]byte, 64*1024)
oob := make([]byte, 1024)
for {
nr, oobn, _, raddr, err := lconn.ReadMsgUDP(buf, oob)
if err != nil {
logs.Error("read from udp fail: %v", err)
break
}
origindst, err := getOriginDst(oob[:oobn])
if err != nil {
logs.Error("%v", err)
continue
}
dip, dport, _ := net.SplitHostPort(origindst.String())
sip, sport, _ := net.SplitHostPort(raddr.String())
key := fmt.Sprintf("%s:%s:%s:%s", sip, sport, dip, dport)
val, ok := streams.Load(key)
if !ok {
val, ok := s.sess.Load(dip)
if !ok {
logs.Error("no route to host: %s", dip)
continue
}
stream, err := val.(*Session).conn.OpenStream()
if err != nil {
logs.Error("open stream fail: %v", err)
continue
}
streams.Store(key, stream)
// write proxy protocol
proxyProtocol := &proto.ProxyProtocol{
Protocol: "udp",
SrcIP: sip,
SrcPort: sport,
// DstIP: dip,
DstIP: "127.0.0.1", // may change to client setting
DstPort: dport,
}
body, err := json.Marshal(proxyProtocol)
if err != nil {
logs.Error("json marshal fail: %v", err)
continue
}
bytes := encode(body)
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
if err != nil {
logs.Error("stream write fail: %v", err)
continue
}
go s.udpProxy(stream, rawfd, origindst, raddr)
}
val, ok = streams.Load(key)
if !ok {
logs.Error("get stream for %s fail", key)
continue
}
stream := val.(*yamux.Stream)
bytes := encode(buf[:nr])
stream.SetWriteDeadline(time.Now().Add(time.Second * 10))
_, err = stream.Write(bytes)
stream.SetWriteDeadline(time.Time{})
if err != nil {
logs.Error("stream write fail: %v", err)
}
}
return nil
}
func (s *Server) udpProxy(stream *yamux.Stream, tofd int, fromaddr, toaddr *net.UDPAddr) {
hdr := make([]byte, 2)
for {
_, err := io.ReadFull(stream, hdr)
if err != nil {
logs.Error("read stream fail %v", err)
break
}
nlen := binary.BigEndian.Uint16(hdr)
buf := make([]byte, nlen)
_, err = io.ReadFull(stream, buf)
if err != nil {
logs.Error("read stream body fail: %v", err)
break
}
err = sendUDPViaRaw(tofd, fromaddr, toaddr, buf)
if err != nil {
logs.Error("send via raw socket fail: %v", err)
}
}
}
func encode(raw []byte) []byte {
buf := make([]byte, 2)
binary.BigEndian.PutUint16(buf, uint16(len(raw)))
buf = append(buf, raw...)
return buf
}
func getOriginDst(hdr []byte) (*net.UDPAddr, error) {
msgs, err := syscall.ParseSocketControlMessage(hdr)
if err != nil {
return nil, err
}
var origindst *net.UDPAddr
for _, msg := range msgs {
if msg.Header.Level == syscall.SOL_IP &&
msg.Header.Type == syscall.IP_RECVORIGDSTADDR {
originDstRaw := &syscall.RawSockaddrInet4{}
err := binary.Read(bytes.NewReader(msg.Data), binary.LittleEndian, originDstRaw)
if err != nil {
logs.Error("read origin dst fail: %v", err)
continue
}
// only support for ipv4
if originDstRaw.Family == syscall.AF_INET {
pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originDstRaw))
p := (*[2]byte)(unsafe.Pointer(&pp.Port))
origindst = &net.UDPAddr{
IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]),
Port: int(p[0])<<8 + int(p[1]),
}
}
}
}
if origindst == nil {
return nil, fmt.Errorf("get origin dst fail")
}
return origindst, nil
}
// randomDomain generate random domain for client
func randomDomain(num int64) string {

View File

@@ -0,0 +1,75 @@
package core
import (
"fmt"
"net"
"syscall"
)
func checksum_add(buf []byte, seed uint32) uint32 {
sum := seed
for i, l := 0, len(buf); i < l; i += 2 {
j := i + 1
if j == l {
sum += uint32(buf[i]) << 8
break
}
sum += uint32(buf[i])<<8 | uint32(buf[j])
}
return sum
}
func checksum_warp(seed uint32) uint16 {
sum := seed
for sum > 0xffff {
sum = (sum >> 16) + (sum & 0xffff)
}
csum := ^uint16(sum)
// RFC 768
if csum == 0 {
csum = 0xffff
}
return csum
}
func CheckSum(buf []byte) uint16 {
return checksum_warp(checksum_add(buf, 0))
}
func sendUDPViaRaw(fd int, src, dst *net.UDPAddr, payload []byte) error {
iplen, ulen := uint16(28+len(payload)), uint16(8+len(payload))
if iplen > 65535 {
return fmt.Errorf("too big packet")
}
// UDP checksum: sip + dip + udp-head + payload + PROTO + ulen
data := make([]byte, iplen)
data[9] = syscall.IPPROTO_UDP
copy(data[12:16], src.IP.To4())
copy(data[16:20], dst.IP.To4())
data[20] = byte(src.Port >> 8)
data[21] = byte(src.Port)
data[22] = byte(dst.Port >> 8)
data[23] = byte(dst.Port)
data[24] = byte(ulen >> 8)
data[25] = byte(ulen)
copy(data[28:], payload)
uc := checksum_warp(checksum_add(data, uint32(ulen)))
data[26] = byte(uc >> 8)
data[27] = byte(uc)
data[0] = 0x45
data[2] = byte(iplen >> 8)
data[3] = byte(iplen)
data[6] = 0x40
data[8] = 64
ipc := CheckSum(data[:20])
data[10] = byte(ipc >> 8)
data[11] = byte(ipc)
addr := syscall.SockaddrInet4{Port: dst.Port}
copy(addr.Addr[:], data[16:20])
return syscall.Sendto(fd, data, 0, &addr)
}