diff --git a/plugin/rtmp/index.go b/plugin/rtmp/index.go index 11a61db..1a60199 100644 --- a/plugin/rtmp/index.go +++ b/plugin/rtmp/index.go @@ -44,51 +44,51 @@ func (p *RTMPPlugin) OnTCPConnect(conn *net.TCPConn) task.ITask { return ret } -func (task *RTMPServer) Go() (err error) { - if err = task.Handshake(task.conf.C2); err != nil { - task.Error("handshake", "error", err) +func (server *RTMPServer) Go() (err error) { + if err = server.Handshake(server.conf.C2); err != nil { + server.Error("handshake", "error", err) return } var commander Commander var gstreamid uint32 for err == nil { - if commander, err = task.RecvMessage(); err == nil { - task.Debug("recv cmd", "commandName", commander.GetCommand().CommandName) + if commander, err = server.RecvMessage(); err == nil { + server.Debug("recv cmd", "commandName", commander.GetCommand().CommandName) switch cmd := commander.(type) { case *CallMessage: //connect - task.SetDescriptions(cmd.Object) + server.SetDescriptions(cmd.Object) app := cmd.Object["app"] // 客户端要连接到的服务应用名 objectEncoding := cmd.Object["objectEncoding"] // AMF编码方法 switch v := objectEncoding.(type) { case float64: - task.ObjectEncoding = v + server.ObjectEncoding = v default: - task.ObjectEncoding = 0 + server.ObjectEncoding = 0 } - task.AppName = app.(string) - task.Info("connect", "appName", task.AppName, "objectEncoding", task.ObjectEncoding) - err = task.SendMessage(RTMP_MSG_ACK_SIZE, Uint32Message(512<<10)) + server.AppName = app.(string) + server.Info("connect", "appName", server.AppName, "objectEncoding", server.ObjectEncoding) + err = server.SendMessage(RTMP_MSG_ACK_SIZE, Uint32Message(512<<10)) if err != nil { - task.Error("sendMessage ack size", "error", err) + server.Error("sendMessage ack size", "error", err) return } - task.WriteChunkSize = task.conf.ChunkSize - err = task.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(task.conf.ChunkSize)) + server.WriteChunkSize = server.conf.ChunkSize + err = server.SendMessage(RTMP_MSG_CHUNK_SIZE, Uint32Message(server.conf.ChunkSize)) if err != nil { - task.Error("sendMessage chunk size", "error", err) + server.Error("sendMessage chunk size", "error", err) return } - err = task.SendMessage(RTMP_MSG_BANDWIDTH, &SetPeerBandwidthMessage{ + err = server.SendMessage(RTMP_MSG_BANDWIDTH, &SetPeerBandwidthMessage{ AcknowledgementWindowsize: uint32(512 << 10), LimitType: byte(2), }) if err != nil { - task.Error("sendMessage bandwidth", "error", err) + server.Error("sendMessage bandwidth", "error", err) return } - err = task.SendStreamID(RTMP_USER_STREAM_BEGIN, 0) + err = server.SendStreamID(RTMP_USER_STREAM_BEGIN, 0) if err != nil { - task.Error("sendMessage stream begin", "error", err) + server.Error("sendMessage stream begin", "error", err) return } m := new(ResponseConnectMessage) @@ -103,16 +103,18 @@ func (task *RTMPServer) Go() (err error) { m.Infomation = map[string]any{ "level": Level_Status, "code": NetConnection_Connect_Success, - "objectEncoding": task.ObjectEncoding, + "objectEncoding": server.ObjectEncoding, } - err = task.SendMessage(RTMP_MSG_AMF0_COMMAND, m) + err = server.SendMessage(RTMP_MSG_AMF0_COMMAND, m) if err != nil { - task.Error("sendMessage connect", "error", err) + server.Error("sendMessage connect", "error", err) + } else { + server.OnConnected() } case *CommandMessage: // "createStream" gstreamid++ - task.Info("createStream:", "streamId", gstreamid) - task.ResponseCreateStream(cmd.TransactionId, gstreamid) + server.Info("createStream:", "streamId", gstreamid) + server.ResponseCreateStream(cmd.TransactionId, gstreamid) case *CURDStreamMessage: // if stream, ok := receivers[cmd.StreamId]; ok { // stream.Stop() @@ -133,11 +135,11 @@ func (task *RTMPServer) Go() (err error) { // err = nc.SendMessage(RTMP_MSG_AMF0_COMMAND, m) case *PublishMessage: ns := NetStream{ - NetConnection: &task.NetConnection, + NetConnection: &server.NetConnection, StreamID: cmd.StreamId, } var publisher *m7s.Publisher - publisher, err = task.conf.Publish(task.Context, task.AppName+"/"+cmd.PublishingName) + publisher, err = server.conf.Publish(server.Context, server.AppName+"/"+cmd.PublishingName) if err != nil { err = ns.Response(cmd.TransactionId, NetStream_Publish_BadName, Level_Error) } else { @@ -149,18 +151,18 @@ func (task *RTMPServer) Go() (err error) { err = ns.BeginPublish(cmd.TransactionId) } if err != nil { - task.Error("sendMessage publish", "error", err) + server.Error("sendMessage publish", "error", err) } else { - publisher.Using(task) + publisher.Using(server) } case *PlayMessage: - streamPath := task.AppName + "/" + cmd.StreamName + streamPath := server.AppName + "/" + cmd.StreamName ns := NetStream{ - NetConnection: &task.NetConnection, + NetConnection: &server.NetConnection, StreamID: cmd.StreamId, } var suber *m7s.Subscriber - suber, err = task.conf.Subscribe(task.Context, streamPath) + suber, err = server.conf.Subscribe(server.Context, streamPath) if err != nil { err = ns.Response(cmd.TransactionId, NetStream_Play_Failed, Level_Error) } else { @@ -169,13 +171,13 @@ func (task *RTMPServer) Go() (err error) { ns.Subscribe(suber) } if err != nil { - task.Error("sendMessage play", "error", err) + server.Error("sendMessage play", "error", err) } } } else if err == io.EOF || errors.Is(err, io.ErrUnexpectedEOF) { - task.Info("rtmp client closed", "error", err) + server.Info("rtmp client closed", "error", err) } else { - task.Warn("ReadMessage", "error", err) + server.Warn("ReadMessage", "error", err) } } return diff --git a/plugin/rtmp/pkg/client.go b/plugin/rtmp/pkg/client.go index b0babbb..cb1428a 100644 --- a/plugin/rtmp/pkg/client.go +++ b/plugin/rtmp/pkg/client.go @@ -180,6 +180,7 @@ func (c *Client) Run() (err error) { err = c.SendMessage(RTMP_MSG_AMF0_COMMAND, &CommandMessage{"createStream", 2}) if err == nil { c.Info("connected") + c.OnConnected() } } case *ResponseCreateStreamMessage: diff --git a/plugin/rtmp/pkg/net-connection.go b/plugin/rtmp/pkg/net-connection.go index 18e38ea..514a5ac 100644 --- a/plugin/rtmp/pkg/net-connection.go +++ b/plugin/rtmp/pkg/net-connection.go @@ -50,6 +50,19 @@ type Writers = map[uint32]*struct { *m7s.Publisher } +type PingTask struct { + task.TickTask + NetConnection *NetConnection +} + +func (t *PingTask) GetTickInterval() time.Duration { + return time.Second * 10 +} + +func (t *PingTask) Tick(any) { + t.NetConnection.SendPingRequest() +} + type NetConnection struct { task.Job *util.BufReader @@ -77,7 +90,7 @@ func NewNetConnection(conn net.Conn) (ret *NetConnection) { func (nc *NetConnection) Init(conn net.Conn) { nc.Conn = conn - nc.BufReader = util.NewBufReaderWithTimeout(conn, 10*time.Second) + nc.BufReader = util.NewBufReaderWithTimeout(conn, 30*time.Second) nc.bandwidth = RTMP_MAX_CHUNK_SIZE << 3 nc.ReadChunkSize = RTMP_DEFAULT_CHUNK_SIZE nc.WriteChunkSize = RTMP_DEFAULT_CHUNK_SIZE @@ -89,6 +102,12 @@ func (nc *NetConnection) Init(conn net.Conn) { nc.Writers = make(Writers) } +func (nc *NetConnection) OnConnected() { + nc.AddTask(&PingTask{ + NetConnection: nc, + }) +} + func (nc *NetConnection) Dispose() { nc.Conn.Close() nc.BufReader.Recycle() @@ -429,7 +448,6 @@ func (nc *NetConnection) SendMessage(t byte, msg RtmpMessage) (err error) { nc.totalWrite += nc.writeSeqNum nc.writeSeqNum = 0 err = nc.SendMessage(RTMP_MSG_ACK, Uint32Message(nc.totalWrite)) - err = nc.SendPingRequest() } for !nc.writing.CompareAndSwap(false, true) { runtime.Gosched()