diff --git a/tlsLayer/tlsLayer.go b/tlsLayer/tlsLayer.go index 5bfd15e..f820791 100644 --- a/tlsLayer/tlsLayer.go +++ b/tlsLayer/tlsLayer.go @@ -7,6 +7,7 @@ package tlsLayer import ( "crypto/tls" + "crypto/x509" "strings" "unsafe" @@ -33,19 +34,30 @@ func rejectUnknownGetCertificateFunc(certs []*tls.Certificate) func(hello *tls.C if len(certs) == 0 { return nil, utils.ErrInErr{ErrDesc: "len(certs) == 0", ErrDetail: utils.ErrInvalidData} } + if hello == nil { + return nil, utils.ErrInErr{ErrDesc: "hello==nil", ErrDetail: utils.ErrInvalidData} + } sni := strings.ToLower(hello.ServerName) gsni := "*" if index := strings.IndexByte(sni, '.'); index != -1 { gsni += sni[index:] } - for _, keyPair := range certs { - if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { - return keyPair, nil + for _, cert := range certs { + if cert.Leaf == nil { + var e error + cert.Leaf, e = x509.ParseCertificate(cert.Certificate[0]) + if e != nil { + return nil, utils.ErrInErr{ErrDesc: "rejectUnknown: x509.ParseCertificate failed ", ErrDetail: e} + } } - for _, name := range keyPair.Leaf.DNSNames { + + if cert.Leaf.Subject.CommonName == sni || cert.Leaf.Subject.CommonName == gsni { + return cert, nil + } + for _, name := range cert.Leaf.DNSNames { if name == sni || name == gsni { - return keyPair, nil + return cert, nil } } } @@ -69,7 +81,7 @@ func GetTlsConfig(mustHasCert bool, conf Conf) *tls.Config { if err != nil { - if ce := utils.CanLogErr("Can't create tls cert"); ce != nil { + if ce := utils.CanLogErr("Can't init tls cert"); ce != nil { ce.Write(zap.String("cert", conf.CertConf.CertFile), zap.String("key", conf.CertConf.KeyFile), zap.Error(err)) }