server: split OnFrame into OnPacketRTP and OnPacketRTCP

This commit is contained in:
aler9
2021-10-30 16:03:08 +02:00
committed by Alessandro Ros
parent 62bd19f770
commit 472430f900
8 changed files with 111 additions and 56 deletions

View File

@@ -124,14 +124,25 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
}, nil }, nil
} }
// called after receiving a frame. // called after receiving a RTP packet.
func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) {
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route frames to readers // if we are the publisher, route packet to readers
if ctx.Session == sh.publisher { if ctx.Session == sh.publisher {
sh.stream.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTP, ctx.Payload)
}
}
// called after receiving a RTCP packet.
func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) {
sh.mutex.Lock()
defer sh.mutex.Unlock()
// if we are the publisher, route packet to readers
if ctx.Session == sh.publisher {
sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTCP, ctx.Payload)
} }
} }

View File

@@ -123,14 +123,25 @@ func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*bas
}, nil }, nil
} }
// called after receiving a frame. // called after receiving a RTP packet.
func (sh *serverHandler) OnFrame(ctx *gortsplib.ServerHandlerOnFrameCtx) { func (sh *serverHandler) OnPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) {
sh.mutex.Lock() sh.mutex.Lock()
defer sh.mutex.Unlock() defer sh.mutex.Unlock()
// if we are the publisher, route frames to readers // if we are the publisher, route packet to readers
if ctx.Session == sh.publisher { if ctx.Session == sh.publisher {
sh.stream.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTP, ctx.Payload)
}
}
// called after receiving a RTCP packet.
func (sh *serverHandler) OnPacketRTCP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) {
sh.mutex.Lock()
defer sh.mutex.Unlock()
// if we are the publisher, route packet to readers
if ctx.Session == sh.publisher {
sh.stream.WriteFrame(ctx.TrackID, gortsplib.StreamTypeRTCP, ctx.Payload)
} }
} }

View File

@@ -6,7 +6,6 @@ import (
"crypto/tls" "crypto/tls"
"net" "net"
"strconv" "strconv"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -615,7 +614,6 @@ func TestServerPublish(t *testing.T) {
connClosed := make(chan struct{}) connClosed := make(chan struct{})
sessionOpened := make(chan struct{}) sessionOpened := make(chan struct{})
sessionClosed := make(chan struct{}) sessionClosed := make(chan struct{})
rtpReceived := uint64(0)
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
@@ -646,18 +644,14 @@ func TestServerPublish(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) { onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
if atomic.SwapUint64(&rtpReceived, 1) == 0 {
require.Equal(t, 0, ctx.TrackID) require.Equal(t, 0, ctx.TrackID)
require.Equal(t, StreamTypeRTP, ctx.StreamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload)
} else { },
onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) {
require.Equal(t, 0, ctx.TrackID) require.Equal(t, 0, ctx.TrackID)
require.Equal(t, StreamTypeRTCP, ctx.StreamType)
require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload) require.Equal(t, []byte{0x05, 0x06, 0x07, 0x08}, ctx.Payload)
ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C}) ctx.Session.WriteFrame(0, StreamTypeRTCP, []byte{0x09, 0x0A, 0x0B, 0x0C})
}
}, },
}, },
} }
@@ -967,7 +961,7 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) { onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
t.Error("should not happen") t.Error("should not happen")
}, },
}, },
@@ -1475,7 +1469,7 @@ func TestServerPublishUDPChangeConn(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) { onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
}, },
}, },
UDPRTPAddress: "127.0.0.1:8000", UDPRTPAddress: "127.0.0.1:8000",

View File

@@ -303,14 +303,13 @@ func TestServerRead(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) { onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) {
// skip multicast loopback // skip multicast loopback
if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 { if transport == "multicast" && atomic.AddUint64(&counter, 1) <= 1 {
return return
} }
require.Equal(t, 0, ctx.TrackID) require.Equal(t, 0, ctx.TrackID)
require.Equal(t, StreamTypeRTCP, ctx.StreamType)
require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload) require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, ctx.Payload)
close(framesReceived) close(framesReceived)
}, },
@@ -1281,8 +1280,6 @@ func TestServerReadUDPChangeConn(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) {
},
onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) { onGetParameter: func(ctx *ServerHandlerOnGetParameterCtx) (*base.Response, error) {
return &base.Response{ return &base.Response{
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
@@ -1380,8 +1377,6 @@ func TestServerReadErrorUDPSamePorts(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) {
},
}, },
UDPRTPAddress: "127.0.0.1:8000", UDPRTPAddress: "127.0.0.1:8000",
UDPRTCPAddress: "127.0.0.1:8001", UDPRTCPAddress: "127.0.0.1:8001",

View File

@@ -48,7 +48,8 @@ type testServerHandler struct {
onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error) onPlay func(*ServerHandlerOnPlayCtx) (*base.Response, error)
onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error) onRecord func(*ServerHandlerOnRecordCtx) (*base.Response, error)
onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error) onPause func(*ServerHandlerOnPauseCtx) (*base.Response, error)
onFrame func(*ServerHandlerOnFrameCtx) onPacketRTP func(*ServerHandlerOnPacketRTPCtx)
onPacketRTCP func(*ServerHandlerOnPacketRTCPCtx)
onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error) onSetParameter func(*ServerHandlerOnSetParameterCtx) (*base.Response, error)
onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error) onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error)
} }
@@ -119,9 +120,15 @@ func (sh *testServerHandler) OnPause(ctx *ServerHandlerOnPauseCtx) (*base.Respon
return nil, fmt.Errorf("unimplemented") return nil, fmt.Errorf("unimplemented")
} }
func (sh *testServerHandler) OnFrame(ctx *ServerHandlerOnFrameCtx) { func (sh *testServerHandler) OnPacketRTP(ctx *ServerHandlerOnPacketRTPCtx) {
if sh.onFrame != nil { if sh.onPacketRTP != nil {
sh.onFrame(ctx) sh.onPacketRTP(ctx)
}
}
func (sh *testServerHandler) OnPacketRTCP(ctx *ServerHandlerOnPacketRTCPCtx) {
if sh.onPacketRTCP != nil {
sh.onPacketRTCP(ctx)
} }
} }
@@ -392,12 +399,20 @@ func TestServerHighLevelPublishRead(t *testing.T) {
StatusCode: base.StatusOK, StatusCode: base.StatusOK,
}, nil }, nil
}, },
onFrame: func(ctx *ServerHandlerOnFrameCtx) { onPacketRTP: func(ctx *ServerHandlerOnPacketRTPCtx) {
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
if ctx.Session == publisher { if ctx.Session == publisher {
stream.WriteFrame(ctx.TrackID, ctx.StreamType, ctx.Payload) stream.WriteFrame(ctx.TrackID, StreamTypeRTP, ctx.Payload)
}
},
onPacketRTCP: func(ctx *ServerHandlerOnPacketRTCPCtx) {
mutex.Lock()
defer mutex.Unlock()
if ctx.Session == publisher {
stream.WriteFrame(ctx.TrackID, StreamTypeRTCP, ctx.Payload)
} }
}, },
}, },

View File

@@ -159,14 +159,23 @@ func (sc *ServerConn) run() {
time.Now(), streamType, frame.Payload) time.Now(), streamType, frame.Payload)
} }
if h, ok := sc.s.Handler.(ServerHandlerOnFrame); ok { if streamType == StreamTypeRTP {
h.OnFrame(&ServerHandlerOnFrameCtx{ if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.tcpSession, Session: sc.tcpSession,
TrackID: trackID, TrackID: trackID,
StreamType: streamType,
Payload: frame.Payload, Payload: frame.Payload,
}) })
} }
} else {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
TrackID: trackID,
Payload: frame.Payload,
})
}
}
} }
case *base.Request: case *base.Request:

View File

@@ -179,15 +179,26 @@ type ServerHandlerOnSetParameter interface {
OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error) OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error)
} }
// ServerHandlerOnFrameCtx is the context of a frame. // ServerHandlerOnPacketRTPCtx is the context of a RTP packet.
type ServerHandlerOnFrameCtx struct { type ServerHandlerOnPacketRTPCtx struct {
Session *ServerSession Session *ServerSession
TrackID int TrackID int
StreamType StreamType
Payload []byte Payload []byte
} }
// ServerHandlerOnFrame can be implemented by a ServerHandler. // ServerHandlerOnPacketRTP can be implemented by a ServerHandler.
type ServerHandlerOnFrame interface { type ServerHandlerOnPacketRTP interface {
OnFrame(*ServerHandlerOnFrameCtx) OnPacketRTP(*ServerHandlerOnPacketRTPCtx)
}
// ServerHandlerOnPacketRTCPCtx is the context of a RTCP packet.
type ServerHandlerOnPacketRTCPCtx struct {
Session *ServerSession
TrackID int
Payload []byte
}
// ServerHandlerOnPacketRTCP can be implemented by a ServerHandler.
type ServerHandlerOnPacketRTCP interface {
OnPacketRTCP(*ServerHandlerOnPacketRTCPCtx)
} }

View File

@@ -208,14 +208,23 @@ func (u *serverUDPListener) run() {
clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n]) clientData.ss.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, u.streamType, buf[:n])
} }
if h, ok := u.s.Handler.(ServerHandlerOnFrame); ok { if u.streamType == StreamTypeRTP {
h.OnFrame(&ServerHandlerOnFrameCtx{ if h, ok := u.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: clientData.ss, Session: clientData.ss,
TrackID: clientData.trackID, TrackID: clientData.trackID,
StreamType: u.streamType,
Payload: buf[:n], Payload: buf[:n],
}) })
} }
} else {
if h, ok := u.s.Handler.(ServerHandlerOnPacketRTCP); ok {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: clientData.ss,
TrackID: clientData.trackID,
Payload: buf[:n],
})
}
}
}() }()
} }
}() }()