mirror of
https://github.com/eolinker/apinto
synced 2025-09-26 21:01:19 +08:00
1. 国密算法完成
2. 修复插件修改后未销毁的问题
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -11,52 +12,58 @@ import (
|
||||
|
||||
var errorCertificateNotExit = errors.New("not exist cert")
|
||||
|
||||
type ICert interface {
|
||||
SaveCert(workerId string, cert []*gmtls.Certificate)
|
||||
DelCert(workerId string)
|
||||
}
|
||||
|
||||
var (
|
||||
workerMaps = make(map[string][]*gmtls.Certificate)
|
||||
lock = sync.RWMutex{}
|
||||
currentWorkers = make(map[string]*tls.Certificate)
|
||||
gmWorkers = make(map[string][]*gmtls.Certificate)
|
||||
|
||||
lock = sync.RWMutex{}
|
||||
// currentCert 普通TLS证书
|
||||
currentCert = atomic.Pointer[config.Cert]{}
|
||||
currentCert = atomic.Pointer[config.Cert[tls.Certificate]]{}
|
||||
// gmCert gmTLS证书
|
||||
gmCert = atomic.Pointer[config.Cert]{}
|
||||
gmCert = atomic.Pointer[config.Cert[gmtls.Certificate]]{}
|
||||
// gmEncCert gmTLS加密证书
|
||||
gmEncCert = atomic.Pointer[config.Cert]{}
|
||||
gmEncCert = atomic.Pointer[config.Cert[gmtls.Certificate]]{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
currentCert.Store(config.NewCert(nil))
|
||||
|
||||
currentCert.Store(config.NewCert[tls.Certificate](nil))
|
||||
gmCert.Store(config.NewCert[gmtls.Certificate](nil))
|
||||
gmEncCert.Store(config.NewCert[gmtls.Certificate](nil))
|
||||
}
|
||||
func DelCert(workerId string) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
delete(workerMaps, workerId)
|
||||
delete(currentWorkers, workerId)
|
||||
rebuild()
|
||||
}
|
||||
|
||||
func SaveCert(workerId string, certs []*gmtls.Certificate) {
|
||||
func SaveCert(workerId string, certs *tls.Certificate) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
workerMaps[workerId] = certs
|
||||
currentWorkers[workerId] = certs
|
||||
rebuild()
|
||||
}
|
||||
func rebuild() {
|
||||
currentMap := make(map[string]*gmtls.Certificate)
|
||||
|
||||
func SaveGMCert(workerId string, certs []*gmtls.Certificate) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
gmWorkers[workerId] = certs
|
||||
gmRebuild()
|
||||
}
|
||||
|
||||
func DelGMCert(workerId string) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
delete(gmWorkers, workerId)
|
||||
gmRebuild()
|
||||
}
|
||||
|
||||
func gmRebuild() {
|
||||
gmMap := make(map[string]*gmtls.Certificate)
|
||||
gmEncMap := make(map[string]*gmtls.Certificate)
|
||||
for _, cs := range workerMaps {
|
||||
for _, cs := range gmWorkers {
|
||||
l := len(cs)
|
||||
switch {
|
||||
case l == 1:
|
||||
i := cs[0]
|
||||
currentMap[i.Leaf.Subject.CommonName] = i
|
||||
for _, dnsName := range i.Leaf.DNSNames {
|
||||
currentMap[dnsName] = i
|
||||
}
|
||||
case l == 2:
|
||||
i := cs[0]
|
||||
gmMap[i.Leaf.Subject.CommonName] = i
|
||||
@@ -74,26 +81,36 @@ func rebuild() {
|
||||
}
|
||||
|
||||
}
|
||||
currentCert.Swap(config.NewCert(currentMap))
|
||||
gmCert.Swap(config.NewCert(gmMap))
|
||||
gmEncCert.Swap(config.NewCert(gmEncMap))
|
||||
}
|
||||
func rebuild() {
|
||||
currentMap := make(map[string]*tls.Certificate)
|
||||
for _, cs := range currentWorkers {
|
||||
i := cs
|
||||
currentMap[i.Leaf.Subject.CommonName] = i
|
||||
for _, dnsName := range i.Leaf.DNSNames {
|
||||
currentMap[dnsName] = i
|
||||
}
|
||||
}
|
||||
currentCert.Swap(config.NewCert(currentMap))
|
||||
}
|
||||
|
||||
func GetCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
|
||||
func GetCertificateFunc(certsLocal ...*config.Cert[tls.Certificate]) func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
|
||||
if len(certsLocal) == 0 {
|
||||
|
||||
return func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
|
||||
return func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return currentCert.Load().GetCertificate(info)
|
||||
}
|
||||
}
|
||||
certList := make([]*config.Cert, 0, len(certsLocal))
|
||||
certList := make([]*config.Cert[tls.Certificate], 0, len(certsLocal))
|
||||
for _, c := range certsLocal {
|
||||
if c != nil {
|
||||
certList = append(certList, c)
|
||||
}
|
||||
}
|
||||
return func(info *gmtls.ClientHelloInfo) (certificate *gmtls.Certificate, err error) {
|
||||
return func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, err error) {
|
||||
certificate, err = currentCert.Load().GetCertificate(info)
|
||||
if certificate != nil {
|
||||
return
|
||||
@@ -112,19 +129,8 @@ func GetCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHello
|
||||
|
||||
}
|
||||
|
||||
func GetAutoCertificateFunc(certsLocal ...*config.Cert) func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
|
||||
func GetGMCertificateFunc() func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
|
||||
return func(info *gmtls.ClientHelloInfo) (*gmtls.Certificate, error) {
|
||||
gmFlag := false
|
||||
// 检查支持协议中是否包含GMSSL
|
||||
for _, v := range info.SupportedVersions {
|
||||
if v == gmtls.VersionGMSSL {
|
||||
gmFlag = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !gmFlag {
|
||||
return GetCertificateFunc(certsLocal...)(info)
|
||||
}
|
||||
return gmCert.Load().GetCertificate(info)
|
||||
}
|
||||
}
|
||||
|
@@ -1,12 +1,13 @@
|
||||
package certs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/eolinker/apinto/certs"
|
||||
"github.com/eolinker/apinto/drivers"
|
||||
"github.com/eolinker/apinto/utils"
|
||||
"github.com/eolinker/eosc"
|
||||
"github.com/tjfoc/gmsm/gmtls"
|
||||
"github.com/tjfoc/gmsm/x509"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -53,7 +54,7 @@ func (w *Worker) Reset(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro
|
||||
}
|
||||
|
||||
w.config = config
|
||||
certs.SaveCert(w.Id(), []*gmtls.Certificate{cert})
|
||||
certs.SaveCert(w.Id(), cert)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -66,7 +67,7 @@ func (w *Worker) CheckSkill(string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCert(privateKey, pemValue string) (*gmtls.Certificate, error) {
|
||||
func parseCert(privateKey, pemValue string) (*tls.Certificate, error) {
|
||||
cert, err := genCert([]byte(privateKey), []byte(pemValue))
|
||||
if err == nil {
|
||||
return cert, nil
|
||||
@@ -83,8 +84,8 @@ func parseCert(privateKey, pemValue string) (*gmtls.Certificate, error) {
|
||||
return genCert(keydata, pem)
|
||||
}
|
||||
|
||||
func genCert(key, pem []byte) (*gmtls.Certificate, error) {
|
||||
certificate, err := gmtls.X509KeyPair(pem, key)
|
||||
func genCert(key, pem []byte) (*tls.Certificate, error) {
|
||||
certificate, err := tls.X509KeyPair(pem, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -40,7 +40,7 @@ func (w *Worker) Check(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro
|
||||
func (w *Worker) Destroy() error {
|
||||
|
||||
controller.Del(w.Id())
|
||||
certs.DelCert(w.Id())
|
||||
certs.DelGMCert(w.Id())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (w *Worker) Reset(conf interface{}, _ map[eosc.RequireId]eosc.IWorker) erro
|
||||
}
|
||||
|
||||
w.config = c
|
||||
certs.SaveCert(w.Id(), []*gmtls.Certificate{signCert, encCert})
|
||||
certs.SaveGMCert(w.Id(), []*gmtls.Certificate{signCert, encCert})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -92,7 +92,11 @@ func (p *PluginManager) Reset(conf interface{}) error {
|
||||
list := p.pluginObjs.List()
|
||||
// 遍历,全量更新
|
||||
for _, v := range list {
|
||||
old := v.fs
|
||||
v.fs = p.createFilters(v.conf)
|
||||
if old != nil {
|
||||
old.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@@ -49,7 +49,7 @@ func initListener(tf traffic.ITraffic, listenCfg *config.ListenUrl) {
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
tcp, ssl := tf.Listen(listenCfg.ListenUrls...)
|
||||
tcp, ssl, gmSsl := tf.Listen(listenCfg.ListenUrls...)
|
||||
|
||||
listenerByPort := make(map[int][]net.Listener)
|
||||
for _, l := range tcp {
|
||||
@@ -57,18 +57,27 @@ func initListener(tf traffic.ITraffic, listenCfg *config.ListenUrl) {
|
||||
listenerByPort[port] = append(listenerByPort[port], l)
|
||||
}
|
||||
if len(ssl) > 0 {
|
||||
tlsConfig := &tls.Config{
|
||||
GetCertificate: certs.GetCertificateFunc(),
|
||||
MinVersion: tls.VersionSSL30,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
for _, l := range ssl {
|
||||
log.Debug("ssl listen: ", l.Addr().String())
|
||||
port := readPort(l.Addr())
|
||||
listenerByPort[port] = append(listenerByPort[port], tls.NewListener(l, tlsConfig))
|
||||
}
|
||||
}
|
||||
if len(gmSsl) > 0 {
|
||||
support := gmtls.NewGMSupport()
|
||||
support.EnableMixMode()
|
||||
gmTlsConfig := &gmtls.Config{
|
||||
GetCertificate: certs.GetAutoCertificateFunc(),
|
||||
GetCertificate: certs.GetGMCertificateFunc(),
|
||||
GetKECertificate: certs.GetKECertificate(),
|
||||
GMSupport: support,
|
||||
MinVersion: gmtls.VersionGMSSL,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
for _, l := range ssl {
|
||||
log.Debug("ssl listen: ", l.Addr().String())
|
||||
for _, l := range gmSsl {
|
||||
log.Debug("gm ssl listen: ", l.Addr().String())
|
||||
port := readPort(l.Addr())
|
||||
listenerByPort[port] = append(listenerByPort[port], gmtls.NewListener(l, gmTlsConfig))
|
||||
}
|
||||
|
10
go.mod
10
go.mod
@@ -1,8 +1,8 @@
|
||||
module github.com/eolinker/apinto
|
||||
|
||||
go 1.21
|
||||
go 1.23
|
||||
|
||||
toolchain go1.21.1
|
||||
toolchain go1.23.1
|
||||
|
||||
require (
|
||||
github.com/Shopify/sarama v1.32.0
|
||||
@@ -11,7 +11,7 @@ require (
|
||||
github.com/clbanning/mxj v1.8.4
|
||||
github.com/coocood/freecache v1.2.2
|
||||
github.com/dubbogo/gost v1.13.1
|
||||
github.com/eolinker/eosc v0.20.6
|
||||
github.com/eolinker/eosc v0.21.1
|
||||
github.com/fasthttp/websocket v1.5.0
|
||||
github.com/fullstorydev/grpcurl v1.8.7
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
@@ -206,4 +206,6 @@ require (
|
||||
|
||||
replace github.com/soheilhy/cmux v0.1.5 => github.com/hmzzrcs/cmux v0.1.6
|
||||
|
||||
replace github.com/eolinker/eosc => ../eosc
|
||||
//replace (
|
||||
// github.com/eolinker/eosc => ../eosc
|
||||
//)
|
||||
|
@@ -156,17 +156,21 @@ type timestampChecker struct {
|
||||
endTime time.Time
|
||||
}
|
||||
|
||||
func newTimestampChecker(timeRange string) (*timestampChecker, error) {
|
||||
// 正则表达式:匹配 HH:mm:ss - HH:mm:ss
|
||||
regex := `^((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00) - ((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00)$`
|
||||
re := regexp.MustCompile(regex)
|
||||
if !re.MatchString(timeRange) {
|
||||
return nil, fmt.Errorf("invalid time format, expected HH:mm:ss - HH:mm:ss (00:00:00 - 24:00:00)")
|
||||
}
|
||||
var (
|
||||
timeRangeRegex = regexp.MustCompile(`^((?:[01]\d|2[0-3]):[0-5]\d:[0-5]\d|24:00:00)$`)
|
||||
)
|
||||
|
||||
func newTimestampChecker(timeRange string) (*timestampChecker, error) {
|
||||
// 提取开始时间和结束时间
|
||||
times := strings.Split(timeRange, " - ")
|
||||
startTimeStr, endTimeStr := times[0], times[1]
|
||||
times := strings.Split(timeRange, "-")
|
||||
startTimeStr, endTimeStr := strings.TrimSpace(times[0]), strings.TrimSpace(times[1])
|
||||
|
||||
if !timeRangeRegex.MatchString(startTimeStr) {
|
||||
return nil, fmt.Errorf("invalid time format for start time: %s", startTimeStr)
|
||||
}
|
||||
if !timeRangeRegex.MatchString(endTimeStr) {
|
||||
return nil, fmt.Errorf("invalid time format for end time: %s", endTimeStr)
|
||||
}
|
||||
// 解析开始时间和结束时间(假设在当前日期)
|
||||
startTime, err := time.Parse("15:04:05", startTimeStr)
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user