Files
lkm/stream/forward_sink.go
2025-06-04 20:55:18 +08:00

229 lines
5.3 KiB
Go

package stream
import (
"encoding/binary"
"github.com/lkmio/avformat/collections"
"github.com/lkmio/lkm/log"
"github.com/lkmio/transport"
"net"
"time"
)
type TransportType int
const (
TransportTypeUDP TransportType = iota
TransportTypeTCPClient
TransportTypeTCPServer
)
func (t TransportType) String() string {
switch t {
case TransportTypeUDP:
return "udp"
case TransportTypeTCPClient:
return "tcp_client"
case TransportTypeTCPServer:
return "tcp_server"
default:
panic("invalid transport type")
}
}
// ForwardSink 转发流Sink, 级联/对讲广播/JT1078转GB28181均使用
type ForwardSink struct {
BaseSink
socket transport.Transport
transportType TransportType
receiveTimer *time.Timer
ssrc uint32
requireSSRCMatch bool // 如果ssrc要求一致, 发包时要检查ssrc是否一致, 不一致则重新拷贝一份
rtpBuffer *RtpBuffer
}
func (f *ForwardSink) OnConnected(conn net.Conn) []byte {
log.Sugar.Infof("%s 连接 conn: %s", f.Protocol, conn.RemoteAddr())
f.receiveTimer.Stop()
// 如果f.Conn赋值后, 发送数据先于EnableAsyncWriteMode执行, 可能会panic
// 所以保险一点, 放在主协程执行
ExecuteSyncEventOnTransStreamPublisher(f.SourceID, func() {
f.Conn = conn
f.BaseSink.EnableAsyncWriteMode(512)
})
return nil
}
func (f *ForwardSink) OnPacket(conn net.Conn, data []byte) []byte {
return nil
}
func (f *ForwardSink) OnDisConnected(conn net.Conn, err error) {
log.Sugar.Infof("%s 断开连接 conn: %s", f.Protocol, conn.RemoteAddr())
f.Close()
}
func (f *ForwardSink) Write(index int, data []*collections.ReferenceCounter[[]byte], ts int64, keyVideo bool) error {
// TCP等待连接后再转发数据
if TransportTypeUDP != f.transportType && f.Conn == nil {
return nil
}
var processedData []*collections.ReferenceCounter[[]byte]
// ssrc不一致, 重新拷贝一份, 修改为指定的ssrc
if f.requireSSRCMatch && f.ssrc != binary.BigEndian.Uint32(data[0].Get()[2+8:]) {
if TransportTypeUDP != f.transportType {
if f.rtpBuffer == nil {
f.rtpBuffer = NewRtpBuffer(1024)
}
processedData = make([]*collections.ReferenceCounter[[]byte], 0, len(data))
} else if f.rtpBuffer == nil {
f.rtpBuffer = NewRtpBuffer(1)
}
for i, datum := range data {
src := datum.Get()
counter := f.rtpBuffer.Get()
bytes := counter.Get()
length := len(src)
copy(bytes, src[:length])
// 修改ssrc
binary.BigEndian.PutUint32(bytes[2+8:], f.ssrc)
// UDP直接发送
if TransportTypeUDP == f.transportType {
_ = f.socket.(*transport.UDPClient).Write(bytes[2:length])
} else {
counter.ResetData(bytes[:length])
counter.Refer()
processedData[i] = counter
}
}
// UDP已经发送, 直接返回
if processedData == nil {
return nil
} else {
// 引用计数保持为1
for _, pkt := range processedData {
pkt.Release()
}
}
}
if processedData == nil {
processedData = data
}
if TransportTypeUDP == f.transportType {
for _, datum := range processedData {
f.socket.(*transport.UDPClient).Write(datum.Get()[2:])
}
} else {
return f.BaseSink.Write(index, processedData, ts, keyVideo)
}
return nil
}
// Close 关闭转发流
func (f *ForwardSink) Close() {
f.BaseSink.Close()
if f.socket != nil {
f.socket.Close()
}
if f.receiveTimer != nil {
f.receiveTimer.Stop()
}
if f.rtpBuffer != nil {
f.rtpBuffer.Clear()
}
}
// StartReceiveTimer 启动tcp sever计时器, 如果计时器触发, 没有连接, 则关闭流
func (f *ForwardSink) StartReceiveTimer() {
f.receiveTimer = time.AfterFunc(ForwardSinkWaitTimeout*time.Second, func() {
if f.Conn == nil {
log.Sugar.Infof("%s 等待连接超时, 关闭sink", f.Protocol)
f.Close()
}
})
}
func (f *ForwardSink) GetSSRC() uint32 {
return f.ssrc
}
func NewForwardSink(transportType TransportType, protocol TransStreamProtocol, sinkId SinkID, sourceId string, addr string, manager transport.Manager, ssrc uint32) (*ForwardSink, int, error) {
sink := &ForwardSink{
BaseSink: BaseSink{ID: sinkId, SourceID: sourceId, State: SessionStateCreated, Protocol: protocol},
transportType: transportType,
ssrc: ssrc,
requireSSRCMatch: true, // 默认要求ssrc一致
}
if transportType == TransportTypeUDP {
remoteAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, 0, err
}
client, err := manager.NewUDPClient(remoteAddr)
if err != nil {
return nil, 0, err
}
sink.socket = client
} else if transportType == TransportTypeTCPClient {
client := transport.TCPClient{}
err := manager.AllocPort(true, func(port uint16) error {
localAddr, err := net.ResolveTCPAddr("tcp", ListenAddr(int(port)))
if err != nil {
return err
}
remoteAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
client.SetHandler(sink)
conn, err := client.Connect(localAddr, remoteAddr)
if err != nil {
return err
}
sink.Conn = conn
return nil
})
if err != nil {
return nil, 0, err
}
sink.socket = &client
} else if transportType == TransportTypeTCPServer {
tcpServer, err := manager.NewTCPServer()
if err != nil {
return nil, 0, err
}
tcpServer.SetHandler(sink)
tcpServer.Accept()
sink.socket = tcpServer
sink.StartReceiveTimer()
}
return sink, sink.socket.ListenPort(), nil
}