diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index 6e8b36b0..a565aa20 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -124,14 +124,25 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas }, nil } -// called after receiving a frame. -func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { +// called after receiving a RTP packet. +func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) { sh.mutex.Lock() defer sh.mutex.Unlock() - // if we are the publisher, route frames to readers + // if we are the publisher, route packet to readers if ctx.Session == sh.publisher { - sh.stream.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTP, ctx.Payload) + } +} + +// called after receiving a RTCP packet. +func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // if we are the publisher, route packet to readers + if ctx.Session == sh.publisher { + sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTCP, ctx.Payload) } } diff --git a/examples/server/main.go b/examples/server/main.go index d37fa4f5..3fb63cb1 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -123,14 +123,25 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas }, nil } -// called after receiving a frame. -func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { +// called after receiving a RTP packet. +func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) { sh.mutex.Lock() defer sh.mutex.Unlock() - // if we are the publisher, route frames to readers + // if we are the publisher, route packet to readers if ctx.Session == sh.publisher { - sh.stream.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTP, ctx.Payload) + } +} + +// called after receiving a RTCP packet. +func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) { + sh.mutex.Lock() + defer sh.mutex.Unlock() + + // if we are the publisher, route packet to readers + if ctx.Session == sh.publisher { + sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTCP, ctx.Payload) } } diff --git a/server_publish_test.go b/server_publish_test.go index c9f79b9d..14cbc0c1 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -6,7 +6,6 @@ import ( "crypto/tls" "net" "strconv" - "sync/atomic" "testing" "time" @@ -615,7 +614,6 @@ func TestServerPublish(t *testing.T) { connClosed := make(chan struct{}) sessionOpened := make(chan struct{}) sessionClosed := make(chan struct{}) - rtpReceived := uint64(0) s := &Server{ Handler: &testServerHandler{ @@ -646,18 +644,14 @@ func TestServerPublish(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { - if atomic.SwapUint64(&rtpReceived, 1) == 0 { - require.Equal(t, 0, ctx.TrackID) - require.Equal(t, StreamTypeRTP, ctx.StreamType) - require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) - } else { - require.Equal(t, 0, ctx.TrackID) - require.Equal(t, StreamTypeRTCP, ctx.StreamType) - require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) - - ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) - } + onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { + require.Equal(t, 0, ctx.TrackID) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) + }, + onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { + require.Equal(t, 0, ctx.TrackID) + require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) + ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) }, }, } @@ -967,7 +961,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { + onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { t.Error("should not happen") }, }, @@ -1475,7 +1469,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { + onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { }, }, UDPRTPAddress: "127.0.0.1:8000", diff --git a/server_read_test.go b/server_read_test.go index 0f5870ca..24a0ed00 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -303,14 +303,13 @@ func TestServerRead(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { + onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { // skip multicast loopback if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { return } require.Equal(t, 0, ctx.TrackID) - require.Equal(t, StreamTypeRTCP, ctx.StreamType) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) close(framesReceived) }, @@ -1281,8 +1280,6 @@ func TestServerReadUDPChangeConn(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { - }, onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, @@ -1380,8 +1377,6 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { - }, }, UDPRTPAddress: "127.0.0.1:8000", UDPRTCPAddress: "127.0.0.1:8001", diff --git a/server_test.go b/server_test.go index 7abed063..f79e6459 100644 --- a/server_test.go +++ b/server_test.go @@ -48,7 +48,8 @@ type testServerHandler struct { onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) - onFrame func(*ServerHandlerOnFrameCtx) + onPacketRTP func(*ServerHandlerOnPacketRTPCtx) + onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx) onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error) onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error) } @@ -119,9 +120,15 @@ func (sh *testServerHandler) OnPause(ctx *ServerHandlerOnPauseCtx) (*base.Respon return nil, fmt.Errorf("unimplemented") } -func (sh *testServerHandler) OnFrame(ctx *ServerHandlerOnFrameCtx) { - if sh.onFrame != nil { - sh.onFrame(ctx) +func (sh *testServerHandler) OnPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) { + if sh.onPacketRTP != nil { + sh.onPacketRTP(ctx) + } +} + +func (sh *testServerHandler) OnPacketRTCP(ctx *ServerHandlerOnPacketRTCPCtx) { + if sh.onPacketRTCP != nil { + sh.onPacketRTCP(ctx) } } @@ -392,12 +399,20 @@ func TestServerHighLevelPublishRead(t *testing.T) { StatusCode: base.StatusOK, }, nil }, - onFrame: func(ctx *ServerHandlerOnFrameCtx) { + onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) { mutex.Lock() defer mutex.Unlock() if ctx.Session == publisher { - stream.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) + stream.WriteFrame(ctx.TrackID, StreamTypeRTP, ctx.Payload) + } + }, + onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) { + mutex.Lock() + defer mutex.Unlock() + + if ctx.Session == publisher { + stream.WriteFrame(ctx.TrackID, StreamTypeRTCP, ctx.Payload) } }, }, diff --git a/serverconn.go b/serverconn.go index 6bb4e2b7..18cef584 100644 --- a/serverconn.go +++ b/serverconn.go @@ -159,13 +159,22 @@ func (sc *ServerConn) run() { time.Now(), streamType, frame.Payload) } - if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok { - h.OnFrame(&ServerHandlerOnFrameCtx{ - Session: sc.tcpSession, - TrackID: trackID, - StreamType: streamType, - Payload: frame.Payload, - }) + if streamType == StreamTypeRTP { + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Payload: frame.Payload, + }) + } + } else { + if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: sc.tcpSession, + TrackID: trackID, + Payload: frame.Payload, + }) + } } } diff --git a/serverhandler.go b/serverhandler.go index 15ca3f2e..511bf8c1 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -179,15 +179,26 @@ type ServerHandlerOnSetParameter interface { OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error) } -// ServerHandlerOnFrameCtx is the context of a frame. -type ServerHandlerOnFrameCtx struct { +// ServerHandlerOnPacketRTPCtx is the context of a RTP packet. +type ServerHandlerOnPacketRTPCtx struct { Session *ServerSession TrackID int - StreamType StreamType Payload []byte } -// ServerHandlerOnFrame can be implemented by a ServerHandler. -type ServerHandlerOnFrame interface { - OnFrame(*ServerHandlerOnFrameCtx) +// ServerHandlerOnPacketRTP can be implemented by a ServerHandler. +type ServerHandlerOnPacketRTP interface { + OnPacketRTP(*ServerHandlerOnPacketRTPCtx) +} + +// ServerHandlerOnPacketRTCPCtx is the context of a RTCP packet. +type ServerHandlerOnPacketRTCPCtx struct { + Session *ServerSession + TrackID int + Payload []byte +} + +// ServerHandlerOnPacketRTCP can be implemented by a ServerHandler. +type ServerHandlerOnPacketRTCP interface { + OnPacketRTCP(*ServerHandlerOnPacketRTCPCtx) } diff --git a/serverudpl.go b/serverudpl.go index 49b090f1..ca26b73f 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -208,13 +208,22 @@ func (u *serverUDPListener) run() { clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) } - if h, ok := u.s.Handler.(ServerHandlerOnFrame); ok { - h.OnFrame(&ServerHandlerOnFrameCtx{ - Session: clientData.ss, - TrackID: clientData.trackID, - StreamType: u.streamType, - Payload: buf[:n], - }) + if u.streamType == StreamTypeRTP { + if h, ok := u.s.Handler.(ServerHandlerOnPacketRTP); ok { + h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{ + Session: clientData.ss, + TrackID: clientData.trackID, + Payload: buf[:n], + }) + } + } else { + if h, ok := u.s.Handler.(ServerHandlerOnPacketRTCP); ok { + h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{ + Session: clientData.ss, + TrackID: clientData.trackID, + Payload: buf[:n], + }) + } } }() }