mirror of
https://github.com/csznet/goForward.git
synced 2025-09-26 19:31:11 +08:00
311 lines
8.1 KiB
Go
311 lines
8.1 KiB
Go
package forward
|
|
|
|
import (
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"csz.net/goForward/conf"
|
|
"csz.net/goForward/sql"
|
|
)
|
|
|
|
type ConnectionStats struct {
|
|
conf.ConnectionStats
|
|
TotalBytesOld uint64 `gorm:"-"`
|
|
TotalBytesLock sync.Mutex `gorm:"-"`
|
|
TCPConnections []net.Conn `gorm:"-"` // 用于存储 TCP 连接
|
|
TcpTime int `gorm:"-"` // TCP无传输时间
|
|
}
|
|
|
|
// 保存多个连接信息
|
|
type LargeConnectionStats struct {
|
|
Connections []*ConnectionStats `json:"connections"`
|
|
}
|
|
|
|
// 复用缓冲区
|
|
var bufPool = sync.Pool{
|
|
New: func() interface{} {
|
|
return make([]byte, 4096)
|
|
},
|
|
}
|
|
|
|
// 开启转发,负责分发具体转发
|
|
func Run(stats *ConnectionStats) {
|
|
defer releaseResources(stats) // 在函数返回时释放资源
|
|
var ctx, cancel = context.WithCancel(context.Background())
|
|
var innerWg sync.WaitGroup
|
|
defer cancel()
|
|
innerWg.Add(1)
|
|
go func() {
|
|
stats.printStats(ctx)
|
|
innerWg.Done()
|
|
}()
|
|
fmt.Printf("【%s】监听端口 %s 转发至 %s:%s\n", stats.Protocol, stats.LocalPort, stats.RemoteAddr, stats.RemotePort)
|
|
if stats.Protocol == "udp" {
|
|
// UDP转发
|
|
localAddr, err := net.ResolveUDPAddr("udp", ":"+stats.LocalPort)
|
|
if err != nil {
|
|
log.Fatalln("解析本地地址时发生错误:", err)
|
|
}
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", stats.RemoteAddr+":"+stats.RemotePort)
|
|
if err != nil {
|
|
log.Fatalln("解析远程地址时发生错误:", err)
|
|
}
|
|
conn, err := net.ListenUDP("udp", localAddr)
|
|
if err != nil {
|
|
log.Fatalln("监听时发生错误:", err)
|
|
}
|
|
defer conn.Close()
|
|
go func() {
|
|
for {
|
|
select {
|
|
case stopPort := <-conf.Ch:
|
|
if stopPort == stats.LocalPort+stats.Protocol {
|
|
fmt.Printf("【%s】停止监听端口 %s\n", stats.Protocol, stats.LocalPort)
|
|
conn.Close()
|
|
cancel()
|
|
return
|
|
} else {
|
|
conf.Ch <- stopPort
|
|
time.Sleep(3 * time.Second)
|
|
}
|
|
default:
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
}
|
|
}()
|
|
innerWg.Add(1)
|
|
go func() {
|
|
stats.handleUDPConnection(conn, remoteAddr, ctx)
|
|
innerWg.Done()
|
|
}()
|
|
} else {
|
|
// TCP转发
|
|
listener, err := net.Listen("tcp", ":"+stats.LocalPort)
|
|
if err != nil {
|
|
log.Fatalln("监听时发生错误:", err)
|
|
}
|
|
defer listener.Close()
|
|
go func() {
|
|
for {
|
|
select {
|
|
case stopPort := <-conf.Ch:
|
|
fmt.Println("通道信息:" + stopPort)
|
|
fmt.Println("当前端口:" + stats.LocalPort)
|
|
if stopPort == stats.LocalPort+stats.Protocol {
|
|
fmt.Printf("【%s】停止监听端口 %s\n", stats.Protocol, stats.LocalPort)
|
|
listener.Close()
|
|
cancel()
|
|
// 遍历并关闭所有 TCP 连接
|
|
for _, conn := range stats.TCPConnections {
|
|
conn.Close()
|
|
}
|
|
return
|
|
} else {
|
|
conf.Ch <- stopPort
|
|
time.Sleep(3 * time.Second)
|
|
}
|
|
default:
|
|
time.Sleep(1 * time.Second)
|
|
}
|
|
}
|
|
}()
|
|
for {
|
|
clientConn, err := listener.Accept()
|
|
if err != nil {
|
|
log.Println("【"+stats.LocalPort+"】接受连接时发生错误:", err)
|
|
cancel()
|
|
break
|
|
}
|
|
innerWg.Add(1)
|
|
go func() {
|
|
stats.handleTCPConnection(clientConn, ctx, cancel)
|
|
innerWg.Done()
|
|
}()
|
|
}
|
|
}
|
|
innerWg.Wait()
|
|
}
|
|
|
|
// TCP转发
|
|
func (cs *ConnectionStats) handleTCPConnection(clientConn net.Conn, ctx context.Context, cancel context.CancelFunc) {
|
|
defer clientConn.Close()
|
|
remoteConn, err := net.Dial("tcp", cs.RemoteAddr+":"+cs.RemotePort)
|
|
if err != nil {
|
|
log.Println("【"+cs.LocalPort+"】连接远程地址时发生错误:", err)
|
|
return
|
|
}
|
|
defer remoteConn.Close()
|
|
cs.TCPConnections = append(cs.TCPConnections, clientConn, remoteConn) // 添加连接到列表
|
|
var copyWG sync.WaitGroup
|
|
copyWG.Add(2)
|
|
go func() {
|
|
defer copyWG.Done()
|
|
if err := cs.copyBytes(clientConn, remoteConn); err != nil {
|
|
log.Println("复制字节时发生错误:", err)
|
|
cancel() // Assuming `cancel` is the cancel function from the context
|
|
}
|
|
}()
|
|
go func() {
|
|
defer copyWG.Done()
|
|
if err := cs.copyBytes(remoteConn, clientConn); err != nil {
|
|
log.Println("复制字节时发生错误:", err)
|
|
cancel() // Assuming `cancel` is the cancel function from the context
|
|
}
|
|
}()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// 如果上级 context 被取消,停止接收新连接
|
|
return
|
|
default:
|
|
copyWG.Wait()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// UDP转发
|
|
func (cs *ConnectionStats) handleUDPConnection(localConn *net.UDPConn, remoteAddr *net.UDPAddr, ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
buf := bufPool.Get().([]byte)
|
|
n, _, err := localConn.ReadFromUDP(buf)
|
|
if err != nil {
|
|
log.Println("【"+cs.LocalPort+"】从源读取时发生错误:", err)
|
|
return
|
|
}
|
|
fmt.Printf("收到长度为 %d 的UDP数据包\n", n)
|
|
cs.TotalBytesLock.Lock()
|
|
cs.TotalBytes += uint64(n)
|
|
cs.TotalBytesLock.Unlock()
|
|
|
|
// 处理消息的边界和错误情况
|
|
go func() {
|
|
cs.forwardUDPMessage(localConn, remoteAddr, buf[:n])
|
|
bufPool.Put(buf)
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (cs *ConnectionStats) forwardUDPMessage(localConn *net.UDPConn, remoteAddr *net.UDPAddr, message []byte) {
|
|
// 在消息前面添加消息长度信息
|
|
length := make([]byte, 2)
|
|
binary.BigEndian.PutUint16(length, uint16(len(message)))
|
|
// 组合消息长度和实际消息
|
|
data := append(length, message...)
|
|
_, err := localConn.WriteToUDP(data, remoteAddr)
|
|
if err != nil {
|
|
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
|
|
}
|
|
|
|
}
|
|
|
|
func (cs *ConnectionStats) copyBytes(dst, src net.Conn) error {
|
|
buf := bufPool.Get().([]byte)
|
|
defer bufPool.Put(buf)
|
|
for {
|
|
n, err := src.Read(buf)
|
|
if n > 0 {
|
|
cs.TotalBytesLock.Lock()
|
|
cs.TotalBytes += uint64(n)
|
|
cs.TotalBytesLock.Unlock()
|
|
_, err := dst.Write(buf[:n])
|
|
if err != nil {
|
|
log.Println("【"+cs.LocalPort+"】写入目标时发生错误:", err)
|
|
return err
|
|
}
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
log.Println("【"+cs.LocalPort+"】从源读取时发生错误:", err)
|
|
break
|
|
}
|
|
}
|
|
// 关闭连接
|
|
dst.Close()
|
|
src.Close()
|
|
return nil
|
|
}
|
|
|
|
// 定时打印和处理流量变化
|
|
func (cs *ConnectionStats) printStats(ctx context.Context) {
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
defer ticker.Stop() // 在函数结束时停止定时器
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
cs.TotalBytesLock.Lock()
|
|
if cs.TotalBytes > cs.TotalBytesOld {
|
|
if cs.Protocol == "tcp" {
|
|
cs.TcpTime = 0
|
|
}
|
|
var total string
|
|
if cs.TotalBytes > 0 && float64(cs.TotalBytes)/(1024*1024) < 0.5 {
|
|
total = strconv.FormatFloat(float64(cs.TotalBytes)/(1024), 'f', 2, 64) + "KB"
|
|
} else {
|
|
total = strconv.FormatFloat(float64(cs.TotalBytes)/(1024*1024), 'f', 2, 64) + "MB"
|
|
}
|
|
fmt.Printf("【%s】端口 %s 统计流量: %s\n", cs.Protocol, cs.LocalPort, total)
|
|
//统计更换单位
|
|
var gb uint64 = 1073741824
|
|
if cs.TotalBytes >= gb {
|
|
cs.TotalGigabyte = cs.TotalGigabyte + 1
|
|
sql.UpdateForwardGb(cs.Id, cs.TotalGigabyte)
|
|
cs.TotalBytes = cs.TotalBytes - gb
|
|
}
|
|
cs.TotalBytesOld = cs.TotalBytes
|
|
sql.UpdateForwardBytes(cs.Id, cs.TotalBytes)
|
|
fmt.Printf("【%s】端口 %s 当前连接数: %d\n", cs.Protocol, cs.LocalPort, len(cs.TCPConnections))
|
|
} else {
|
|
if cs.Protocol == "tcp" {
|
|
// fmt.Printf("【%s】端口 %s 当前超时秒: %d\n", cs.Protocol, cs.LocalPort, cs.TcpTime)
|
|
if cs.TcpTime >= conf.TcpTimeout {
|
|
// fmt.Printf("【%s】端口 %s 超时关闭\n", cs.Protocol, cs.LocalPort)
|
|
for i := len(cs.TCPConnections) - 1; i >= 0; i-- {
|
|
conn := cs.TCPConnections[i]
|
|
conn.Close()
|
|
// 从连接列表中移除关闭的连接
|
|
cs.TCPConnections = append(cs.TCPConnections[:i], cs.TCPConnections[i+1:]...)
|
|
}
|
|
} else {
|
|
cs.TcpTime = cs.TcpTime + 5
|
|
}
|
|
}
|
|
}
|
|
cs.TotalBytesLock.Unlock()
|
|
//当协程退出时执行
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// 关闭 TCP 连接并从切片中移除
|
|
func closeTCPConnections(stats *ConnectionStats) {
|
|
stats.TotalBytesLock.Lock()
|
|
defer stats.TotalBytesLock.Unlock()
|
|
for i, conn := range stats.TCPConnections {
|
|
conn.Close()
|
|
stats.TCPConnections[i] = nil
|
|
}
|
|
stats.TCPConnections = nil // 清空切片
|
|
}
|
|
|
|
// 释放资源
|
|
func releaseResources(stats *ConnectionStats) {
|
|
closeTCPConnections(stats)
|
|
}
|