修订代码, 实现trojan协议

This commit is contained in:
hahahrfool
2022-04-08 22:44:01 +08:00
parent ce735dbb99
commit 3dc53554df
14 changed files with 436 additions and 16 deletions

View File

@@ -0,0 +1,13 @@
[[listen]]
protocol = "socks5"
host = "127.0.0.1"
port = 10800
[[dial]]
protocol = "trojans"
uuid = "a684455c-b14f-11ea-bf0d-42010aaa0003"
host = "127.0.0.1"
port = 4434
version = 0
insecure = true

View File

@@ -0,0 +1,14 @@
[[listen]]
protocol = "trojans"
uuid = "a684455c-b14f-11ea-bf0d-42010aaa0003"
host = "0.0.0.0"
port = 4434
version = 0
insecure = true
fallback = ":80"
cert = "cert.pem"
key = "cert.key"
[[dial]]
protocol = "direct"

View File

@@ -30,6 +30,8 @@ import (
_ "github.com/hahahrfool/v2ray_simple/proxy/direct" _ "github.com/hahahrfool/v2ray_simple/proxy/direct"
_ "github.com/hahahrfool/v2ray_simple/proxy/dokodemo" _ "github.com/hahahrfool/v2ray_simple/proxy/dokodemo"
_ "github.com/hahahrfool/v2ray_simple/proxy/http" _ "github.com/hahahrfool/v2ray_simple/proxy/http"
_ "github.com/hahahrfool/v2ray_simple/proxy/socks5"
_ "github.com/hahahrfool/v2ray_simple/proxy/trojan"
) )
const ( const (

View File

@@ -338,6 +338,7 @@ dialedPart:
// 如果a的ip不为空则会返回 AtypIP4 或 AtypIP6否则会返回 AtypDomain // 如果a的ip不为空则会返回 AtypIP4 或 AtypIP6否则会返回 AtypDomain
// Return address bytes and type // Return address bytes and type
// 如果atyp类型是 域名,则 第一字节为该域名的总长度, 其余字节为域名内容。 // 如果atyp类型是 域名,则 第一字节为该域名的总长度, 其余字节为域名内容。
// 如果类型是ip则会拷贝出该ip的数据的副本。
func (a *Addr) AddressBytes() ([]byte, byte) { func (a *Addr) AddressBytes() ([]byte, byte) {
var addr []byte var addr []byte
var atyp byte var atyp byte

View File

@@ -12,12 +12,16 @@ import (
"github.com/hahahrfool/v2ray_simple/utils" "github.com/hahahrfool/v2ray_simple/utils"
) )
func init() {
proxy.RegisterClient(Name, ClientCreator{})
}
//作为对照,可以参考 https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/client.go //作为对照,可以参考 https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/client.go
type ClientCreator struct{} type ClientCreator struct{}
func (_ ClientCreator) NewClientFromURL(u *url.URL) (proxy.Client, error) { func (_ ClientCreator) NewClientFromURL(u *url.URL) (proxy.Client, error) {
return nil, errors.New("not implemented") return nil, utils.ErrNotImplemented
} }
func (_ ClientCreator) NewClient(dc *proxy.DialConf) (proxy.Client, error) { func (_ ClientCreator) NewClient(dc *proxy.DialConf) (proxy.Client, error) {
@@ -37,10 +41,10 @@ type Client struct {
} }
func (c *Client) Name() string { func (c *Client) Name() string {
return name return Name
} }
func WriteTargetToBuf(target netLayer.Addr, buf *bytes.Buffer) { func WriteAddrToBuf(target netLayer.Addr, buf *bytes.Buffer) {
if len(target.IP) > 0 { if len(target.IP) > 0 {
if ip4 := target.IP.To4(); ip4 == nil { if ip4 := target.IP.To4(); ip4 == nil {
buf.WriteByte(netLayer.AtypIP6) buf.WriteByte(netLayer.AtypIP6)
@@ -69,7 +73,7 @@ func (c *Client) Handshake(underlay net.Conn, target netLayer.Addr) (io.ReadWrit
buf.Write(c.password_hexStringBytes) buf.Write(c.password_hexStringBytes)
buf.Write(crlf) buf.Write(crlf)
buf.WriteByte(CmdConnect) buf.WriteByte(CmdConnect)
WriteTargetToBuf(target, buf) WriteAddrToBuf(target, buf)
_, err := underlay.Write(buf.Bytes()) _, err := underlay.Write(buf.Bytes())
utils.PutBuf(buf) utils.PutBuf(buf)
@@ -89,12 +93,12 @@ func (c *Client) EstablishUDPChannel(underlay net.Conn, target netLayer.Addr) (n
buf.Write(c.password_hexStringBytes) buf.Write(c.password_hexStringBytes)
buf.Write(crlf) buf.Write(crlf)
buf.WriteByte(CmdUDPAssociate) buf.WriteByte(CmdUDPAssociate)
WriteTargetToBuf(target, buf) WriteAddrToBuf(target, buf)
_, err := underlay.Write(buf.Bytes()) _, err := underlay.Write(buf.Bytes())
utils.PutBuf(buf) utils.PutBuf(buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return UDPConn{underlay}, nil return NewUDPConn(underlay, nil), nil
} }

161
proxy/trojan/server.go Normal file
View File

@@ -0,0 +1,161 @@
package trojan
import (
"bytes"
"errors"
"io"
"net"
"net/url"
"time"
"github.com/hahahrfool/v2ray_simple/netLayer"
"github.com/hahahrfool/v2ray_simple/proxy"
"github.com/hahahrfool/v2ray_simple/utils"
)
func init() {
proxy.RegisterServer(Name, &ServerCreator{})
}
type Server struct {
proxy.ProxyCommonStruct
userHashes map[string]bool
//mux4Hashes sync.RWMutex
}
type ServerCreator struct{}
func (_ ServerCreator) NewServer(lc *proxy.ListenConf) (proxy.Server, error) {
uuidStr := lc.Uuid
s := &Server{
userHashes: make(map[string]bool),
}
s.userHashes[string(SHA224_hexStringBytes(uuidStr))] = true
return s, nil
}
func (_ ServerCreator) NewServerFromURL(u *url.URL) (proxy.Server, error) {
return nil, utils.ErrNotImplemented
}
func (s *Server) Name() string {
return Name
}
func (s *Server) Handshake(underlay net.Conn) (result io.ReadWriteCloser, msgConn netLayer.MsgConn, targetAddr netLayer.Addr, returnErr error) {
if err := underlay.SetReadDeadline(time.Now().Add(time.Second * 4)); err != nil {
returnErr = err
return
}
defer underlay.SetReadDeadline(time.Time{})
readbs := utils.GetBytes(utils.StandardBytesLength)
wholeReadLen, err := underlay.Read(readbs)
if err != nil {
returnErr = utils.ErrInErr{ErrDesc: "read err", ErrDetail: err, Data: wholeReadLen}
return
}
if wholeReadLen < 17 {
//根据下面回答HTTP的最小长度恰好是16字节但是是0.9版本。1.0是18字节1.1还要更长。总之我们可以直接不返回fallback地址
//https://stackoverflow.com/questions/25047905/http-request-minimum-size-in-bytes/25065089
returnErr = utils.ErrInErr{ErrDesc: "fallback, msg too short", Data: wholeReadLen}
return
}
readbuf := bytes.NewBuffer(readbs[:wholeReadLen])
goto realPart
errorPart:
//所返回的buffer必须包含所有数据而Buffer不支持回退所以只能重新New
returnErr = &utils.ErrFirstBuffer{
Err: returnErr,
First: bytes.NewBuffer(readbs[:wholeReadLen]),
}
return
realPart:
if wholeReadLen < 56+8+1 {
returnErr = utils.ErrInvalidData
goto errorPart
}
//可参考 https://github.com/p4gefau1t/trojan-go/blob/master/tunnel/trojan/server.go
hash := readbuf.Next(56)
hashStr := string(hash)
if !s.userHashes[hashStr] {
returnErr = errors.New("hash not match")
goto errorPart
}
crb, _ := readbuf.ReadByte()
lfb, _ := readbuf.ReadByte()
if crb != crlf[0] || lfb != crlf[1] {
returnErr = utils.ErrInvalidData
goto errorPart
}
cmdb, _ := readbuf.ReadByte()
var isudp bool
switch cmdb {
default:
returnErr = utils.ErrInvalidData
goto errorPart
case CmdConnect:
case CmdUDPAssociate:
isudp = true
}
targetAddr, err = GetAddrFromReader(readbuf)
if err != nil {
returnErr = err
goto errorPart
}
if isudp {
targetAddr.Network = "udp"
}
crb, err = readbuf.ReadByte()
if err != nil {
returnErr = err
goto errorPart
}
lfb, err = readbuf.ReadByte()
if err != nil {
returnErr = err
goto errorPart
}
if crb != crlf[0] || lfb != crlf[1] {
returnErr = utils.ErrInvalidData
goto errorPart
}
if isudp {
return nil, NewUDPConn(underlay, io.MultiReader(readbuf, underlay)), targetAddr, nil
} else {
if readbuf.Len() == 0 {
return underlay, nil, targetAddr, nil
} else {
return &UserTCPConn{
Conn: underlay,
optionalReader: io.MultiReader(readbuf, underlay),
remainFirstBufLen: readbuf.Len(),
hash: hashStr,
underlayIsBasic: netLayer.IsBasicConn(underlay),
isServerEnd: true,
}, nil, targetAddr, nil
}
}
}

68
proxy/trojan/tcpconn.go Normal file
View File

@@ -0,0 +1,68 @@
package trojan
import (
"bufio"
"io"
"net"
"github.com/hahahrfool/v2ray_simple/netLayer"
)
//trojan比较简洁这个 UserTCPConn 只是用于读取握手读取时读到的剩余的缓存
type UserTCPConn struct {
net.Conn
optionalReader io.Reader //在使用了缓存读取握手包头后就产生了buffer中有剩余数据的可能性此时就要使用MultiReader
remainFirstBufLen int //记录读取握手包头时读到的buf的长度. 如果我们读超过了这个部分的话,实际上我们就可以不再使用 optionalReader 读取, 而是直接从Conn读取
underlayIsBasic bool
hash string
isServerEnd bool //for v0
bufr *bufio.Reader //for udp
isntFirstPacket bool //for v0
}
func (uc *UserTCPConn) Read(p []byte) (int, error) {
if uc.remainFirstBufLen > 0 {
n, err := uc.optionalReader.Read(p)
if n > 0 {
uc.remainFirstBufLen -= n
}
return n, err
} else {
return uc.Conn.Read(p)
}
}
func (uc *UserTCPConn) Write(p []byte) (int, error) {
return uc.Conn.Write(p)
}
func (c *UserTCPConn) EverPossibleToSplice() bool {
if netLayer.IsBasicConn(c.Conn) {
return true
}
if s, ok := c.Conn.(netLayer.Splicer); ok {
return s.EverPossibleToSplice()
}
return false
}
func (c *UserTCPConn) CanSplice() (r bool, conn net.Conn) {
if c.remainFirstBufLen > 0 {
return false, nil
}
if netLayer.IsBasicConn(c.Conn) {
r = true
conn = c.Conn
} else if s, ok := c.Conn.(netLayer.Splicer); ok {
r, conn = s.CanSplice()
}
return
}

View File

@@ -5,14 +5,19 @@ package trojan
import ( import (
"crypto/sha256" "crypto/sha256"
"errors"
"fmt" "fmt"
"net"
"github.com/hahahrfool/v2ray_simple/netLayer"
"github.com/hahahrfool/v2ray_simple/utils"
) )
const ( const (
ATypIP4 = 0x1 ATypIP4 = 0x1
ATypDomain = 0x3 ATypDomain = 0x3
ATypIP6 = 0x4 ATypIP6 = 0x4
name = "trojan" Name = "trojan"
) )
const ( const (
CmdConnect = 0x01 CmdConnect = 0x01
@@ -23,6 +28,13 @@ var (
crlf = []byte{0x0d, 0x0a} crlf = []byte{0x0d, 0x0a}
) )
func SHA224(password string) (r [28]byte) {
hash := sha256.New224()
hash.Write([]byte(password))
copy(r[:], hash.Sum(nil))
return
}
//trojan 的前56字节 是 sha224的28字节 每字节 转义成 ascii的 表示16进制的 两个字符 //trojan 的前56字节 是 sha224的28字节 每字节 转义成 ascii的 表示16进制的 两个字符
func SHA224_hexStringBytes(password string) []byte { func SHA224_hexStringBytes(password string) []byte {
hash := sha256.New224() hash := sha256.New224()
@@ -34,3 +46,78 @@ func SHA224_hexStringBytes(password string) []byte {
} }
return []byte(str) return []byte(str)
} }
//依照trojan协议的格式读取 地址的域名、ip、port信息
func GetAddrFromReader(buf utils.ByteReader) (addr netLayer.Addr, err error) {
var b1 byte
b1, err = buf.ReadByte()
if err != nil {
return
}
switch b1 {
case ATypDomain:
var b2 byte
b2, err = buf.ReadByte()
if err != nil {
return
}
if b2 == 0 {
err = errors.New("got ATypDomain but domain lenth is marked to be 0")
return
}
bs := utils.GetBytes(int(b2))
var n int
n, err = buf.Read(bs)
if err != nil {
return
}
if n != int(b2) {
err = utils.ErrShortRead
return
}
addr.Name = string(bs[:n])
case ATypIP4:
bs := make([]byte, 4)
var n int
n, err = buf.Read(bs)
if err != nil {
return
}
if n != 4 {
err = utils.ErrShortRead
return
}
addr.IP = bs
case ATypIP6:
bs := make([]byte, net.IPv6len)
var n int
n, err = buf.Read(bs)
if err != nil {
return
}
if n != 4 {
err = utils.ErrShortRead
return
}
addr.IP = bs
default:
err = utils.ErrInvalidData
return
}
pb1, err := buf.ReadByte()
if err != nil {
return
}
pb2, err := buf.ReadByte()
if err != nil {
return
}
port := uint16(pb1)<<8 + uint16(pb2)
if port == 0 {
err = utils.ErrInvalidData
return
}
addr.Port = int(port)
return
}

View File

@@ -1,13 +1,31 @@
package trojan package trojan
import ( import (
"bufio"
"io"
"net" "net"
"github.com/hahahrfool/v2ray_simple/netLayer" "github.com/hahahrfool/v2ray_simple/netLayer"
"github.com/hahahrfool/v2ray_simple/utils"
) )
type UDPConn struct { type UDPConn struct {
net.Conn net.Conn
optionalReader io.Reader
bufr *bufio.Reader
}
func NewUDPConn(conn net.Conn, optionalReader io.Reader) (uc *UDPConn) {
uc = new(UDPConn)
uc.Conn = conn
if optionalReader != nil {
uc.optionalReader = optionalReader
uc.bufr = bufio.NewReader(optionalReader)
} else {
uc.bufr = bufio.NewReader(conn)
}
return
} }
func (u UDPConn) Fullcone() bool { func (u UDPConn) Fullcone() bool {
@@ -17,11 +35,62 @@ func (u UDPConn) CloseConnWithRaddr(raddr netLayer.Addr) error {
return u.Close() return u.Close()
} }
func (u UDPConn) ReadFrom() ([]byte, netLayer.Addr, error) { func (u UDPConn) ReadFrom() ([]byte, netLayer.Addr, error) {
addr, err := GetAddrFromReader(u.bufr)
if err != nil {
return nil, addr, err
}
addr.Network = "udp"
return nil, netLayer.Addr{}, nil lb1, err := u.bufr.ReadByte()
if err != nil {
return nil, addr, err
}
lb2, err := u.bufr.ReadByte()
if err != nil {
return nil, addr, err
}
lenth := uint16(lb1)<<8 + uint16(lb2)
if lenth == 0 {
return nil, addr, utils.ErrInvalidData
}
cr_b, err := u.bufr.ReadByte()
if err != nil {
return nil, addr, err
}
lf_b, err := u.bufr.ReadByte()
if err != nil {
return nil, addr, err
}
if cr_b != crlf[0] || lf_b != crlf[1] {
return nil, addr, utils.ErrInvalidData
}
bs := utils.GetBytes(int(lenth))
n, err := u.bufr.Read(bs)
if err != nil {
if n > 0 {
return bs[:n], addr, err
}
return nil, addr, err
}
return bs[:n], addr, nil
} }
func (u UDPConn) WriteTo([]byte, netLayer.Addr) error { func (u UDPConn) WriteTo(bs []byte, addr netLayer.Addr) error {
return nil abs, atype := addr.AddressBytes()
buf := utils.GetBuf()
buf.WriteByte(atype)
buf.Write(abs)
buf.WriteByte(byte(len(bs) >> 8))
buf.WriteByte(byte(len(bs) << 8 >> 8))
buf.Write(crlf)
buf.Write(bs)
_, err := u.Conn.Write(buf.Bytes())
utils.PutBuf(buf)
return err
} }

View File

@@ -6,7 +6,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"net/url" "net/url"
"sync" "sync"
@@ -354,7 +353,6 @@ realPart:
} }
if targetAddr.IsUDP() { if targetAddr.IsUDP() {
log.Println("targetAddr", targetAddr.IP, targetAddr.Name)
return nil, &UDPConn{ return nil, &UDPConn{
Conn: underlay, Conn: underlay,
version: int(version), version: int(version),

View File

@@ -88,9 +88,6 @@ func (c *UserTCPConn) WriteBuffers(buffers [][]byte) (int64, error) {
//本作的 ws.Conn 实现了 utils.MultiWriter //本作的 ws.Conn 实现了 utils.MultiWriter
if c.underlayIsBasic { if c.underlayIsBasic {
//如果是基本Conn则不用担心 WriteTo篡改buffers的问题, 因为它会直接调用底层 writev
//nb := net.Buffers(buffers)
//return nb.WriteTo(c.Conn) //发现它还是会篡改??什么鬼
return utils.BuffersWriteTo(buffers, c.Conn) return utils.BuffersWriteTo(buffers, c.Conn)
} else if mr, ok := c.Conn.(utils.MultiWriter); ok { } else if mr, ok := c.Conn.(utils.MultiWriter); ok {

View File

@@ -70,7 +70,6 @@ func (u *UDPConn) WriteTo(p []byte, raddr netLayer.Addr) error {
} }
//从 客户端读取 udp请求
func (u *UDPConn) ReadFrom() ([]byte, netLayer.Addr, error) { func (u *UDPConn) ReadFrom() ([]byte, netLayer.Addr, error) {
var from io.Reader = u.Conn var from io.Reader = u.Conn

View File

@@ -11,6 +11,8 @@ var ErrNotImplemented = errors.New("not implemented")
var ErrNilParameter = errors.New("nil parameter") var ErrNilParameter = errors.New("nil parameter")
var ErrNilOrWrongParameter = errors.New("nil or wrong parameter") var ErrNilOrWrongParameter = errors.New("nil or wrong parameter")
var ErrWrongParameter = errors.New("wrong parameter") var ErrWrongParameter = errors.New("wrong parameter")
var ErrShortRead = errors.New("short read")
var ErrInvalidData = errors.New("invalid data")
//没啥特殊的 //没啥特殊的
type NumErr struct { type NumErr struct {

View File

@@ -17,6 +17,11 @@ func init() {
rand.Seed(time.Now().Unix()) rand.Seed(time.Now().Unix())
} }
type ByteReader interface {
ReadByte() (byte, error)
Read(p []byte) (n int, err error)
}
func IsFlagPassed(name string) bool { func IsFlagPassed(name string) bool {
found := false found := false
flag.Visit(func(f *flag.Flag) { flag.Visit(func(f *flag.Flag) {