diff --git a/proxy/tlsConfig.go b/proxy/tlsConfig.go index a3d2dea..dd4e2a0 100644 --- a/proxy/tlsConfig.go +++ b/proxy/tlsConfig.go @@ -72,7 +72,7 @@ func prepareTLS_forClient(com BaseInterface, dc *DialConf) error { return nil } -// use lc.Host, lc.TLSCert, lc.TLSKey, lc.Insecure, lc.Alpn. +// use lc.Host, lc.TLSCert, lc.TLSKey, lc.Insecure, lc.Alpn, lc.Extra func prepareTLS_forServer(com BaseInterface, lc *ListenConf) error { serc := com.GetBase() @@ -82,16 +82,15 @@ func prepareTLS_forServer(com BaseInterface, lc *ListenConf) error { alpnList := updateAlpnListByAdvLayer(com, lc.Alpn) - var minVer uint16 = tlsLayer.GetMinVerFromExtra(lc.Extra) - tlsserver, err := tlsLayer.NewServer(tlsLayer.Conf{ Host: lc.Host, CertConf: &tlsLayer.CertConf{ CertFile: lc.TLSCert, KeyFile: lc.TLSKey, CA: lc.CA, }, - Insecure: lc.Insecure, - AlpnList: alpnList, - Minver: minVer, + Insecure: lc.Insecure, + AlpnList: alpnList, + Minver: tlsLayer.GetMinVerFromExtra(lc.Extra), + RejectUnknownSni: tlsLayer.GetRejectUnknownSniFromExtra(lc.Extra), }) if err == nil { diff --git a/tlsLayer/tlsLayer.go b/tlsLayer/tlsLayer.go index 6ee85ad..aea4575 100644 --- a/tlsLayer/tlsLayer.go +++ b/tlsLayer/tlsLayer.go @@ -7,6 +7,7 @@ package tlsLayer import ( "crypto/tls" + "strings" "unsafe" "github.com/e1732a364fed/v2ray_simple/utils" @@ -16,12 +17,13 @@ import ( type Conf struct { Host string + Insecure bool + Minver uint16 + AlpnList []string CertConf *CertConf - Insecure bool - Use_uTls bool //only client - AlpnList []string - Minver uint16 + Use_uTls bool //only client + RejectUnknownSni bool //only server } func GetMinVerFromExtra(extra map[string]any) uint16 { @@ -39,6 +41,43 @@ func GetMinVerFromExtra(extra map[string]any) uint16 { return tls.VersionTLS13 } +func GetRejectUnknownSniFromExtra(extra map[string]any) bool { + if len(extra) > 0 { + if thing := extra["rejectUnknownSni"]; thing != nil { + if is, ok := utils.AnyToBool(thing); ok && is { + return true + } + } + } + + return false +} + +func rejectUnknownGetCertificateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if len(certs) == 0 { + return nil, utils.ErrInErr{ErrDesc: "len(certs) == 0", ErrDetail: utils.ErrInvalidData} + } + sni := strings.ToLower(hello.ServerName) + + gsni := "*" + if index := strings.IndexByte(sni, '.'); index != -1 { + gsni += sni[index:] + } + for _, keyPair := range certs { + if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { + return keyPair, nil + } + for _, name := range keyPair.Leaf.DNSNames { + if name == sni || name == gsni { + return keyPair, nil + } + } + } + return nil, utils.ErrInErr{ErrDesc: "rejectUnknownSNI", ErrDetail: utils.ErrInvalidData, Data: sni} + } +} + func GetTlsConfig(mustHasCert bool, conf Conf) *tls.Config { var certArray []tls.Certificate var err error @@ -83,6 +122,9 @@ func GetTlsConfig(mustHasCert bool, conf Conf) *tls.Config { tConf.ClientAuth = tls.RequireAndVerifyClientCert } } + if conf.RejectUnknownSni { + tConf.GetCertificate = rejectUnknownGetCertificateFunc(utils.ArrayToPtrArray(certArray)) + } return tConf } diff --git a/utils/algo.go b/utils/algo.go index fb8be29..15da75e 100644 --- a/utils/algo.go +++ b/utils/algo.go @@ -8,6 +8,13 @@ import ( "golang.org/x/exp/slices" ) +func ArrayToPtrArray[T any](a []T) (r []*T) { + for _, v := range a { + r = append(r, &v) + } + return +} + //Combinatorics //////////////////////////////////////////////////////////////// // func AllSubSets edited from https://github.com/mxschmitt/golang-combinations with MIT License