diff --git a/certs/cert.go b/certs/cert.go index 48605a7b..04e740bc 100644 --- a/certs/cert.go +++ b/certs/cert.go @@ -1,6 +1,7 @@ package certs import ( + "crypto/tls" "errors" "sync" "sync/atomic" @@ -11,52 +12,58 @@ import ( var errorCertificateNotExit = errors.New("not exist cert") -type ICert interface { - SaveCert(workerId string, cert []*gmtls.Certificate) - DelCert(workerId string) -} - var ( - workerMaps = make(map[string][]*gmtls.Certificate) - lock = sync.RWMutex{} + currentWorkers = make(map[string]*tls.Certificate) + gmWorkers = make(map[string][]*gmtls.Certificate) + + lock = sync.RWMutex{} // currentCert 普通TLS证书 - currentCert = atomic.Pointer[config.Cert]{} + currentCert = atomic.Pointer[config.Cert[tls.Certificate]]{} // gmCert gmTLS证书 - gmCert = atomic.Pointer[config.Cert]{} + gmCert = atomic.Pointer[config.Cert[gmtls.Certificate]]{} // gmEncCert gmTLS加密证书 - gmEncCert = atomic.Pointer[config.Cert]{} + gmEncCert = atomic.Pointer[config.Cert[gmtls.Certificate]]{} ) func init() { - currentCert.Store(config.NewCert(nil)) - + currentCert.Store(config.NewCert[tls.Certificate](nil)) + gmCert.Store(config.NewCert[gmtls.Certificate](nil)) + gmEncCert.Store(config.NewCert[gmtls.Certificate](nil)) } func DelCert(workerId string) { lock.Lock() defer lock.Unlock() - delete(workerMaps, workerId) + delete(currentWorkers, workerId) rebuild() } -func SaveCert(workerId string, certs []*gmtls.Certificate) { +func SaveCert(workerId string, certs *tls.Certificate) { lock.Lock() defer lock.Unlock() - workerMaps[workerId] = certs + currentWorkers[workerId] = certs rebuild() } -func rebuild() { - currentMap := make(map[string]*gmtls.Certificate) + +func SaveGMCert(workerId string, certs []*gmtls.Certificate) { + lock.Lock() + defer lock.Unlock() + gmWorkers[workerId] = certs + gmRebuild() +} + +func DelGMCert(workerId string) { + lock.Lock() + defer lock.Unlock() + delete(gmWorkers, workerId) + gmRebuild() +} + +func gmRebuild() { gmMap := make(map[string]*gmtls.Certificate) gmEncMap := make(map[string]*gmtls.Certificate) - for _, cs := range workerMaps { + for _, cs := range gmWorkers { l := len(cs) switch { - case l == 1: - i := cs[0] - currentMap[i.Leaf.Subject.CommonName] = i - for _, dnsName := range i.Leaf.DNSNames { - currentMap[dnsName] = i - } case l == 2: i := cs[0] gmMap[i.Leaf.Subject.CommonName] = i @@ -74,26 +81,36 @@ func rebuild() { } } - currentCert.Swap(config.NewCert(currentMap)) gmCert.Swap(config.NewCert(gmMap)) gmEncCert.Swap(config.NewCert(gmEncMap)) } +func rebuild() { + currentMap := make(map[string]*tls.Certificate) + for _, cs := range currentWorkers { + i := cs + currentMap[i.Leaf.Subject.CommonName] = i + for _, dnsName := range i.Leaf.DNSNames { + currentMap[dnsName] = i + } + } + currentCert.Swap(config.NewCert(currentMap)) +} -func GetCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) { +func GetCertificateFunc(certsLocal ...*config.Cert[tls.Certificate]) func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { if len(certsLocal) == 0 { - return func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) { + return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return currentCert.Load().GetCertificate(info) } } - certList := make([]*config.Cert, 0, len(certsLocal)) + certList := make([]*config.Cert[tls.Certificate], 0, len(certsLocal)) for _, c := range certsLocal { if c != nil { certList = append(certList, c) } } - return func(info *gmtls.ClientHelloInfo) (certificate *gmtls.Certificate, err error) { + return func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) { certificate, err = currentCert.Load().GetCertificate(info) if certificate != nil { return @@ -112,19 +129,8 @@ func GetCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHello } -func GetAutoCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) { +func GetGMCertificateFunc() func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) { return func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) { - gmFlag := false - // 检查支持协议中是否包含GMSSL - for _, v := range info.SupportedVersions { - if v == gmtls.VersionGMSSL { - gmFlag = true - break - } - } - if !gmFlag { - return GetCertificateFunc(certsLocal...)(info) - } return gmCert.Load().GetCertificate(info) } } diff --git a/drivers/certs/worker.go b/drivers/certs/worker.go index f51b12c2..c5a01987 100644 --- a/drivers/certs/worker.go +++ b/drivers/certs/worker.go @@ -1,12 +1,13 @@ package certs import ( + "crypto/tls" + "crypto/x509" + "github.com/eolinker/apinto/certs" "github.com/eolinker/apinto/drivers" "github.com/eolinker/apinto/utils" "github.com/eolinker/eosc" - "github.com/tjfoc/gmsm/gmtls" - "github.com/tjfoc/gmsm/x509" ) var ( @@ -53,7 +54,7 @@ func (w *Worker) Reset(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro } w.config = config - certs.SaveCert(w.Id(), []*gmtls.Certificate{cert}) + certs.SaveCert(w.Id(), cert) return nil } @@ -66,7 +67,7 @@ func (w *Worker) CheckSkill(string) bool { return false } -func parseCert(privateKey, pemValue string) (*gmtls.Certificate, error) { +func parseCert(privateKey, pemValue string) (*tls.Certificate, error) { cert, err := genCert([]byte(privateKey), []byte(pemValue)) if err == nil { return cert, nil @@ -83,8 +84,8 @@ func parseCert(privateKey, pemValue string) (*gmtls.Certificate, error) { return genCert(keydata, pem) } -func genCert(key, pem []byte) (*gmtls.Certificate, error) { - certificate, err := gmtls.X509KeyPair(pem, key) +func genCert(key, pem []byte) (*tls.Certificate, error) { + certificate, err := tls.X509KeyPair(pem, key) if err != nil { return nil, err } diff --git a/drivers/gm-certs/worker.go b/drivers/gm-certs/worker.go index 9b67454e..4768a358 100644 --- a/drivers/gm-certs/worker.go +++ b/drivers/gm-certs/worker.go @@ -40,7 +40,7 @@ func (w *Worker) Check(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro func (w *Worker) Destroy() error { controller.Del(w.Id()) - certs.DelCert(w.Id()) + certs.DelGMCert(w.Id()) return nil } @@ -64,7 +64,7 @@ func (w *Worker) Reset(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro } w.config = c - certs.SaveCert(w.Id(), []*gmtls.Certificate{signCert, encCert}) + certs.SaveGMCert(w.Id(), []*gmtls.Certificate{signCert, encCert}) return nil } diff --git a/drivers/plugin-manager/manager.go b/drivers/plugin-manager/manager.go index c4b47347..c58c4a45 100644 --- a/drivers/plugin-manager/manager.go +++ b/drivers/plugin-manager/manager.go @@ -92,7 +92,11 @@ func (p *PluginManager) Reset(conf interface{}) error { list := p.pluginObjs.List() // 遍历,全量更新 for _, v := range list { + old := v.fs v.fs = p.createFilters(v.conf) + if old != nil { + old.Destroy() + } } return nil diff --git a/drivers/router/listener.go b/drivers/router/listener.go index b742df98..0e5c2b6c 100644 --- a/drivers/router/listener.go +++ b/drivers/router/listener.go @@ -49,7 +49,7 @@ func initListener(tf traffic.ITraffic, listenCfg *config.ListenUrl) { } wg := sync.WaitGroup{} - tcp, ssl := tf.Listen(listenCfg.ListenUrls...) + tcp, ssl, gmSsl := tf.Listen(listenCfg.ListenUrls...) listenerByPort := make(map[int][]net.Listener) for _, l := range tcp { @@ -57,18 +57,27 @@ func initListener(tf traffic.ITraffic, listenCfg *config.ListenUrl) { listenerByPort[port] = append(listenerByPort[port], l) } if len(ssl) > 0 { + tlsConfig := &tls.Config{ + GetCertificate: certs.GetCertificateFunc(), + MinVersion: tls.VersionSSL30, + MaxVersion: tls.VersionTLS13, + } + for _, l := range ssl { + log.Debug("ssl listen: ", l.Addr().String()) + port := readPort(l.Addr()) + listenerByPort[port] = append(listenerByPort[port], tls.NewListener(l, tlsConfig)) + } + } + if len(gmSsl) > 0 { support := gmtls.NewGMSupport() support.EnableMixMode() gmTlsConfig := &gmtls.Config{ - GetCertificate: certs.GetAutoCertificateFunc(), + GetCertificate: certs.GetGMCertificateFunc(), GetKECertificate: certs.GetKECertificate(), GMSupport: support, - MinVersion: gmtls.VersionGMSSL, - MaxVersion: tls.VersionTLS13, } - - for _, l := range ssl { - log.Debug("ssl listen: ", l.Addr().String()) + for _, l := range gmSsl { + log.Debug("gm ssl listen: ", l.Addr().String()) port := readPort(l.Addr()) listenerByPort[port] = append(listenerByPort[port], gmtls.NewListener(l, gmTlsConfig)) } diff --git a/go.mod b/go.mod index 06376af0..c4bfb01d 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/eolinker/apinto -go 1.21 +go 1.23 -toolchain go1.21.1 +toolchain go1.23.1 require ( github.com/Shopify/sarama v1.32.0 @@ -11,7 +11,7 @@ require ( github.com/clbanning/mxj v1.8.4 github.com/coocood/freecache v1.2.2 github.com/dubbogo/gost v1.13.1 - github.com/eolinker/eosc v0.20.6 + github.com/eolinker/eosc v0.21.1 github.com/fasthttp/websocket v1.5.0 github.com/fullstorydev/grpcurl v1.8.7 github.com/go-redis/redis/v8 v8.11.5 @@ -206,4 +206,6 @@ require ( replace github.com/soheilhy/cmux v0.1.5 => github.com/hmzzrcs/cmux v0.1.6 -replace github.com/eolinker/eosc => ../eosc +//replace ( +// github.com/eolinker/eosc => ../eosc +//) diff --git a/strategy/checker.go b/strategy/checker.go index c7980ac8..225c753e 100644 --- a/strategy/checker.go +++ b/strategy/checker.go @@ -156,17 +156,21 @@ type timestampChecker struct { endTime time.Time } -func newTimestampChecker(timeRange string) (*timestampChecker, error) { - // 正则表达式:匹配 HH:mm:ss - HH:mm:ss - regex := `^((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00) - ((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00)$` - re := regexp.MustCompile(regex) - if !re.MatchString(timeRange) { - return nil, fmt.Errorf("invalid time format, expected HH:mm:ss - HH:mm:ss (00:00:00 - 24:00:00)") - } +var ( + timeRangeRegex = regexp.MustCompile(`^((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00)$`) +) +func newTimestampChecker(timeRange string) (*timestampChecker, error) { // 提取开始时间和结束时间 - times := strings.Split(timeRange, " - ") - startTimeStr, endTimeStr := times[0], times[1] + times := strings.Split(timeRange, "-") + startTimeStr, endTimeStr := strings.TrimSpace(times[0]), strings.TrimSpace(times[1]) + + if !timeRangeRegex.MatchString(startTimeStr) { + return nil, fmt.Errorf("invalid time format for start time: %s", startTimeStr) + } + if !timeRangeRegex.MatchString(endTimeStr) { + return nil, fmt.Errorf("invalid time format for end time: %s", endTimeStr) + } // 解析开始时间和结束时间(假设在当前日期) startTime, err := time.Parse("15:04:05", startTimeStr) if err != nil {