mirror of
https://github.com/langhuihui/monibuca.git
synced 2025-10-04 11:46:23 +08:00
163 lines
4.2 KiB
Go
163 lines
4.2 KiB
Go
package config
|
||
|
||
import (
|
||
"crypto/tls"
|
||
_ "embed"
|
||
"log/slog"
|
||
"net"
|
||
"runtime"
|
||
"time"
|
||
|
||
"github.com/langhuihui/gotask"
|
||
)
|
||
|
||
//go:embed local.monibuca.com_bundle.pem
|
||
var LocalCert []byte
|
||
|
||
//go:embed local.monibuca.com.key
|
||
var LocalKey []byte
|
||
|
||
func GetTLSConfig(certFile, keyFile string) (tslConfig *tls.Config, err error) {
|
||
var keyPair tls.Certificate
|
||
if certFile != "" || keyFile != "" {
|
||
keyPair, err = tls.LoadX509KeyPair(certFile, keyFile)
|
||
} else {
|
||
keyPair, err = tls.X509KeyPair(LocalCert, LocalKey)
|
||
}
|
||
if err == nil {
|
||
tslConfig = &tls.Config{
|
||
Certificates: []tls.Certificate{keyPair},
|
||
NextProtos: []string{"monibuca"},
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
type TCP struct {
|
||
ListenAddr string `desc:"监听地址,格式为ip:port,ip 可省略默认监听所有网卡"`
|
||
ListenAddrTLS string `desc:"监听地址,格式为ip:port,ip 可省略默认监听所有网卡"`
|
||
CertFile string `desc:"证书文件"`
|
||
KeyFile string `desc:"私钥文件"`
|
||
ListenNum int `desc:"同时并行监听数量,0为CPU核心数量"` //同时并行监听数量,0为CPU核心数量
|
||
NoDelay bool `desc:"是否禁用Nagle算法"` //是否禁用Nagle算法
|
||
WriteBuffer int `desc:"写缓冲区大小"` //写缓冲区大小
|
||
ReadBuffer int `desc:"读缓冲区大小"` //读缓冲区大小
|
||
KeepAlive bool `desc:"是否启用KeepAlive"` //是否启用KeepAlive
|
||
AutoListen bool `default:"true" desc:"是否自动监听"`
|
||
}
|
||
|
||
func (config *TCP) CreateTCPWork(logger *slog.Logger, handler TCPHandler) *ListenTCPWork {
|
||
ret := &ListenTCPWork{TCP: config, handler: handler}
|
||
ret.SetDescription("listenAddr", config.ListenAddr)
|
||
ret.Logger = logger.With("addr", config.ListenAddr)
|
||
return ret
|
||
}
|
||
|
||
func (config *TCP) CreateTCPTLSWork(logger *slog.Logger, handler TCPHandler) *ListenTCPTLSWork {
|
||
ret := &ListenTCPTLSWork{ListenTCPWork{TCP: config, handler: handler}}
|
||
ret.SetDescription("listenAddr", config.ListenAddrTLS)
|
||
ret.Logger = logger.With("addr", config.ListenAddrTLS)
|
||
return ret
|
||
}
|
||
|
||
type TCPHandler = func(conn *net.TCPConn) task.ITask
|
||
|
||
type ListenTCPWork struct {
|
||
task.Work
|
||
*TCP
|
||
net.Listener
|
||
handler TCPHandler
|
||
}
|
||
|
||
func (task *ListenTCPWork) Start() (err error) {
|
||
task.Listener, err = net.Listen("tcp", task.ListenAddr)
|
||
if err == nil {
|
||
task.Info("listen tcp")
|
||
} else {
|
||
task.Error("failed to listen tcp", "error", err)
|
||
return err
|
||
}
|
||
if task.handler == nil {
|
||
return nil
|
||
}
|
||
count := task.ListenNum
|
||
if count == 0 {
|
||
count = runtime.NumCPU()
|
||
}
|
||
for range count {
|
||
go task.listen(task.handler)
|
||
}
|
||
return
|
||
}
|
||
|
||
func (task *ListenTCPWork) Dispose() {
|
||
task.Info("tcp server stop")
|
||
task.Listener.Close()
|
||
}
|
||
|
||
type ListenTCPTLSWork struct {
|
||
ListenTCPWork
|
||
}
|
||
|
||
func (task *ListenTCPTLSWork) Start() (err error) {
|
||
var tlsConfig *tls.Config
|
||
if tlsConfig, err = GetTLSConfig(task.CertFile, task.KeyFile); err != nil {
|
||
return
|
||
}
|
||
task.Listener, err = tls.Listen("tcp", task.ListenAddrTLS, tlsConfig)
|
||
if err == nil {
|
||
task.Info("listen tcp tls")
|
||
} else {
|
||
task.Error("failed to listen tcp tls", "error", err)
|
||
}
|
||
return
|
||
}
|
||
|
||
func (task *ListenTCPWork) listen(handler TCPHandler) {
|
||
var tempDelay time.Duration
|
||
for {
|
||
conn, err := task.Accept()
|
||
if err != nil {
|
||
if ne, ok := err.(net.Error); ok && !ne.Timeout() {
|
||
if tempDelay == 0 {
|
||
tempDelay = 5 * time.Millisecond
|
||
} else {
|
||
tempDelay *= 2
|
||
}
|
||
if max := 1 * time.Second; tempDelay > max {
|
||
tempDelay = max
|
||
}
|
||
// slog.Warnf("%s: Accept error: %v; retrying in %v", tcp.DownListenAddr, err, tempDelay)
|
||
time.Sleep(tempDelay)
|
||
continue
|
||
}
|
||
return
|
||
}
|
||
var tcpConn *net.TCPConn
|
||
switch v := conn.(type) {
|
||
case *net.TCPConn:
|
||
tcpConn = v
|
||
case *tls.Conn:
|
||
tcpConn = v.NetConn().(*net.TCPConn)
|
||
}
|
||
if !task.NoDelay {
|
||
tcpConn.SetNoDelay(false)
|
||
}
|
||
if task.WriteBuffer > 0 {
|
||
if err := tcpConn.SetWriteBuffer(task.WriteBuffer); err != nil {
|
||
task.Error("failed to set write buffer", "error", err)
|
||
continue
|
||
}
|
||
}
|
||
if task.ReadBuffer > 0 {
|
||
if err := tcpConn.SetReadBuffer(task.ReadBuffer); err != nil {
|
||
task.Error("failed to set read buffer", "error", err)
|
||
continue
|
||
}
|
||
}
|
||
tempDelay = 0
|
||
subTask := handler(tcpConn)
|
||
task.AddTask(subTask)
|
||
}
|
||
}
|