diff --git a/examples/trojan.client.toml b/examples/trojan.client.toml new file mode 100644 index 0000000..d286b76 --- /dev/null +++ b/examples/trojan.client.toml @@ -0,0 +1,13 @@ +[[listen]] +protocol = "socks5" +host = "127.0.0.1" +port = 10800 + + +[[dial]] +protocol = "trojans" +uuid = "a684455c-b14f-11ea-bf0d-42010aaa0003" +host = "127.0.0.1" +port = 4434 +version = 0 +insecure = true diff --git a/examples/trojan.server.toml b/examples/trojan.server.toml new file mode 100644 index 0000000..25fc02a --- /dev/null +++ b/examples/trojan.server.toml @@ -0,0 +1,14 @@ +[[listen]] +protocol = "trojans" +uuid = "a684455c-b14f-11ea-bf0d-42010aaa0003" +host = "0.0.0.0" +port = 4434 +version = 0 +insecure = true +fallback = ":80" +cert = "cert.pem" +key = "cert.key" + +[[dial]] +protocol = "direct" + diff --git a/main.go b/main.go index ab9bb09..7254e51 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,8 @@ import ( _ "github.com/hahahrfool/v2ray_simple/proxy/direct" _ "github.com/hahahrfool/v2ray_simple/proxy/dokodemo" _ "github.com/hahahrfool/v2ray_simple/proxy/http" + _ "github.com/hahahrfool/v2ray_simple/proxy/socks5" + _ "github.com/hahahrfool/v2ray_simple/proxy/trojan" ) const ( diff --git a/netLayer/addr.go b/netLayer/addr.go index aa42d3e..5845d06 100644 --- a/netLayer/addr.go +++ b/netLayer/addr.go @@ -338,6 +338,7 @@ dialedPart: // 如果a的ip不为空,则会返回 AtypIP4 或 AtypIP6,否则会返回 AtypDomain // Return address bytes and type // 如果atyp类型是 域名,则 第一字节为该域名的总长度, 其余字节为域名内容。 +// 如果类型是ip,则会拷贝出该ip的数据的副本。 func (a *Addr) AddressBytes() ([]byte, byte) { var addr []byte var atyp byte diff --git a/proxy/trojan/client.go b/proxy/trojan/client.go index a5000c0..b8feb6b 100644 --- a/proxy/trojan/client.go +++ b/proxy/trojan/client.go @@ -12,12 +12,16 @@ import ( "github.com/hahahrfool/v2ray_simple/utils" ) +func init() { + proxy.RegisterClient(Name, ClientCreator{}) +} + //作为对照,可以参考 https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/client.go type ClientCreator struct{} func (_ ClientCreator) NewClientFromURL(u *url.URL) (proxy.Client, error) { - return nil, errors.New("not implemented") + return nil, utils.ErrNotImplemented } func (_ ClientCreator) NewClient(dc *proxy.DialConf) (proxy.Client, error) { @@ -37,10 +41,10 @@ type Client struct { } func (c *Client) Name() string { - return name + return Name } -func WriteTargetToBuf(target netLayer.Addr, buf *bytes.Buffer) { +func WriteAddrToBuf(target netLayer.Addr, buf *bytes.Buffer) { if len(target.IP) > 0 { if ip4 := target.IP.To4(); ip4 == nil { buf.WriteByte(netLayer.AtypIP6) @@ -69,7 +73,7 @@ func (c *Client) Handshake(underlay net.Conn, target netLayer.Addr) (io.ReadWrit buf.Write(c.password_hexStringBytes) buf.Write(crlf) buf.WriteByte(CmdConnect) - WriteTargetToBuf(target, buf) + WriteAddrToBuf(target, buf) _, err := underlay.Write(buf.Bytes()) utils.PutBuf(buf) @@ -89,12 +93,12 @@ func (c *Client) EstablishUDPChannel(underlay net.Conn, target netLayer.Addr) (n buf.Write(c.password_hexStringBytes) buf.Write(crlf) buf.WriteByte(CmdUDPAssociate) - WriteTargetToBuf(target, buf) + WriteAddrToBuf(target, buf) _, err := underlay.Write(buf.Bytes()) utils.PutBuf(buf) if err != nil { return nil, err } - return UDPConn{underlay}, nil + return NewUDPConn(underlay, nil), nil } diff --git a/proxy/trojan/server.go b/proxy/trojan/server.go new file mode 100644 index 0000000..19bb1c2 --- /dev/null +++ b/proxy/trojan/server.go @@ -0,0 +1,161 @@ +package trojan + +import ( + "bytes" + "errors" + "io" + "net" + "net/url" + "time" + + "github.com/hahahrfool/v2ray_simple/netLayer" + "github.com/hahahrfool/v2ray_simple/proxy" + "github.com/hahahrfool/v2ray_simple/utils" +) + +func init() { + proxy.RegisterServer(Name, &ServerCreator{}) +} + +type Server struct { + proxy.ProxyCommonStruct + + userHashes map[string]bool + + //mux4Hashes sync.RWMutex +} +type ServerCreator struct{} + +func (_ ServerCreator) NewServer(lc *proxy.ListenConf) (proxy.Server, error) { + uuidStr := lc.Uuid + + s := &Server{ + userHashes: make(map[string]bool), + } + + s.userHashes[string(SHA224_hexStringBytes(uuidStr))] = true + + return s, nil +} + +func (_ ServerCreator) NewServerFromURL(u *url.URL) (proxy.Server, error) { + return nil, utils.ErrNotImplemented +} +func (s *Server) Name() string { + return Name +} + +func (s *Server) Handshake(underlay net.Conn) (result io.ReadWriteCloser, msgConn netLayer.MsgConn, targetAddr netLayer.Addr, returnErr error) { + if err := underlay.SetReadDeadline(time.Now().Add(time.Second * 4)); err != nil { + returnErr = err + return + } + defer underlay.SetReadDeadline(time.Time{}) + + readbs := utils.GetBytes(utils.StandardBytesLength) + + wholeReadLen, err := underlay.Read(readbs) + if err != nil { + returnErr = utils.ErrInErr{ErrDesc: "read err", ErrDetail: err, Data: wholeReadLen} + return + } + + if wholeReadLen < 17 { + //根据下面回答,HTTP的最小长度恰好是16字节,但是是0.9版本。1.0是18字节,1.1还要更长。总之我们可以直接不返回fallback地址 + //https://stackoverflow.com/questions/25047905/http-request-minimum-size-in-bytes/25065089 + + returnErr = utils.ErrInErr{ErrDesc: "fallback, msg too short", Data: wholeReadLen} + return + } + + readbuf := bytes.NewBuffer(readbs[:wholeReadLen]) + + goto realPart + +errorPart: + + //所返回的buffer必须包含所有数据,而Buffer不支持回退,所以只能重新New + returnErr = &utils.ErrFirstBuffer{ + Err: returnErr, + First: bytes.NewBuffer(readbs[:wholeReadLen]), + } + return + +realPart: + + if wholeReadLen < 56+8+1 { + returnErr = utils.ErrInvalidData + goto errorPart + } + + //可参考 https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/server.go + + hash := readbuf.Next(56) + hashStr := string(hash) + if !s.userHashes[hashStr] { + returnErr = errors.New("hash not match") + goto errorPart + } + + crb, _ := readbuf.ReadByte() + lfb, _ := readbuf.ReadByte() + if crb != crlf[0] || lfb != crlf[1] { + returnErr = utils.ErrInvalidData + goto errorPart + } + + cmdb, _ := readbuf.ReadByte() + + var isudp bool + switch cmdb { + default: + returnErr = utils.ErrInvalidData + goto errorPart + case CmdConnect: + + case CmdUDPAssociate: + isudp = true + } + + targetAddr, err = GetAddrFromReader(readbuf) + if err != nil { + returnErr = err + goto errorPart + } + if isudp { + targetAddr.Network = "udp" + } + crb, err = readbuf.ReadByte() + if err != nil { + returnErr = err + goto errorPart + } + lfb, err = readbuf.ReadByte() + if err != nil { + returnErr = err + goto errorPart + } + if crb != crlf[0] || lfb != crlf[1] { + returnErr = utils.ErrInvalidData + goto errorPart + } + + if isudp { + return nil, NewUDPConn(underlay, io.MultiReader(readbuf, underlay)), targetAddr, nil + + } else { + if readbuf.Len() == 0 { + return underlay, nil, targetAddr, nil + } else { + return &UserTCPConn{ + Conn: underlay, + optionalReader: io.MultiReader(readbuf, underlay), + remainFirstBufLen: readbuf.Len(), + hash: hashStr, + underlayIsBasic: netLayer.IsBasicConn(underlay), + isServerEnd: true, + }, nil, targetAddr, nil + } + + } +} diff --git a/proxy/trojan/tcpconn.go b/proxy/trojan/tcpconn.go new file mode 100644 index 0000000..a4b6c88 --- /dev/null +++ b/proxy/trojan/tcpconn.go @@ -0,0 +1,68 @@ +package trojan + +import ( + "bufio" + "io" + "net" + + "github.com/hahahrfool/v2ray_simple/netLayer" +) + +//trojan比较简洁,这个 UserTCPConn 只是用于读取握手读取时读到的剩余的缓存 +type UserTCPConn struct { + net.Conn + optionalReader io.Reader //在使用了缓存读取握手包头后,就产生了buffer中有剩余数据的可能性,此时就要使用MultiReader + + remainFirstBufLen int //记录读取握手包头时读到的buf的长度. 如果我们读超过了这个部分的话,实际上我们就可以不再使用 optionalReader 读取, 而是直接从Conn读取 + + underlayIsBasic bool + + hash string + isServerEnd bool //for v0 + + bufr *bufio.Reader //for udp + isntFirstPacket bool //for v0 + +} + +func (uc *UserTCPConn) Read(p []byte) (int, error) { + if uc.remainFirstBufLen > 0 { + n, err := uc.optionalReader.Read(p) + if n > 0 { + uc.remainFirstBufLen -= n + } + return n, err + } else { + return uc.Conn.Read(p) + } +} +func (uc *UserTCPConn) Write(p []byte) (int, error) { + return uc.Conn.Write(p) +} + +func (c *UserTCPConn) EverPossibleToSplice() bool { + + if netLayer.IsBasicConn(c.Conn) { + return true + } + if s, ok := c.Conn.(netLayer.Splicer); ok { + return s.EverPossibleToSplice() + } + return false +} + +func (c *UserTCPConn) CanSplice() (r bool, conn net.Conn) { + if c.remainFirstBufLen > 0 { + return false, nil + } + + if netLayer.IsBasicConn(c.Conn) { + r = true + conn = c.Conn + + } else if s, ok := c.Conn.(netLayer.Splicer); ok { + r, conn = s.CanSplice() + } + + return +} diff --git a/proxy/trojan/trojan.go b/proxy/trojan/trojan.go index f5af20f..67acad8 100644 --- a/proxy/trojan/trojan.go +++ b/proxy/trojan/trojan.go @@ -5,14 +5,19 @@ package trojan import ( "crypto/sha256" + "errors" "fmt" + "net" + + "github.com/hahahrfool/v2ray_simple/netLayer" + "github.com/hahahrfool/v2ray_simple/utils" ) const ( ATypIP4 = 0x1 ATypDomain = 0x3 ATypIP6 = 0x4 - name = "trojan" + Name = "trojan" ) const ( CmdConnect = 0x01 @@ -23,6 +28,13 @@ var ( crlf = []byte{0x0d, 0x0a} ) +func SHA224(password string) (r [28]byte) { + hash := sha256.New224() + hash.Write([]byte(password)) + copy(r[:], hash.Sum(nil)) + return +} + //trojan 的前56字节 是 sha224的28字节 每字节 转义成 ascii的 表示16进制的 两个字符 func SHA224_hexStringBytes(password string) []byte { hash := sha256.New224() @@ -34,3 +46,78 @@ func SHA224_hexStringBytes(password string) []byte { } return []byte(str) } + +//依照trojan协议的格式读取 地址的域名、ip、port信息 +func GetAddrFromReader(buf utils.ByteReader) (addr netLayer.Addr, err error) { + var b1 byte + b1, err = buf.ReadByte() + if err != nil { + return + } + switch b1 { + case ATypDomain: + var b2 byte + b2, err = buf.ReadByte() + if err != nil { + return + } + if b2 == 0 { + err = errors.New("got ATypDomain but domain lenth is marked to be 0") + return + } + bs := utils.GetBytes(int(b2)) + var n int + n, err = buf.Read(bs) + if err != nil { + return + } + if n != int(b2) { + err = utils.ErrShortRead + return + } + addr.Name = string(bs[:n]) + case ATypIP4: + bs := make([]byte, 4) + var n int + n, err = buf.Read(bs) + if err != nil { + return + } + if n != 4 { + err = utils.ErrShortRead + return + } + addr.IP = bs + case ATypIP6: + bs := make([]byte, net.IPv6len) + var n int + n, err = buf.Read(bs) + if err != nil { + return + } + if n != 4 { + err = utils.ErrShortRead + return + } + addr.IP = bs + default: + err = utils.ErrInvalidData + return + } + + pb1, err := buf.ReadByte() + if err != nil { + return + } + pb2, err := buf.ReadByte() + if err != nil { + return + } + port := uint16(pb1)<<8 + uint16(pb2) + if port == 0 { + err = utils.ErrInvalidData + return + } + addr.Port = int(port) + return +} diff --git a/proxy/trojan/udpConn.go b/proxy/trojan/udpConn.go index 9d69a9d..0b1e4e2 100644 --- a/proxy/trojan/udpConn.go +++ b/proxy/trojan/udpConn.go @@ -1,13 +1,31 @@ package trojan import ( + "bufio" + "io" "net" "github.com/hahahrfool/v2ray_simple/netLayer" + "github.com/hahahrfool/v2ray_simple/utils" ) type UDPConn struct { net.Conn + optionalReader io.Reader + + bufr *bufio.Reader +} + +func NewUDPConn(conn net.Conn, optionalReader io.Reader) (uc *UDPConn) { + uc = new(UDPConn) + uc.Conn = conn + if optionalReader != nil { + uc.optionalReader = optionalReader + uc.bufr = bufio.NewReader(optionalReader) + } else { + uc.bufr = bufio.NewReader(conn) + } + return } func (u UDPConn) Fullcone() bool { @@ -17,11 +35,62 @@ func (u UDPConn) CloseConnWithRaddr(raddr netLayer.Addr) error { return u.Close() } func (u UDPConn) ReadFrom() ([]byte, netLayer.Addr, error) { + addr, err := GetAddrFromReader(u.bufr) + if err != nil { + return nil, addr, err + } + addr.Network = "udp" - return nil, netLayer.Addr{}, nil + lb1, err := u.bufr.ReadByte() + if err != nil { + return nil, addr, err + } + lb2, err := u.bufr.ReadByte() + if err != nil { + return nil, addr, err + } + lenth := uint16(lb1)<<8 + uint16(lb2) + if lenth == 0 { + return nil, addr, utils.ErrInvalidData + } + + cr_b, err := u.bufr.ReadByte() + if err != nil { + return nil, addr, err + } + lf_b, err := u.bufr.ReadByte() + if err != nil { + return nil, addr, err + } + if cr_b != crlf[0] || lf_b != crlf[1] { + return nil, addr, utils.ErrInvalidData + } + + bs := utils.GetBytes(int(lenth)) + n, err := u.bufr.Read(bs) + if err != nil { + if n > 0 { + return bs[:n], addr, err + } + return nil, addr, err + } + + return bs[:n], addr, nil } -func (u UDPConn) WriteTo([]byte, netLayer.Addr) error { +func (u UDPConn) WriteTo(bs []byte, addr netLayer.Addr) error { - return nil + abs, atype := addr.AddressBytes() + buf := utils.GetBuf() + buf.WriteByte(atype) + buf.Write(abs) + buf.WriteByte(byte(len(bs) >> 8)) + buf.WriteByte(byte(len(bs) << 8 >> 8)) + buf.Write(crlf) + buf.Write(bs) + + _, err := u.Conn.Write(buf.Bytes()) + utils.PutBuf(buf) + + return err } diff --git a/proxy/vless/server.go b/proxy/vless/server.go index f8bd7b8..f9afd05 100644 --- a/proxy/vless/server.go +++ b/proxy/vless/server.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "net" "net/url" "sync" @@ -354,7 +353,6 @@ realPart: } if targetAddr.IsUDP() { - log.Println("targetAddr", targetAddr.IP, targetAddr.Name) return nil, &UDPConn{ Conn: underlay, version: int(version), diff --git a/proxy/vless/tcpconn.go b/proxy/vless/tcpconn.go index 1e332c6..309a5c7 100644 --- a/proxy/vless/tcpconn.go +++ b/proxy/vless/tcpconn.go @@ -88,9 +88,6 @@ func (c *UserTCPConn) WriteBuffers(buffers [][]byte) (int64, error) { //本作的 ws.Conn 实现了 utils.MultiWriter if c.underlayIsBasic { - //如果是基本Conn,则不用担心 WriteTo篡改buffers的问题, 因为它会直接调用底层 writev - //nb := net.Buffers(buffers) - //return nb.WriteTo(c.Conn) //发现它还是会篡改??什么鬼 return utils.BuffersWriteTo(buffers, c.Conn) } else if mr, ok := c.Conn.(utils.MultiWriter); ok { diff --git a/proxy/vless/udpConn.go b/proxy/vless/udpConn.go index 76cb8f1..ebaae9b 100644 --- a/proxy/vless/udpConn.go +++ b/proxy/vless/udpConn.go @@ -70,7 +70,6 @@ func (u *UDPConn) WriteTo(p []byte, raddr netLayer.Addr) error { } -//从 客户端读取 udp请求 func (u *UDPConn) ReadFrom() ([]byte, netLayer.Addr, error) { var from io.Reader = u.Conn diff --git a/utils/error.go b/utils/error.go index 2d7f900..de0eaf0 100644 --- a/utils/error.go +++ b/utils/error.go @@ -11,6 +11,8 @@ var ErrNotImplemented = errors.New("not implemented") var ErrNilParameter = errors.New("nil parameter") var ErrNilOrWrongParameter = errors.New("nil or wrong parameter") var ErrWrongParameter = errors.New("wrong parameter") +var ErrShortRead = errors.New("short read") +var ErrInvalidData = errors.New("invalid data") //没啥特殊的 type NumErr struct { diff --git a/utils/utils.go b/utils/utils.go index 0465d44..000a24c 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -17,6 +17,11 @@ func init() { rand.Seed(time.Now().Unix()) } +type ByteReader interface { + ReadByte() (byte, error) + Read(p []byte) (n int, err error) +} + func IsFlagPassed(name string) bool { found := false flag.Visit(func(f *flag.Flag) {