Files
kubevpn/pkg/core/tcp.go
2025-04-12 12:30:05 +08:00

89 lines
2.2 KiB
Go

package core
import (
"context"
"crypto/tls"
"errors"
"net"
"github.com/wencaiwulue/kubevpn/v2/pkg/config"
plog "github.com/wencaiwulue/kubevpn/v2/pkg/log"
"github.com/wencaiwulue/kubevpn/v2/pkg/util"
)
type tcpTransporter struct {
tlsConfig *tls.Config
}
func TCPTransporter(tlsInfo map[string][]byte) Transporter {
tlsConfig, err := util.GetTlsClientConfig(tlsInfo)
if err != nil {
if errors.Is(err, util.ErrNoTLSConfig) {
plog.G(context.Background()).Warn("tls config not found in config, use raw tcp mode")
return &tcpTransporter{}
}
plog.G(context.Background()).Errorf("failed to get tls client config: %v", err)
return &tcpTransporter{}
}
return &tcpTransporter{tlsConfig: tlsConfig}
}
func (tr *tcpTransporter) Dial(ctx context.Context, addr string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: config.DialTimeout}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
if tr.tlsConfig == nil {
plog.G(ctx).Debugf("tls config not found in config, use raw tcp mode")
return conn, nil
}
plog.G(ctx).Debugf("use tls mode")
return tls.Client(conn, tr.tlsConfig), nil
}
func TCPListener(addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
listener, err := net.ListenTCP("tcp", laddr)
if err != nil {
return nil, err
}
serverConfig, err := util.GetTlsServerConfig(nil)
if err != nil {
if errors.Is(err, util.ErrNoTLSConfig) {
plog.G(context.Background()).Warn("tls config not found in config, use raw tcp mode")
return &tcpKeepAliveListener{TCPListener: listener}, nil
}
plog.G(context.Background()).Errorf("failed to get tls server config: %v", err)
return nil, err
}
return tls.NewListener(&tcpKeepAliveListener{TCPListener: listener}, serverConfig), nil
}
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln *tcpKeepAliveListener) Accept() (c net.Conn, err error) {
conn, err := ln.AcceptTCP()
if err != nil {
return
}
err = conn.SetKeepAlive(true)
if err != nil {
return nil, err
}
err = conn.SetKeepAlivePeriod(config.KeepAliveTime)
if err != nil {
return nil, err
}
err = conn.SetNoDelay(true)
if err != nil {
return nil, err
}
return conn, nil
}