diff --git a/advLayer/ws/conn.go b/advLayer/ws/conn.go index aa0190d..f130c92 100644 --- a/advLayer/ws/conn.go +++ b/advLayer/ws/conn.go @@ -15,7 +15,6 @@ import ( // 因此我们包装一下,统一使用Read和Write函数 来读写 二进制数据。因为我们这里是代理, type Conn struct { net.Conn - first_nextFrameCalled bool state ws.State r *wsutil.Reader diff --git a/advLayer/ws/server.go b/advLayer/ws/server.go index abfd324..1daeda8 100644 --- a/advLayer/ws/server.go +++ b/advLayer/ws/server.go @@ -1,6 +1,7 @@ package ws import ( + "bytes" "encoding/base64" "io" "net" @@ -19,6 +20,11 @@ import ( // 683 * 4 = 2732, 若你不信,运行 we_test.go中的 TestBase64Len const MaxEarlyDataLen_Base64 = 2732 +var ( + connectionBs = []byte("Connection") + upgradeBs = []byte("Upgrade") +) + //implements advLayer.SingleServer type Server struct { Creator @@ -108,7 +114,36 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) { optionalFirstBuffer := rp.WholeRequestBuf - if rp.Method != "GET" || s.Thepath != rp.Path { + notWsRequest := false + + //因为 gobwas 会先自行给错误的连接 返回 错误信息,而这不行,所以我们先过滤一遍。 + //header 我们只过滤一个 connection 就行. 要是怕攻击者用 “对的path,method 和错误的header” 进行探测, + // 那你设一个复杂的path就ok了。 + + if rp.Method != "GET" || s.Thepath != rp.Path || len(rp.Headers) == 0 { + notWsRequest = true + + } else { + hasUpgrade := false + for _, rh := range rp.Headers { + httpLayer.CanonicalizeHeaderKey(rh.Head) + if bytes.Equal(rh.Head, connectionBs) { + + httpLayer.CanonicalizeHeaderKey(rh.Value) + if bytes.Equal(rh.Value, upgradeBs) { + + hasUpgrade = true + break + } + } + } + if !hasUpgrade { + notWsRequest = true + } + + } + + if notWsRequest { return httpLayer.FallbackMeta{ Conn: underlay, H1RequestBuf: optionalFirstBuffer, @@ -117,7 +152,6 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) { }, httpLayer.ErrShouldFallback } - theWrongPath := "" var thePotentialEarlyData []byte requestHeaderNotGivenCount := s.requestHeaderCheckCount @@ -134,27 +168,6 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) { //我们这里就是先用 httpLayer 过滤 再和 buffer一起传入 ws 包 // ReadBufferSize默认是 4096,已经够大 - OnRequest: func(uri []byte) error { - struri := string(uri) - if struri != s.Thepath { - - theWrongPath = struri - - //return utils.NewDataErr("ws path not match", nil, struri[:min]) - //发现这个错误除了在程序里返回外,还会直接显示到 浏览器上!这会被探测到的。 - // 所以只能显示标准http错误, 然后通过闭包的方式 把path信息传递到外部. - if ce := utils.CanLogWarn("ws path not match"); ce != nil { - min := len(s.Thepath) - if len(struri) < min { - min = len(struri) - } - //log.Println("ws path not match", struri[:min]) - ce.Write(zap.String("wrong path", struri[:min])) - } - return ws.RejectConnectionError(ws.RejectionStatus(http.StatusBadRequest)) - } - return nil - }, OnHeader: func(key, value []byte) error { if s.noNeedToCheckRequestHeaders { return nil @@ -229,11 +242,13 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) { _, err := theUpgrader.Upgrade(rw) if err != nil { - if len(theWrongPath) > 0 { - //ws的Method必为Get, 否则 gobwas 会返回 ErrHandshakeBadMethod - return nil, &httpLayer.RequestErr{Path: theWrongPath} - } - return nil, err + + return httpLayer.FallbackMeta{ + Conn: underlay, + H1RequestBuf: optionalFirstBuffer, + Path: rp.Path, + Method: rp.Method, + }, httpLayer.ErrShouldFallback } theConn := &Conn{ diff --git a/cmd/verysimple/Makefile b/cmd/verysimple/Makefile index 343a13a..5c90546 100644 --- a/cmd/verysimple/Makefile +++ b/cmd/verysimple/Makefile @@ -158,6 +158,7 @@ win10: clean: rm -f ${prefix} rm -f ${prefix}.exe + rm -f BUILD_VERSION rm -f $(linuxAmdFn) rm -f $(linuxArmFn) diff --git a/httpLayer/fallback.go b/httpLayer/fallback.go index 079153a..7db657f 100644 --- a/httpLayer/fallback.go +++ b/httpLayer/fallback.go @@ -24,7 +24,7 @@ const ( all_non_default_fallbacktype_count = iota - 2 - alpn_unspecified = 0 + //alpn_unspecified = 0 ) const ( diff --git a/httpLayer/fallbackConditionSet.go b/httpLayer/fallbackConditionSet.go index 010eeb8..3bff4b7 100644 --- a/httpLayer/fallbackConditionSet.go +++ b/httpLayer/fallbackConditionSet.go @@ -79,7 +79,6 @@ func (fcs *FallbackConditionSet) setSingle(t byte, s string, b byte) { case Fallback_alpn: fcs.AlpnMask = b } - return } func (fcs *FallbackConditionSet) setSingleByInt(t int, s string, b byte) { @@ -93,7 +92,6 @@ func (fcs *FallbackConditionSet) setSingleByInt(t int, s string, b byte) { fcs.Sni = s } - return } func (fcs *FallbackConditionSet) extractSingle(t byte) (r FallbackConditionSet) { diff --git a/httpLayer/fallback_test.go b/httpLayer/fallback_test.go index 5f3ad4b..79ab9f5 100644 --- a/httpLayer/fallback_test.go +++ b/httpLayer/fallback_test.go @@ -16,7 +16,8 @@ var testf = httpLayer.FallbackConditionSet{ var testMap = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult) var testMap2 = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult) -var testMap3 = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult) + +//var testMap3 = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult) const map2Mask = httpLayer.Fallback_sni | httpLayer.Fallback_alpn diff --git a/httpLayer/h1_requestfilter.go b/httpLayer/h1_requestfilter.go index 73ce0be..38ddd69 100644 --- a/httpLayer/h1_requestfilter.go +++ b/httpLayer/h1_requestfilter.go @@ -1,12 +1,21 @@ package httpLayer +import ( + "bytes" +) + +type RawHeader struct { + Head []byte + Value []byte +} + // 从数据中试图获取 http请求的 path,和 method. -// failreason>0 表示获取失败. 不会返回小于0的值。 +// failreason!=0 表示获取失败. // 同时可以用这个方法判断明文 是不是 http1.1, http1.0, http0.9的 http请求。 // 如果是http代理的话,判断方式会有变化,所以需要 isproxy 参数。 // // 此方法亦可用于 判断一个http请求头部是否合法。 -func GetH1RequestMethod_and_PATH_from_Bytes(bs []byte, isproxy bool) (version, method string, path string, failreason int) { +func ParseH1Request(bs []byte, isproxy bool) (version, method string, path string, headers []RawHeader, failreason int) { if len(bs) < 16 { //http0.9 最小长度为16, http1.0及1.1最小长度为18 failreason = 1 @@ -131,8 +140,43 @@ func GetH1RequestMethod_and_PATH_from_Bytes(bs []byte, isproxy bool) (version, m return } - version = string(bs[i+6 : i+9]) path = string(bs[shouldSlashIndex:i]) + + if string(bs[i+1:i+5]) != "HTTP" { + failreason = -10 + return + } + + version = string(bs[i+6 : i+9]) + if bs[i+9] != '\r' || bs[i+10] != '\n' { + failreason = -11 + return + } + + leftBs := bs[i+11:] + + indexOfEnding := bytes.Index(leftBs, HeaderENDING_bytes) + if indexOfEnding < 0 { + failreason = -12 + return + + } + headerBytes := leftBs[:indexOfEnding] + headerBytesList := bytes.Split(headerBytes, []byte(CRLF)) + for _, header := range headerBytesList { + + ss := bytes.SplitN(header, []byte(":"), 2) + if len(ss) != 2 { + failreason = -13 + return + } + headers = append(headers, RawHeader{ + Head: bytes.TrimLeft(ss[0], " "), + Value: bytes.TrimLeft(ss[1], " "), + }) + + } + return } } diff --git a/httpLayer/h1_requestfilter_test.go b/httpLayer/h1_requestfilter_test.go index f5854ce..dc54e8f 100644 --- a/httpLayer/h1_requestfilter_test.go +++ b/httpLayer/h1_requestfilter_test.go @@ -5,28 +5,28 @@ import "testing" func TestGetPath(t *testing.T) { str1 := "GET /sdfdsffs.html HTTP/0.9\r\n" - _, method, p1, falreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str1), false) + _, method, p1, _, falreason := ParseH1Request([]byte(str1), false) if p1 != "/sdfdsffs.html" || method != "GET" { t.Log("get path failed", p1, len(str1), falreason, len("/sdfdsffs.html")) t.FailNow() } str2 := "CONNECT /sdfdsffs.html HTTP/0.9\r\n" - _, _, p2, falreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str2), false) + _, _, p2, _, falreason := ParseH1Request([]byte(str2), false) if p2 != "/sdfdsffs.html" { t.Log("get path failed", falreason, p2) t.FailNow() } str3 := "GET /x HTTP/0.9\r" - _, _, p3, falreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str3), false) + _, _, p3, _, falreason := ParseH1Request([]byte(str3), false) if p3 == "/x" { //尾缀长度不够 t.Log("get path failed", len(str3), falreason, p3) t.FailNow() } str3 = "GET /x HTTP/0.9\r\n" - _, _, p3, falreason = GetH1RequestMethod_and_PATH_from_Bytes([]byte(str3), false) + _, _, p3, _, falreason = ParseH1Request([]byte(str3), false) if p3 != "/x" { t.Log("get path failed", len(str3), falreason, p3) t.FailNow() @@ -36,7 +36,7 @@ func TestGetPath(t *testing.T) { str4 := "GET " + requestStr + " HTTP/1.1\r\n" - _, _, p4, failreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str4), true) + _, _, p4, _, failreason := ParseH1Request([]byte(str4), true) if p4 != requestStr { t.Log("get path failed", len(str4), failreason, p4) t.FailNow() diff --git a/httpLayer/header.go b/httpLayer/header.go index 6e7fc1f..0e3c29d 100644 --- a/httpLayer/header.go +++ b/httpLayer/header.go @@ -13,6 +13,11 @@ import ( "golang.org/x/exp/slices" ) +const ( + toLower = 'a' - 'A' // for use with OR. + toUpper = ^byte(toLower) // for use with AND. +) + //return a clone of m with headers trimmed to one value func TrimHeaders(m map[string][]string) (result map[string][]string) { @@ -24,6 +29,20 @@ func TrimHeaders(m map[string][]string) (result map[string][]string) { return } +// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except +// that it operates with slice of bytes and modifies it inplace without copying. copied from gobwas/ws +func CanonicalizeHeaderKey(k []byte) { + upper := true + for i, c := range k { + if upper && 'a' <= c && c <= 'z' { + k[i] &= toUpper + } else if !upper && 'A' <= c && c <= 'Z' { + k[i] |= toLower + } + upper = c == '-' + } +} + //all values in template is given by real func AllHeadersIn(template map[string][]string, realh http.Header) (ok bool, firstNotMatchKey string) { for k, vs := range template { @@ -151,7 +170,7 @@ func (h *HeaderPreset) AssignDefaultValue() { h.Prepare() } -func (h *HeaderPreset) ReadRequest(underlay net.Conn) (err error, leftBuf *bytes.Buffer) { +func (h *HeaderPreset) ReadRequest(underlay net.Conn) (leftBuf *bytes.Buffer, err error) { var rp H1RequestParser err = rp.ReadAndParse(underlay) @@ -190,7 +209,7 @@ func (h *HeaderPreset) ReadRequest(underlay net.Conn) (err error, leftBuf *bytes for _, header := range headerBytesList { //log.Println("ReadRequest read header", string(h)) hs := string(header) - ss := strings.Split(hs, ":") + ss := strings.SplitN(hs, ":", 2) if len(ss) != 2 { err = utils.ErrInvalidData return @@ -234,7 +253,7 @@ func (h *HeaderPreset) ReadRequest(underlay net.Conn) (err error, leftBuf *bytes rp.WholeRequestBuf.Next(4) - return nil, rp.WholeRequestBuf + return rp.WholeRequestBuf, nil } func (h *HeaderPreset) WriteRequest(underlay net.Conn, payload []byte) error { @@ -255,7 +274,7 @@ func (h *HeaderPreset) WriteRequest(underlay net.Conn, payload []byte) error { return r.Write(underlay) } -func (h *HeaderPreset) ReadResponse(underlay net.Conn) (err error, leftBuf *bytes.Buffer) { +func (h *HeaderPreset) ReadResponse(underlay net.Conn) (leftBuf *bytes.Buffer, err error) { bs := utils.GetPacket() var n int @@ -278,7 +297,7 @@ func (h *HeaderPreset) ReadResponse(underlay net.Conn) (err error, leftBuf *byte buf := bytes.NewBuffer(bs[indexOfEnding+4 : n]) - return nil, buf + return buf, nil } func (h *HeaderPreset) WriteResponse(underlay net.Conn, payload []byte) error { @@ -330,7 +349,7 @@ func (c *HeaderConn) Read(p []byte) (n int, err error) { if c.IsServerEnd { if c.optionalReader == nil { - err, buf = c.H.ReadRequest(c.Conn) + buf, err = c.H.ReadRequest(c.Conn) if err != nil { err = utils.ErrInErr{ErrDesc: "http HeaderConn Read failed, at serverEnd", ErrDetail: err} return @@ -341,7 +360,7 @@ func (c *HeaderConn) Read(p []byte) (n int, err error) { } else { if c.optionalReader == nil { - err, buf = c.H.ReadResponse(c.Conn) + buf, err = c.H.ReadResponse(c.Conn) if err != nil { err = utils.ErrInErr{ErrDesc: "http HeaderConn Read failed", ErrDetail: err} return diff --git a/httpLayer/httpLayer.go b/httpLayer/httpLayer.go index b270a2f..70e10b2 100644 --- a/httpLayer/httpLayer.go +++ b/httpLayer/httpLayer.go @@ -119,6 +119,7 @@ type H1RequestParser struct { Method string WholeRequestBuf *bytes.Buffer Failreason int //为0表示没错误 + Headers []RawHeader } // 尝试读取数据并解析HTTP请求, 解析道道 数据会存入 RequestParser 结构中. @@ -134,7 +135,7 @@ func (rhr *H1RequestParser) ReadAndParse(r io.Reader) error { buf := bytes.NewBuffer(data) rhr.WholeRequestBuf = buf - rhr.Version, rhr.Method, rhr.Path, rhr.Failreason = GetH1RequestMethod_and_PATH_from_Bytes(data, false) + rhr.Version, rhr.Method, rhr.Path, rhr.Headers, rhr.Failreason = ParseH1Request(data, false) if rhr.Failreason != 0 { return utils.ErrInErr{ErrDesc: "httpLayer ReadAndParse failed", ErrDetail: ErrNotHTTP_Request, Data: rhr.Failreason} } diff --git a/iics.go b/iics.go index 07d0188..1ed9504 100644 --- a/iics.go +++ b/iics.go @@ -105,7 +105,7 @@ func checkfallback(iics incomingInserverConnState) (targetAddr netLayer.Addr, re if iics.fallbackFirstBuffer != nil && theRequestPath == "" { var failreason int - _, _, theRequestPath, failreason = httpLayer.GetH1RequestMethod_and_PATH_from_Bytes(iics.fallbackFirstBuffer.Bytes(), false) + _, _, theRequestPath, _, failreason = httpLayer.ParseH1Request(iics.fallbackFirstBuffer.Bytes(), false) if failreason != 0 { theRequestPath = "" @@ -164,7 +164,7 @@ func checkfallback(iics incomingInserverConnState) (targetAddr netLayer.Addr, re } - //默认回落, 每个listen配置 都可 有一个自己独享的默认回落 + //默认回落, 每个listen配置 都 有一个自己独享的默认回落配置 (fallback = 80 这种) if defaultFallbackAddr := iics.inServer.GetFallback(); defaultFallbackAddr != nil { @@ -176,6 +176,10 @@ func checkfallback(iics incomingInserverConnState) (targetAddr netLayer.Addr, re targetAddr = *defaultFallbackAddr result = 0 + } else { + + result = -1 } + return } diff --git a/proxy/http/server.go b/proxy/http/server.go index 6dcec42..501176b 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -64,7 +64,7 @@ func (s *Server) Handshake(underlay net.Conn) (newconn net.Conn, _ netLayer.MsgC // "CONNECT is intended only for use in requests to a proxy. " 总之CONNECT命令专门用于代理. // GET如果 path也是带 http:// 头的话,也是可以的,但是这种只适用于http代理,无法用于https。 - _, method, path, failreason := httpLayer.GetH1RequestMethod_and_PATH_from_Bytes(bs[:n], true) + _, method, path, _, failreason := httpLayer.ParseH1Request(bs[:n], true) if failreason != 0 { err = utils.ErrInErr{ErrDesc: "get method/path failed", ErrDetail: utils.ErrInvalidData, Data: []any{method, failreason}} @@ -102,7 +102,7 @@ func (s *Server) Handshake(underlay net.Conn) (newconn net.Conn, _ netLayer.MsgC } addressStr = hostPortURL.Host - if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口, 默认80 + if !strings.Contains(hostPortURL.Host, ":") { //host不带端口, 默认80 addressStr = hostPortURL.Host + ":80" } } diff --git a/tlsLayer/server.go b/tlsLayer/server.go index 0750b30..38f1c64 100644 --- a/tlsLayer/server.go +++ b/tlsLayer/server.go @@ -2,7 +2,6 @@ package tlsLayer import ( "crypto/tls" - "log" "net" "unsafe" @@ -23,7 +22,7 @@ func NewServer(host, certFile, keyFile string, isInsecure bool, alpnList []strin return nil, err } - //发现服务端必须给出 http/1.1 等,否则不会协商出这个alpn,而我们为了回落,是需要协商出所有可能需要等 alpn的。 + //发现服务端必须给出 http/1.1 等,否则不会协商出这个alpn,而我们为了回落,是需要协商出所有可能需要的 alpn的。 //而且我们如果不提供 h1 和 h2 的alpn的话,很容易被审查者察觉的。 @@ -39,8 +38,6 @@ func NewServer(host, certFile, keyFile string, isInsecure bool, alpnList []strin } } - log.Println("NewServer", alpnList) - s := &Server{ tlsConfig: &tls.Config{ InsecureSkipVerify: isInsecure, diff --git a/utils/pool.go b/utils/pool.go index b90278c..38e3771 100644 --- a/utils/pool.go +++ b/utils/pool.go @@ -6,7 +6,7 @@ import ( ) var ( - mtuPool sync.Pool //专门储存 长度为 MTU 的 []byte + mtuPool sync.Pool //专门储存 长度为 MTU 的 *[]byte, 注意这里存储的是指针 // packetPool 专门储存 长度为 MaxPacketLen 的 []byte // @@ -34,13 +34,15 @@ func init() { mtuPool = sync.Pool{ New: func() interface{} { - return make([]byte, MTU) + bs := make([]byte, MTU) + return &bs }, } packetPool = sync.Pool{ New: func() interface{} { - return make([]byte, MaxPacketLen) + bs := make([]byte, MaxPacketLen) + return &bs }, } @@ -64,7 +66,8 @@ func PutBuf(buf *bytes.Buffer) { //建议在 Read net.Conn 时, 使用 GetPacket函数 获取到足够大的 []byte(MaxBufLen) func GetPacket() []byte { - return packetPool.Get().([]byte) + bsPtr := packetPool.Get().(*[]byte) + return *bsPtr } // 放回用 GetPacket 获取的 []byte @@ -72,23 +75,27 @@ func PutPacket(bs []byte) { c := cap(bs) if c < MaxPacketLen { if c >= MTU { - mtuPool.Put(bs[:MTU]) + bs = bs[:MTU] + mtuPool.Put(&bs) } return } - packetPool.Put(bs[:MaxPacketLen]) + bs = bs[:MaxPacketLen] + packetPool.Put(&bs) } // 从Pool中获取一个 MTU 长度的 []byte func GetMTU() []byte { - return mtuPool.Get().([]byte) + bs := mtuPool.Get().(*[]byte) + return *bs } // 从pool中获取 []byte, 根据给出长度不同,来源于的Pool会不同. func GetBytes(size int) []byte { if size <= MTU { - bs := mtuPool.Get().([]byte) + bsPtr := mtuPool.Get().(*[]byte) + bs := *bsPtr return bs[:size] } @@ -103,8 +110,13 @@ func PutBytes(bs []byte) { return } else if c < MaxPacketLen { - mtuPool.Put(bs[:MTU]) + if c != MTU { + bs = bs[:MTU] + + } + mtuPool.Put(&bs) } else { - packetPool.Put(bs[:MaxPacketLen]) + bs = bs[:MaxPacketLen] + packetPool.Put(&bs) } }