1. 国密算法完成

2. 修复插件修改后未销毁的问题
This commit is contained in:
Liujian
2025-07-07 18:23:29 +08:00
parent dbfb361d39
commit f0d2ab24bb
7 changed files with 95 additions and 69 deletions

View File

@@ -1,6 +1,7 @@
package certs
import (
"crypto/tls"
"errors"
"sync"
"sync/atomic"
@@ -11,52 +12,58 @@ import (
var errorCertificateNotExit = errors.New("not exist cert")
type ICert interface {
SaveCert(workerId string, cert []*gmtls.Certificate)
DelCert(workerId string)
}
var (
workerMaps = make(map[string][]*gmtls.Certificate)
lock = sync.RWMutex{}
currentWorkers = make(map[string]*tls.Certificate)
gmWorkers = make(map[string][]*gmtls.Certificate)
lock = sync.RWMutex{}
// currentCert 普通TLS证书
currentCert = atomic.Pointer[config.Cert]{}
currentCert = atomic.Pointer[config.Cert[tls.Certificate]]{}
// gmCert gmTLS证书
gmCert = atomic.Pointer[config.Cert]{}
gmCert = atomic.Pointer[config.Cert[gmtls.Certificate]]{}
// gmEncCert gmTLS加密证书
gmEncCert = atomic.Pointer[config.Cert]{}
gmEncCert = atomic.Pointer[config.Cert[gmtls.Certificate]]{}
)
func init() {
currentCert.Store(config.NewCert(nil))
currentCert.Store(config.NewCert[tls.Certificate](nil))
gmCert.Store(config.NewCert[gmtls.Certificate](nil))
gmEncCert.Store(config.NewCert[gmtls.Certificate](nil))
}
func DelCert(workerId string) {
lock.Lock()
defer lock.Unlock()
delete(workerMaps, workerId)
delete(currentWorkers, workerId)
rebuild()
}
func SaveCert(workerId string, certs []*gmtls.Certificate) {
func SaveCert(workerId string, certs *tls.Certificate) {
lock.Lock()
defer lock.Unlock()
workerMaps[workerId] = certs
currentWorkers[workerId] = certs
rebuild()
}
func rebuild() {
currentMap := make(map[string]*gmtls.Certificate)
func SaveGMCert(workerId string, certs []*gmtls.Certificate) {
lock.Lock()
defer lock.Unlock()
gmWorkers[workerId] = certs
gmRebuild()
}
func DelGMCert(workerId string) {
lock.Lock()
defer lock.Unlock()
delete(gmWorkers, workerId)
gmRebuild()
}
func gmRebuild() {
gmMap := make(map[string]*gmtls.Certificate)
gmEncMap := make(map[string]*gmtls.Certificate)
for _, cs := range workerMaps {
for _, cs := range gmWorkers {
l := len(cs)
switch {
case l == 1:
i := cs[0]
currentMap[i.Leaf.Subject.CommonName] = i
for _, dnsName := range i.Leaf.DNSNames {
currentMap[dnsName] = i
}
case l == 2:
i := cs[0]
gmMap[i.Leaf.Subject.CommonName] = i
@@ -74,26 +81,36 @@ func rebuild() {
}
}
currentCert.Swap(config.NewCert(currentMap))
gmCert.Swap(config.NewCert(gmMap))
gmEncCert.Swap(config.NewCert(gmEncMap))
}
func rebuild() {
currentMap := make(map[string]*tls.Certificate)
for _, cs := range currentWorkers {
i := cs
currentMap[i.Leaf.Subject.CommonName] = i
for _, dnsName := range i.Leaf.DNSNames {
currentMap[dnsName] = i
}
}
currentCert.Swap(config.NewCert(currentMap))
}
func GetCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
func GetCertificateFunc(certsLocal ...*config.Cert[tls.Certificate]) func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if len(certsLocal) == 0 {
return func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return currentCert.Load().GetCertificate(info)
}
}
certList := make([]*config.Cert, 0, len(certsLocal))
certList := make([]*config.Cert[tls.Certificate], 0, len(certsLocal))
for _, c := range certsLocal {
if c != nil {
certList = append(certList, c)
}
}
return func(info *gmtls.ClientHelloInfo) (certificate *gmtls.Certificate, err error) {
return func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) {
certificate, err = currentCert.Load().GetCertificate(info)
if certificate != nil {
return
@@ -112,19 +129,8 @@ func GetCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHello
}
func GetAutoCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
func GetGMCertificateFunc() func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
return func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
gmFlag := false
// 检查支持协议中是否包含GMSSL
for _, v := range info.SupportedVersions {
if v == gmtls.VersionGMSSL {
gmFlag = true
break
}
}
if !gmFlag {
return GetCertificateFunc(certsLocal...)(info)
}
return gmCert.Load().GetCertificate(info)
}
}

View File

@@ -1,12 +1,13 @@
package certs
import (
"crypto/tls"
"crypto/x509"
"github.com/eolinker/apinto/certs"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/apinto/utils"
"github.com/eolinker/eosc"
"github.com/tjfoc/gmsm/gmtls"
"github.com/tjfoc/gmsm/x509"
)
var (
@@ -53,7 +54,7 @@ func (w *Worker) Reset(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro
}
w.config = config
certs.SaveCert(w.Id(), []*gmtls.Certificate{cert})
certs.SaveCert(w.Id(), cert)
return nil
}
@@ -66,7 +67,7 @@ func (w *Worker) CheckSkill(string) bool {
return false
}
func parseCert(privateKey, pemValue string) (*gmtls.Certificate, error) {
func parseCert(privateKey, pemValue string) (*tls.Certificate, error) {
cert, err := genCert([]byte(privateKey), []byte(pemValue))
if err == nil {
return cert, nil
@@ -83,8 +84,8 @@ func parseCert(privateKey, pemValue string) (*gmtls.Certificate, error) {
return genCert(keydata, pem)
}
func genCert(key, pem []byte) (*gmtls.Certificate, error) {
certificate, err := gmtls.X509KeyPair(pem, key)
func genCert(key, pem []byte) (*tls.Certificate, error) {
certificate, err := tls.X509KeyPair(pem, key)
if err != nil {
return nil, err
}

View File

@@ -40,7 +40,7 @@ func (w *Worker) Check(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro
func (w *Worker) Destroy() error {
controller.Del(w.Id())
certs.DelCert(w.Id())
certs.DelGMCert(w.Id())
return nil
}
@@ -64,7 +64,7 @@ func (w *Worker) Reset(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro
}
w.config = c
certs.SaveCert(w.Id(), []*gmtls.Certificate{signCert, encCert})
certs.SaveGMCert(w.Id(), []*gmtls.Certificate{signCert, encCert})
return nil
}

View File

@@ -92,7 +92,11 @@ func (p *PluginManager) Reset(conf interface{}) error {
list := p.pluginObjs.List()
// 遍历,全量更新
for _, v := range list {
old := v.fs
v.fs = p.createFilters(v.conf)
if old != nil {
old.Destroy()
}
}
return nil

View File

@@ -49,7 +49,7 @@ func initListener(tf traffic.ITraffic, listenCfg *config.ListenUrl) {
}
wg := sync.WaitGroup{}
tcp, ssl := tf.Listen(listenCfg.ListenUrls...)
tcp, ssl, gmSsl := tf.Listen(listenCfg.ListenUrls...)
listenerByPort := make(map[int][]net.Listener)
for _, l := range tcp {
@@ -57,18 +57,27 @@ func initListener(tf traffic.ITraffic, listenCfg *config.ListenUrl) {
listenerByPort[port] = append(listenerByPort[port], l)
}
if len(ssl) > 0 {
tlsConfig := &tls.Config{
GetCertificate: certs.GetCertificateFunc(),
MinVersion: tls.VersionSSL30,
MaxVersion: tls.VersionTLS13,
}
for _, l := range ssl {
log.Debug("ssl listen: ", l.Addr().String())
port := readPort(l.Addr())
listenerByPort[port] = append(listenerByPort[port], tls.NewListener(l, tlsConfig))
}
}
if len(gmSsl) > 0 {
support := gmtls.NewGMSupport()
support.EnableMixMode()
gmTlsConfig := &gmtls.Config{
GetCertificate: certs.GetAutoCertificateFunc(),
GetCertificate: certs.GetGMCertificateFunc(),
GetKECertificate: certs.GetKECertificate(),
GMSupport: support,
MinVersion: gmtls.VersionGMSSL,
MaxVersion: tls.VersionTLS13,
}
for _, l := range ssl {
log.Debug("ssl listen: ", l.Addr().String())
for _, l := range gmSsl {
log.Debug("gm ssl listen: ", l.Addr().String())
port := readPort(l.Addr())
listenerByPort[port] = append(listenerByPort[port], gmtls.NewListener(l, gmTlsConfig))
}

10
go.mod
View File

@@ -1,8 +1,8 @@
module github.com/eolinker/apinto
go 1.21
go 1.23
toolchain go1.21.1
toolchain go1.23.1
require (
github.com/Shopify/sarama v1.32.0
@@ -11,7 +11,7 @@ require (
github.com/clbanning/mxj v1.8.4
github.com/coocood/freecache v1.2.2
github.com/dubbogo/gost v1.13.1
github.com/eolinker/eosc v0.20.6
github.com/eolinker/eosc v0.21.1
github.com/fasthttp/websocket v1.5.0
github.com/fullstorydev/grpcurl v1.8.7
github.com/go-redis/redis/v8 v8.11.5
@@ -206,4 +206,6 @@ require (
replace github.com/soheilhy/cmux v0.1.5 => github.com/hmzzrcs/cmux v0.1.6
replace github.com/eolinker/eosc => ../eosc
//replace (
// github.com/eolinker/eosc => ../eosc
//)

View File

@@ -156,17 +156,21 @@ type timestampChecker struct {
endTime time.Time
}
func newTimestampChecker(timeRange string) (*timestampChecker, error) {
// 正则表达式:匹配 HH:mm:ss - HH:mm:ss
regex := `^((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00) - ((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00)$`
re := regexp.MustCompile(regex)
if !re.MatchString(timeRange) {
return nil, fmt.Errorf("invalid time format, expected HH:mm:ss - HH:mm:ss (00:00:00 - 24:00:00)")
}
var (
timeRangeRegex = regexp.MustCompile(`^((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00)$`)
)
func newTimestampChecker(timeRange string) (*timestampChecker, error) {
// 提取开始时间和结束时间
times := strings.Split(timeRange, " - ")
startTimeStr, endTimeStr := times[0], times[1]
times := strings.Split(timeRange, "-")
startTimeStr, endTimeStr := strings.TrimSpace(times[0]), strings.TrimSpace(times[1])
if !timeRangeRegex.MatchString(startTimeStr) {
return nil, fmt.Errorf("invalid time format for start time: %s", startTimeStr)
}
if !timeRangeRegex.MatchString(endTimeStr) {
return nil, fmt.Errorf("invalid time format for end time: %s", endTimeStr)
}
// 解析开始时间和结束时间(假设在当前日期)
startTime, err := time.Parse("15:04:05", startTimeStr)
if err != nil {