修订代码; 完善ws; 令Pool使用指针,而不是slice

令 websocket在path访问正确但是不是ws连接时,也进行回落,而不是返回一个错误

将 GetH1RequestMethod_and_PATH_from_Bytes 改名为 ParseH1Request, 且支持 读取header

同时新增了 RawHeader 结构 用于 上述目的。httpLayer还添加了 CanonicalizeHeaderKey 方法。

令Pool使用指针 后,测速从 3200左右上升至3800左右,也不知道是不是这个优化导致的。如果是的话,那也太猛了。
This commit is contained in:
e1732a364fed
2022-05-07 09:51:45 +08:00
parent dface33524
commit 3e7e779920
14 changed files with 158 additions and 67 deletions

View File

@@ -15,7 +15,6 @@ import (
// 因此我们包装一下统一使用Read和Write函数 来读写 二进制数据。因为我们这里是代理,
type Conn struct {
net.Conn
first_nextFrameCalled bool
state ws.State
r *wsutil.Reader

View File

@@ -1,6 +1,7 @@
package ws
import (
"bytes"
"encoding/base64"
"io"
"net"
@@ -19,6 +20,11 @@ import (
// 683 * 4 = 2732, 若你不信,运行 we_test.go中的 TestBase64Len
const MaxEarlyDataLen_Base64 = 2732
var (
connectionBs = []byte("Connection")
upgradeBs = []byte("Upgrade")
)
//implements advLayer.SingleServer
type Server struct {
Creator
@@ -108,7 +114,36 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) {
optionalFirstBuffer := rp.WholeRequestBuf
if rp.Method != "GET" || s.Thepath != rp.Path {
notWsRequest := false
//因为 gobwas 会先自行给错误的连接 返回 错误信息,而这不行,所以我们先过滤一遍。
//header 我们只过滤一个 connection 就行. 要是怕攻击者用 “对的path,method 和错误的header” 进行探测,
// 那你设一个复杂的path就ok了。
if rp.Method != "GET" || s.Thepath != rp.Path || len(rp.Headers) == 0 {
notWsRequest = true
} else {
hasUpgrade := false
for _, rh := range rp.Headers {
httpLayer.CanonicalizeHeaderKey(rh.Head)
if bytes.Equal(rh.Head, connectionBs) {
httpLayer.CanonicalizeHeaderKey(rh.Value)
if bytes.Equal(rh.Value, upgradeBs) {
hasUpgrade = true
break
}
}
}
if !hasUpgrade {
notWsRequest = true
}
}
if notWsRequest {
return httpLayer.FallbackMeta{
Conn: underlay,
H1RequestBuf: optionalFirstBuffer,
@@ -117,7 +152,6 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) {
}, httpLayer.ErrShouldFallback
}
theWrongPath := ""
var thePotentialEarlyData []byte
requestHeaderNotGivenCount := s.requestHeaderCheckCount
@@ -134,27 +168,6 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) {
//我们这里就是先用 httpLayer 过滤 再和 buffer一起传入 ws 包
// ReadBufferSize默认是 4096已经够大
OnRequest: func(uri []byte) error {
struri := string(uri)
if struri != s.Thepath {
theWrongPath = struri
//return utils.NewDataErr("ws path not match", nil, struri[:min])
//发现这个错误除了在程序里返回外,还会直接显示到 浏览器上!这会被探测到的。
// 所以只能显示标准http错误, 然后通过闭包的方式 把path信息传递到外部.
if ce := utils.CanLogWarn("ws path not match"); ce != nil {
min := len(s.Thepath)
if len(struri) < min {
min = len(struri)
}
//log.Println("ws path not match", struri[:min])
ce.Write(zap.String("wrong path", struri[:min]))
}
return ws.RejectConnectionError(ws.RejectionStatus(http.StatusBadRequest))
}
return nil
},
OnHeader: func(key, value []byte) error {
if s.noNeedToCheckRequestHeaders {
return nil
@@ -229,11 +242,13 @@ func (s *Server) Handshake(underlay net.Conn) (net.Conn, error) {
_, err := theUpgrader.Upgrade(rw)
if err != nil {
if len(theWrongPath) > 0 {
//ws的Method必为Get, 否则 gobwas 会返回 ErrHandshakeBadMethod
return nil, &httpLayer.RequestErr{Path: theWrongPath}
}
return nil, err
return httpLayer.FallbackMeta{
Conn: underlay,
H1RequestBuf: optionalFirstBuffer,
Path: rp.Path,
Method: rp.Method,
}, httpLayer.ErrShouldFallback
}
theConn := &Conn{

View File

@@ -158,6 +158,7 @@ win10:
clean:
rm -f ${prefix}
rm -f ${prefix}.exe
rm -f BUILD_VERSION
rm -f $(linuxAmdFn)
rm -f $(linuxArmFn)

View File

@@ -24,7 +24,7 @@ const (
all_non_default_fallbacktype_count = iota - 2
alpn_unspecified = 0
//alpn_unspecified = 0
)
const (

View File

@@ -79,7 +79,6 @@ func (fcs *FallbackConditionSet) setSingle(t byte, s string, b byte) {
case Fallback_alpn:
fcs.AlpnMask = b
}
return
}
func (fcs *FallbackConditionSet) setSingleByInt(t int, s string, b byte) {
@@ -93,7 +92,6 @@ func (fcs *FallbackConditionSet) setSingleByInt(t int, s string, b byte) {
fcs.Sni = s
}
return
}
func (fcs *FallbackConditionSet) extractSingle(t byte) (r FallbackConditionSet) {

View File

@@ -16,7 +16,8 @@ var testf = httpLayer.FallbackConditionSet{
var testMap = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult)
var testMap2 = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult)
var testMap3 = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult)
//var testMap3 = make(map[httpLayer.FallbackConditionSet]*httpLayer.FallbackResult)
const map2Mask = httpLayer.Fallback_sni | httpLayer.Fallback_alpn

View File

@@ -1,12 +1,21 @@
package httpLayer
import (
"bytes"
)
type RawHeader struct {
Head []byte
Value []byte
}
// 从数据中试图获取 http请求的 path,和 method.
// failreason>0 表示获取失败. 不会返回小于0的值。
// failreason!=0 表示获取失败.
// 同时可以用这个方法判断明文 是不是 http1.1, http1.0, http0.9的 http请求。
// 如果是http代理的话判断方式会有变化,所以需要 isproxy 参数。
//
// 此方法亦可用于 判断一个http请求头部是否合法。
func GetH1RequestMethod_and_PATH_from_Bytes(bs []byte, isproxy bool) (version, method string, path string, failreason int) {
func ParseH1Request(bs []byte, isproxy bool) (version, method string, path string, headers []RawHeader, failreason int) {
if len(bs) < 16 { //http0.9 最小长度为16 http1.0及1.1最小长度为18
failreason = 1
@@ -131,8 +140,43 @@ func GetH1RequestMethod_and_PATH_from_Bytes(bs []byte, isproxy bool) (version, m
return
}
version = string(bs[i+6 : i+9])
path = string(bs[shouldSlashIndex:i])
if string(bs[i+1:i+5]) != "HTTP" {
failreason = -10
return
}
version = string(bs[i+6 : i+9])
if bs[i+9] != '\r' || bs[i+10] != '\n' {
failreason = -11
return
}
leftBs := bs[i+11:]
indexOfEnding := bytes.Index(leftBs, HeaderENDING_bytes)
if indexOfEnding < 0 {
failreason = -12
return
}
headerBytes := leftBs[:indexOfEnding]
headerBytesList := bytes.Split(headerBytes, []byte(CRLF))
for _, header := range headerBytesList {
ss := bytes.SplitN(header, []byte(":"), 2)
if len(ss) != 2 {
failreason = -13
return
}
headers = append(headers, RawHeader{
Head: bytes.TrimLeft(ss[0], " "),
Value: bytes.TrimLeft(ss[1], " "),
})
}
return
}
}

View File

@@ -5,28 +5,28 @@ import "testing"
func TestGetPath(t *testing.T) {
str1 := "GET /sdfdsffs.html HTTP/0.9\r\n"
_, method, p1, falreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str1), false)
_, method, p1, _, falreason := ParseH1Request([]byte(str1), false)
if p1 != "/sdfdsffs.html" || method != "GET" {
t.Log("get path failed", p1, len(str1), falreason, len("/sdfdsffs.html"))
t.FailNow()
}
str2 := "CONNECT /sdfdsffs.html HTTP/0.9\r\n"
_, _, p2, falreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str2), false)
_, _, p2, _, falreason := ParseH1Request([]byte(str2), false)
if p2 != "/sdfdsffs.html" {
t.Log("get path failed", falreason, p2)
t.FailNow()
}
str3 := "GET /x HTTP/0.9\r"
_, _, p3, falreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str3), false)
_, _, p3, _, falreason := ParseH1Request([]byte(str3), false)
if p3 == "/x" { //尾缀长度不够
t.Log("get path failed", len(str3), falreason, p3)
t.FailNow()
}
str3 = "GET /x HTTP/0.9\r\n"
_, _, p3, falreason = GetH1RequestMethod_and_PATH_from_Bytes([]byte(str3), false)
_, _, p3, _, falreason = ParseH1Request([]byte(str3), false)
if p3 != "/x" {
t.Log("get path failed", len(str3), falreason, p3)
t.FailNow()
@@ -36,7 +36,7 @@ func TestGetPath(t *testing.T) {
str4 := "GET " + requestStr + " HTTP/1.1\r\n"
_, _, p4, failreason := GetH1RequestMethod_and_PATH_from_Bytes([]byte(str4), true)
_, _, p4, _, failreason := ParseH1Request([]byte(str4), true)
if p4 != requestStr {
t.Log("get path failed", len(str4), failreason, p4)
t.FailNow()

View File

@@ -13,6 +13,11 @@ import (
"golang.org/x/exp/slices"
)
const (
toLower = 'a' - 'A' // for use with OR.
toUpper = ^byte(toLower) // for use with AND.
)
//return a clone of m with headers trimmed to one value
func TrimHeaders(m map[string][]string) (result map[string][]string) {
@@ -24,6 +29,20 @@ func TrimHeaders(m map[string][]string) (result map[string][]string) {
return
}
// Algorithm below is like standard textproto/CanonicalMIMEHeaderKey, except
// that it operates with slice of bytes and modifies it inplace without copying. copied from gobwas/ws
func CanonicalizeHeaderKey(k []byte) {
upper := true
for i, c := range k {
if upper && 'a' <= c && c <= 'z' {
k[i] &= toUpper
} else if !upper && 'A' <= c && c <= 'Z' {
k[i] |= toLower
}
upper = c == '-'
}
}
//all values in template is given by real
func AllHeadersIn(template map[string][]string, realh http.Header) (ok bool, firstNotMatchKey string) {
for k, vs := range template {
@@ -151,7 +170,7 @@ func (h *HeaderPreset) AssignDefaultValue() {
h.Prepare()
}
func (h *HeaderPreset) ReadRequest(underlay net.Conn) (err error, leftBuf *bytes.Buffer) {
func (h *HeaderPreset) ReadRequest(underlay net.Conn) (leftBuf *bytes.Buffer, err error) {
var rp H1RequestParser
err = rp.ReadAndParse(underlay)
@@ -190,7 +209,7 @@ func (h *HeaderPreset) ReadRequest(underlay net.Conn) (err error, leftBuf *bytes
for _, header := range headerBytesList {
//log.Println("ReadRequest read header", string(h))
hs := string(header)
ss := strings.Split(hs, ":")
ss := strings.SplitN(hs, ":", 2)
if len(ss) != 2 {
err = utils.ErrInvalidData
return
@@ -234,7 +253,7 @@ func (h *HeaderPreset) ReadRequest(underlay net.Conn) (err error, leftBuf *bytes
rp.WholeRequestBuf.Next(4)
return nil, rp.WholeRequestBuf
return rp.WholeRequestBuf, nil
}
func (h *HeaderPreset) WriteRequest(underlay net.Conn, payload []byte) error {
@@ -255,7 +274,7 @@ func (h *HeaderPreset) WriteRequest(underlay net.Conn, payload []byte) error {
return r.Write(underlay)
}
func (h *HeaderPreset) ReadResponse(underlay net.Conn) (err error, leftBuf *bytes.Buffer) {
func (h *HeaderPreset) ReadResponse(underlay net.Conn) (leftBuf *bytes.Buffer, err error) {
bs := utils.GetPacket()
var n int
@@ -278,7 +297,7 @@ func (h *HeaderPreset) ReadResponse(underlay net.Conn) (err error, leftBuf *byte
buf := bytes.NewBuffer(bs[indexOfEnding+4 : n])
return nil, buf
return buf, nil
}
func (h *HeaderPreset) WriteResponse(underlay net.Conn, payload []byte) error {
@@ -330,7 +349,7 @@ func (c *HeaderConn) Read(p []byte) (n int, err error) {
if c.IsServerEnd {
if c.optionalReader == nil {
err, buf = c.H.ReadRequest(c.Conn)
buf, err = c.H.ReadRequest(c.Conn)
if err != nil {
err = utils.ErrInErr{ErrDesc: "http HeaderConn Read failed, at serverEnd", ErrDetail: err}
return
@@ -341,7 +360,7 @@ func (c *HeaderConn) Read(p []byte) (n int, err error) {
} else {
if c.optionalReader == nil {
err, buf = c.H.ReadResponse(c.Conn)
buf, err = c.H.ReadResponse(c.Conn)
if err != nil {
err = utils.ErrInErr{ErrDesc: "http HeaderConn Read failed", ErrDetail: err}
return

View File

@@ -119,6 +119,7 @@ type H1RequestParser struct {
Method string
WholeRequestBuf *bytes.Buffer
Failreason int //为0表示没错误
Headers []RawHeader
}
// 尝试读取数据并解析HTTP请求, 解析道道 数据会存入 RequestParser 结构中.
@@ -134,7 +135,7 @@ func (rhr *H1RequestParser) ReadAndParse(r io.Reader) error {
buf := bytes.NewBuffer(data)
rhr.WholeRequestBuf = buf
rhr.Version, rhr.Method, rhr.Path, rhr.Failreason = GetH1RequestMethod_and_PATH_from_Bytes(data, false)
rhr.Version, rhr.Method, rhr.Path, rhr.Headers, rhr.Failreason = ParseH1Request(data, false)
if rhr.Failreason != 0 {
return utils.ErrInErr{ErrDesc: "httpLayer ReadAndParse failed", ErrDetail: ErrNotHTTP_Request, Data: rhr.Failreason}
}

View File

@@ -105,7 +105,7 @@ func checkfallback(iics incomingInserverConnState) (targetAddr netLayer.Addr, re
if iics.fallbackFirstBuffer != nil && theRequestPath == "" {
var failreason int
_, _, theRequestPath, failreason = httpLayer.GetH1RequestMethod_and_PATH_from_Bytes(iics.fallbackFirstBuffer.Bytes(), false)
_, _, theRequestPath, _, failreason = httpLayer.ParseH1Request(iics.fallbackFirstBuffer.Bytes(), false)
if failreason != 0 {
theRequestPath = ""
@@ -164,7 +164,7 @@ func checkfallback(iics incomingInserverConnState) (targetAddr netLayer.Addr, re
}
//默认回落, 每个listen配置 都 有一个自己独享的默认回落
//默认回落, 每个listen配置 都 有一个自己独享的默认回落配置 (fallback = 80 这种)
if defaultFallbackAddr := iics.inServer.GetFallback(); defaultFallbackAddr != nil {
@@ -176,6 +176,10 @@ func checkfallback(iics incomingInserverConnState) (targetAddr netLayer.Addr, re
targetAddr = *defaultFallbackAddr
result = 0
} else {
result = -1
}
return
}

View File

@@ -64,7 +64,7 @@ func (s *Server) Handshake(underlay net.Conn) (newconn net.Conn, _ netLayer.MsgC
// "CONNECT is intended only for use in requests to a proxy. " 总之CONNECT命令专门用于代理.
// GET如果 path也是带 http:// 头的话也是可以的但是这种只适用于http代理无法用于https。
_, method, path, failreason := httpLayer.GetH1RequestMethod_and_PATH_from_Bytes(bs[:n], true)
_, method, path, _, failreason := httpLayer.ParseH1Request(bs[:n], true)
if failreason != 0 {
err = utils.ErrInErr{ErrDesc: "get method/path failed", ErrDetail: utils.ErrInvalidData, Data: []any{method, failreason}}
@@ -102,7 +102,7 @@ func (s *Server) Handshake(underlay net.Conn) (newconn net.Conn, _ netLayer.MsgC
}
addressStr = hostPortURL.Host
if strings.Index(hostPortURL.Host, ":") == -1 { //host不带端口 默认80
if !strings.Contains(hostPortURL.Host, ":") { //host不带端口 默认80
addressStr = hostPortURL.Host + ":80"
}
}

View File

@@ -2,7 +2,6 @@ package tlsLayer
import (
"crypto/tls"
"log"
"net"
"unsafe"
@@ -23,7 +22,7 @@ func NewServer(host, certFile, keyFile string, isInsecure bool, alpnList []strin
return nil, err
}
//发现服务端必须给出 http/1.1 等否则不会协商出这个alpn而我们为了回落是需要协商出所有可能需要 alpn的。
//发现服务端必须给出 http/1.1 等否则不会协商出这个alpn而我们为了回落是需要协商出所有可能需要 alpn的。
//而且我们如果不提供 h1 和 h2 的alpn的话很容易被审查者察觉的。
@@ -39,8 +38,6 @@ func NewServer(host, certFile, keyFile string, isInsecure bool, alpnList []strin
}
}
log.Println("NewServer", alpnList)
s := &Server{
tlsConfig: &tls.Config{
InsecureSkipVerify: isInsecure,

View File

@@ -6,7 +6,7 @@ import (
)
var (
mtuPool sync.Pool //专门储存 长度为 MTU 的 []byte
mtuPool sync.Pool //专门储存 长度为 MTU 的 *[]byte, 注意这里存储的是指针
// packetPool 专门储存 长度为 MaxPacketLen 的 []byte
//
@@ -34,13 +34,15 @@ func init() {
mtuPool = sync.Pool{
New: func() interface{} {
return make([]byte, MTU)
bs := make([]byte, MTU)
return &bs
},
}
packetPool = sync.Pool{
New: func() interface{} {
return make([]byte, MaxPacketLen)
bs := make([]byte, MaxPacketLen)
return &bs
},
}
@@ -64,7 +66,8 @@ func PutBuf(buf *bytes.Buffer) {
//建议在 Read net.Conn 时, 使用 GetPacket函数 获取到足够大的 []byteMaxBufLen
func GetPacket() []byte {
return packetPool.Get().([]byte)
bsPtr := packetPool.Get().(*[]byte)
return *bsPtr
}
// 放回用 GetPacket 获取的 []byte
@@ -72,23 +75,27 @@ func PutPacket(bs []byte) {
c := cap(bs)
if c < MaxPacketLen {
if c >= MTU {
mtuPool.Put(bs[:MTU])
bs = bs[:MTU]
mtuPool.Put(&bs)
}
return
}
packetPool.Put(bs[:MaxPacketLen])
bs = bs[:MaxPacketLen]
packetPool.Put(&bs)
}
// 从Pool中获取一个 MTU 长度的 []byte
func GetMTU() []byte {
return mtuPool.Get().([]byte)
bs := mtuPool.Get().(*[]byte)
return *bs
}
// 从pool中获取 []byte, 根据给出长度不同来源于的Pool会不同.
func GetBytes(size int) []byte {
if size <= MTU {
bs := mtuPool.Get().([]byte)
bsPtr := mtuPool.Get().(*[]byte)
bs := *bsPtr
return bs[:size]
}
@@ -103,8 +110,13 @@ func PutBytes(bs []byte) {
return
} else if c < MaxPacketLen {
mtuPool.Put(bs[:MTU])
if c != MTU {
bs = bs[:MTU]
}
mtuPool.Put(&bs)
} else {
packetPool.Put(bs[:MaxPacketLen])
bs = bs[:MaxPacketLen]
packetPool.Put(&bs)
}
}