diff --git a/README.md b/README.md index 72b4c6d5..703f04d5 100644 --- a/README.md +++ b/README.md @@ -5,14 +5,13 @@ [![Go Report Card](https://goreportcard.com/badge/github.com/aler9/rtsp-simple-server)](https://goreportcard.com/report/github.com/aler9/rtsp-simple-server) [![Docker Hub](https://img.shields.io/badge/docker-aler9%2Frtsp--simple--server-blue)](https://hub.docker.com/r/aler9/rtsp-simple-server) -_rtsp-simple-server_ is a simple, ready-to-use and zero-dependency RTSP server and RTSP proxy, a software that allows multiple users to publish and read live video and audio streams. RTSP is a standardized protocol that defines how to perform these operations with the help of a server, that is contacted by both readers and publishers in order to negotiate a streaming protocol. The server is then responsible of relaying the publisher streams to the readers. +_rtsp-simple-server_ is a simple, ready-to-use and zero-dependency RTSP server and RTSP proxy, a software that allows multiple users to publish and read live video and audio streams over time. RTSP, RTP and RTCP are standardized protocol that describe how to perform these operations with the help of a server, that is contacted by both readers and publishers in order to negotiate a streaming protocol. The server is then responsible of relaying the publisher streams to the readers. Features: * Read and publish streams via UDP and TCP * Pull and serve streams from other RTSP servers (RTSP proxy) * Each stream can have multiple video and audio tracks, encoded in any format * Publish multiple streams at once, each in a separate path, that can be read by multiple users -* Supports the RTP/RTCP streaming protocol * Supports authentication * Supports running a script when a client connects or disconnects * Compatible with Linux, Windows and Mac, does not require any dependency or interpreter, it's a single executable diff --git a/main.go b/main.go index e71c37b2..bfc3bd53 100644 --- a/main.go +++ b/main.go @@ -104,13 +104,6 @@ type programEventClientPlay2 struct { func (programEventClientPlay2) isProgramEvent() {} -type programEventClientPause struct { - res chan error - client *serverClient -} - -func (programEventClientPause) isProgramEvent() {} - type programEventClientRecord struct { res chan error client *serverClient @@ -119,8 +112,8 @@ type programEventClientRecord struct { func (programEventClientRecord) isProgramEvent() {} type programEventClientFrameUdp struct { - trackFlowType trackFlowType addr *net.UDPAddr + trackFlowType trackFlowType buf []byte } @@ -544,54 +537,22 @@ outer: evt.client.state = _CLIENT_STATE_PLAY evt.res <- nil - case programEventClientPause: - p.receiverCount -= 1 - evt.client.state = _CLIENT_STATE_PRE_PLAY - evt.res <- nil - case programEventClientRecord: p.publisherCount += 1 evt.client.state = _CLIENT_STATE_RECORD evt.res <- nil case programEventClientFrameUdp: - // find publisher and track id from ip and port - cl, trackId := func() (*serverClient, int) { - for _, pub := range p.publishers { - cl, ok := pub.(*serverClient) - if !ok { - continue - } - - if cl.streamProtocol != _STREAM_PROTOCOL_UDP || - cl.state != _CLIENT_STATE_RECORD || - !cl.ip().Equal(evt.addr.IP) { - continue - } - - for i, t := range cl.streamTracks { - if evt.trackFlowType == _TRACK_FLOW_TYPE_RTP { - if t.rtpPort == evt.addr.Port { - return cl, i - } - } else { - if t.rtcpPort == evt.addr.Port { - return cl, i - } - } - } - } - return nil, -1 - }() - if cl == nil { + client, trackId := p.findPublisher(evt.addr, evt.trackFlowType) + if client == nil { continue } - cl.udpLastFrameTime = time.Now() - p.forwardTrack(cl.path, trackId, evt.trackFlowType, evt.buf) + client.udpLastFrameTime = time.Now() + p.forwardFrame(client.path, trackId, evt.trackFlowType, evt.buf) case programEventClientFrameTcp: - p.forwardTrack(evt.path, evt.trackId, evt.trackFlowType, evt.buf) + p.forwardFrame(evt.path, evt.trackId, evt.trackFlowType, evt.buf) case programEventStreamerReady: evt.streamer.ready = true @@ -611,7 +572,7 @@ outer: } case programEventStreamerFrame: - p.forwardTrack(evt.streamer.path, evt.trackId, evt.trackFlowType, evt.buf) + p.forwardFrame(evt.streamer.path, evt.trackId, evt.trackFlowType, evt.buf) case programEventTerminate: break outer @@ -642,9 +603,6 @@ outer: case programEventClientPlay2: evt.res <- fmt.Errorf("terminated") - case programEventClientPause: - evt.res <- fmt.Errorf("terminated") - case programEventClientRecord: evt.res <- fmt.Errorf("terminated") } @@ -672,27 +630,60 @@ func (p *program) close() { <-p.done } -func (p *program) forwardTrack(path string, id int, trackFlowType trackFlowType, frame []byte) { +func (p *program) findPublisher(addr *net.UDPAddr, trackFlowType trackFlowType) (*serverClient, int) { + for _, pub := range p.publishers { + cl, ok := pub.(*serverClient) + if !ok { + continue + } + + if cl.streamProtocol != _STREAM_PROTOCOL_UDP || + cl.state != _CLIENT_STATE_RECORD || + !cl.ip().Equal(addr.IP) { + continue + } + + for i, t := range cl.streamTracks { + if trackFlowType == _TRACK_FLOW_TYPE_RTP { + if t.rtpPort == addr.Port { + return cl, i + } + } else { + if t.rtcpPort == addr.Port { + return cl, i + } + } + } + } + return nil, -1 +} + +func (p *program) forwardFrame(path string, trackId int, trackFlowType trackFlowType, frame []byte) { for c := range p.clients { if c.path == path && c.state == _CLIENT_STATE_PLAY { if c.streamProtocol == _STREAM_PROTOCOL_UDP { if trackFlowType == _TRACK_FLOW_TYPE_RTP { - p.udplRtp.write(&net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: c.streamTracks[id].rtpPort, - }, frame) - + p.udplRtp.write(&udpAddrBufPair{ + addr: &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: c.streamTracks[trackId].rtpPort, + }, + buf: frame, + }) } else { - p.udplRtcp.write(&net.UDPAddr{ - IP: c.ip(), - Zone: c.zone(), - Port: c.streamTracks[id].rtcpPort, - }, frame) + p.udplRtcp.write(&udpAddrBufPair{ + addr: &net.UDPAddr{ + IP: c.ip(), + Zone: c.zone(), + Port: c.streamTracks[trackId].rtcpPort, + }, + buf: frame, + }) } } else { - c.writeFrame(trackFlowTypeToInterleavedChannel(id, trackFlowType), frame) + c.writeFrame(trackFlowTypeToInterleavedChannel(trackId, trackFlowType), frame) } } } diff --git a/server-client.go b/server-client.go index 1e54d399..23dff185 100644 --- a/server-client.go +++ b/server-client.go @@ -53,24 +53,23 @@ func (cs serverClientState) String() string { } type serverClient struct { - p *program - conn *gortsplib.ConnServer - state serverClientState - path string - authUser string - authPass string - authHelper *gortsplib.AuthServer - authFailures int - streamSdpText []byte // filled only if publisher - streamSdpParsed *sdp.Message // filled only if publisher - streamProtocol streamProtocol - streamTracks []*track - udpLastFrameTime time.Time - udpCheckStreamTicker *time.Ticker - readBuf *doubleBuffer - writeBuf *doubleBuffer + p *program + conn *gortsplib.ConnServer + state serverClientState + path string + authUser string + authPass string + authHelper *gortsplib.AuthServer + authFailures int + streamSdpText []byte // only if publisher + streamSdpParsed *sdp.Message // only if publisher + streamProtocol streamProtocol + streamTracks []*track + udpLastFrameTime time.Time + readBuf *doubleBuffer + writeBuf *doubleBuffer - writeChan chan *gortsplib.InterleavedFrame + writeChan chan *gortsplib.InterleavedFrame // only if state = _CLIENT_STATE_PLAY done chan struct{} } @@ -82,11 +81,9 @@ func newServerClient(p *program, nconn net.Conn) *serverClient { ReadTimeout: p.conf.ReadTimeout, WriteTimeout: p.conf.WriteTimeout, }), - state: _CLIENT_STATE_STARTING, - readBuf: newDoubleBuffer(512 * 1024), - writeBuf: newDoubleBuffer(2048), - writeChan: make(chan *gortsplib.InterleavedFrame), - done: make(chan struct{}), + state: _CLIENT_STATE_STARTING, + readBuf: newDoubleBuffer(512 * 1024), + done: make(chan struct{}), } go c.run() @@ -126,29 +123,30 @@ func (c *serverClient) run() { } } +outer: for { - req, err := c.conn.ReadRequest() - if err != nil { - if err != io.EOF { - c.log("ERR: %s", err) + switch c.state { + case _CLIENT_STATE_PLAY: + ok := c.runPlay() + if !ok { + break outer } - break - } - ok := c.handleRequest(req) - if !ok { - break + case _CLIENT_STATE_RECORD: + ok := c.runRecord() + if !ok { + break outer + } + + default: + ok := c.runNormal() + if !ok { + break outer + } } } - if c.udpCheckStreamTicker != nil { - c.udpCheckStreamTicker.Stop() - } - - go func() { - for range c.writeChan { - } - }() + c.conn.NetConn().Close() // close socket in case it has not been closed yet func() { if c.p.conf.PostScript != "" { @@ -160,16 +158,185 @@ func (c *serverClient) run() { } }() - done := make(chan struct{}) - c.p.events <- programEventClientClose{done, c} - <-done - - close(c.writeChan) - c.conn.NetConn().Close() // close socket in case it has not been closed yet - close(c.done) // close() never blocks } +var errClientChangeRunMode = errors.New("change run mode") +var errClientTerminate = errors.New("terminate") + +func (c *serverClient) runNormal() bool { + var ret bool + +outer: + for { + req, err := c.conn.ReadRequest() + if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } + ret = false + break outer + } + + err = c.handleRequest(req) + switch err { + case errClientChangeRunMode: + ret = true + break outer + + case errClientTerminate: + ret = false + break outer + } + } + + if !ret { + done := make(chan struct{}) + c.p.events <- programEventClientClose{done, c} + <-done + } + + return ret +} + +func (c *serverClient) runPlay() bool { + if c.streamProtocol == _STREAM_PROTOCOL_TCP { + writeDone := make(chan struct{}) + go func() { + defer close(writeDone) + for frame := range c.writeChan { + c.conn.WriteInterleavedFrame(frame) + } + }() + + buf := make([]byte, 2048) + for { + _, err := c.conn.NetConn().Read(buf) + if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } + break + } + } + + done := make(chan struct{}) + c.p.events <- programEventClientClose{done, c} + <-done + + close(c.writeChan) + <-writeDone + + } else { + for { + req, err := c.conn.ReadRequest() + if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } + break + } + + err = c.handleRequest(req) + if err != nil { + break + } + } + + done := make(chan struct{}) + c.p.events <- programEventClientClose{done, c} + <-done + } + + return false +} + +func (c *serverClient) runRecord() bool { + if c.streamProtocol == _STREAM_PROTOCOL_TCP { + frame := &gortsplib.InterleavedFrame{} + + outer: + for { + frame.Content = c.readBuf.swap() + frame.Content = frame.Content[:cap(frame.Content)] + recv, err := c.conn.ReadInterleavedFrameOrRequest(frame) + if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } + break outer + } + + switch recvt := recv.(type) { + case *gortsplib.InterleavedFrame: + trackId, trackFlowType := interleavedChannelToTrackFlowType(frame.Channel) + + if trackId >= len(c.streamTracks) { + c.log("ERR: invalid track id '%d'", trackId) + break outer + } + + c.p.events <- programEventClientFrameTcp{ + c.path, + trackId, + trackFlowType, + frame.Content, + } + + case *gortsplib.Request: + err := c.handleRequest(recvt) + if err != nil { + break outer + } + } + } + + done := make(chan struct{}) + c.p.events <- programEventClientClose{done, c} + <-done + + } else { + c.udpLastFrameTime = time.Now() + + udpCheckStreamTicker := time.NewTicker(_UDP_CHECK_STREAM_INTERVAL) + udpCheckStreamDone := make(chan struct{}) + go func() { + defer close(udpCheckStreamDone) + for range udpCheckStreamTicker.C { + if time.Since(c.udpLastFrameTime) >= _UDP_STREAM_DEAD_AFTER { + c.log("ERR: stream is dead") + c.conn.NetConn().Close() + break + } + } + }() + + for { + req, err := c.conn.ReadRequest() + if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } + break + } + + err = c.handleRequest(req) + if err != nil { + break + } + } + + done := make(chan struct{}) + c.p.events <- programEventClientClose{done, c} + <-done + + udpCheckStreamTicker.Stop() + <-udpCheckStreamDone + } + + return false +} + func (c *serverClient) close() { c.conn.NetConn().Close() <-c.done @@ -223,23 +390,12 @@ func (c *serverClient) authenticate(ips []interface{}, user string, pass string, } ip := c.ip() - - for _, item := range ips { - switch titem := item.(type) { - case net.IP: - if titem.Equal(ip) { - return nil - } - - case *net.IPNet: - if titem.Contains(ip) { - return nil - } - } + if !ipEqualOrInRange(ip, ips) { + c.log("ERR: ip '%s' not allowed", ip) + return errAuthCritical } - c.log("ERR: ip '%s' not allowed", ip) - return errAuthCritical + return nil }() if err != nil { return err @@ -304,13 +460,13 @@ func (c *serverClient) authenticate(ips []interface{}, user string, pass string, return nil } -func (c *serverClient) handleRequest(req *gortsplib.Request) bool { +func (c *serverClient) handleRequest(req *gortsplib.Request) error { c.log(string(req.Method)) cseq, ok := req.Header["CSeq"] if !ok || len(cseq) != 1 { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("cseq missing")) - return false + return errClientTerminate } path := func() string { @@ -343,34 +499,33 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { string(gortsplib.ANNOUNCE), string(gortsplib.SETUP), string(gortsplib.PLAY), - string(gortsplib.PAUSE), string(gortsplib.RECORD), string(gortsplib.TEARDOWN), }, ", ")}, }, }) - return true + return nil case gortsplib.DESCRIBE: if c.state != _CLIENT_STATE_STARTING { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("client is in state '%s' instead of '%s'", c.state, _CLIENT_STATE_STARTING)) - return false + return errClientTerminate } pconf := c.findConfForPath(path) if pconf == nil { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("unable to find a valid configuration for path '%s'", path)) - return false + return errClientTerminate } err := c.authenticate(pconf.readIpsParsed, pconf.ReadUser, pconf.ReadPass, req) if err != nil { if err == errAuthCritical { - return false + return errClientTerminate } - return true + return nil } res := make(chan []byte) @@ -378,7 +533,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { sdp := <-res if sdp == nil { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("no one is streaming on path '%s'", path)) - return false + return errClientTerminate } c.conn.WriteResponse(&gortsplib.Response{ @@ -390,51 +545,51 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { }, Content: sdp, }) - return true + return nil case gortsplib.ANNOUNCE: if c.state != _CLIENT_STATE_STARTING { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("client is in state '%s' instead of '%s'", c.state, _CLIENT_STATE_STARTING)) - return false + return errClientTerminate } pconf := c.findConfForPath(path) if pconf == nil { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("unable to find a valid configuration for path '%s'", path)) - return false + return errClientTerminate } err := c.authenticate(pconf.publishIpsParsed, pconf.PublishUser, pconf.PublishPass, req) if err != nil { if err == errAuthCritical { - return false + return errClientTerminate } - return true + return nil } ct, ok := req.Header["Content-Type"] if !ok || len(ct) != 1 { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("Content-Type header missing")) - return false + return errClientTerminate } if ct[0] != "application/sdp" { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("unsupported Content-Type '%s'", ct)) - return false + return errClientTerminate } sdpParsed, err := gortsplib.SDPParse(req.Content) if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("invalid SDP: %s", err)) - return false + return errClientTerminate } sdpParsed, req.Content = gortsplib.SDPFilter(sdpParsed, req.Content) if len(path) == 0 { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path can't be empty")) - return false + return errClientTerminate } res := make(chan error) @@ -442,7 +597,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { err = <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) - return false + return errClientTerminate } c.streamSdpText = req.Content @@ -454,19 +609,19 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { "CSeq": cseq, }, }) - return true + return nil case gortsplib.SETUP: tsRaw, ok := req.Header["Transport"] if !ok || len(tsRaw) != 1 { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header missing")) - return false + return errClientTerminate } th := gortsplib.ReadHeaderTransport(tsRaw[0]) if _, ok := th["multicast"]; ok { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("multicast is not supported")) - return false + return errClientTerminate } switch c.state { @@ -476,15 +631,15 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { if pconf == nil { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("unable to find a valid configuration for path '%s'", path)) - return false + return errClientTerminate } err := c.authenticate(pconf.readIpsParsed, pconf.ReadUser, pconf.ReadPass, req) if err != nil { if err == errAuthCritical { - return false + return errClientTerminate } - return true + return nil } // play via UDP @@ -501,23 +656,23 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { }() { if _, ok := c.p.protocols[_STREAM_PROTOCOL_UDP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) - return false + return errClientTerminate } rtpPort, rtcpPort := th.GetPorts("client_port") if rtpPort == 0 || rtcpPort == 0 { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not have valid client ports (%s)", tsRaw[0])) - return false + return errClientTerminate } if c.path != "" && path != c.path { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) - return false + return errClientTerminate } if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) - return false + return errClientTerminate } res := make(chan error) @@ -525,7 +680,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { err = <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) - return false + return errClientTerminate } c.conn.WriteResponse(&gortsplib.Response{ @@ -541,23 +696,23 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { "Session": []string{"12345678"}, }, }) - return true + return nil // play via TCP } else if _, ok := th["RTP/AVP/TCP"]; ok { if _, ok := c.p.protocols[_STREAM_PROTOCOL_TCP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) - return false + return errClientTerminate } if c.path != "" && path != c.path { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) - return false + return errClientTerminate } if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't receive tracks with different protocols")) - return false + return errClientTerminate } res := make(chan error) @@ -565,7 +720,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { err = <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) - return false + return errClientTerminate } interleaved := fmt.Sprintf("%d-%d", ((len(c.streamTracks) - 1) * 2), ((len(c.streamTracks)-1)*2)+1) @@ -582,24 +737,24 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { "Session": []string{"12345678"}, }, }) - return true + return nil } else { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP, RTP/AVP/UDP or RTP/AVP/TCP) (%s)", tsRaw[0])) - return false + return errClientTerminate } // record case _CLIENT_STATE_ANNOUNCE, _CLIENT_STATE_PRE_RECORD: if _, ok := th["mode=record"]; !ok { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not contain mode=record")) - return false + return errClientTerminate } // after ANNOUNCE, c.path is already set if path != c.path { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) - return false + return errClientTerminate } // record via UDP @@ -616,23 +771,23 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { }() { if _, ok := c.p.protocols[_STREAM_PROTOCOL_UDP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("UDP streaming is disabled")) - return false + return errClientTerminate } rtpPort, rtcpPort := th.GetPorts("client_port") if rtpPort == 0 || rtcpPort == 0 { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not have valid client ports (%s)", tsRaw[0])) - return false + return errClientTerminate } if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_UDP { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) - return false + return errClientTerminate } if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) - return false + return errClientTerminate } res := make(chan error) @@ -640,7 +795,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { err := <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) - return false + return errClientTerminate } c.conn.WriteResponse(&gortsplib.Response{ @@ -656,35 +811,35 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { "Session": []string{"12345678"}, }, }) - return true + return nil // record via TCP } else if _, ok := th["RTP/AVP/TCP"]; ok { if _, ok := c.p.protocols[_STREAM_PROTOCOL_TCP]; !ok { c.writeResError(req, gortsplib.StatusUnsupportedTransport, fmt.Errorf("TCP streaming is disabled")) - return false + return errClientTerminate } if len(c.streamTracks) > 0 && c.streamProtocol != _STREAM_PROTOCOL_TCP { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("can't publish tracks with different protocols")) - return false + return errClientTerminate } interleaved := th.GetValue("interleaved") if interleaved == "" { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not contain the interleaved field")) - return false + return errClientTerminate } expInterleaved := fmt.Sprintf("%d-%d", 0+len(c.streamTracks)*2, 1+len(c.streamTracks)*2) if interleaved != expInterleaved { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("wrong interleaved value, expected '%s', got '%s'", expInterleaved, interleaved)) - return false + return errClientTerminate } if len(c.streamTracks) >= len(c.streamSdpParsed.Medias) { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("all the tracks have already been setup")) - return false + return errClientTerminate } res := make(chan error) @@ -692,7 +847,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { err := <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) - return false + return errClientTerminate } c.conn.WriteResponse(&gortsplib.Response{ @@ -707,28 +862,28 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { "Session": []string{"12345678"}, }, }) - return true + return nil } else { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("transport header does not contain a valid protocol (RTP/AVP, RTP/AVP/UDP or RTP/AVP/TCP) (%s)", tsRaw[0])) - return false + return errClientTerminate } default: c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("client is in state '%s'", c.state)) - return false + return errClientTerminate } case gortsplib.PLAY: if c.state != _CLIENT_STATE_PRE_PLAY { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("client is in state '%s' instead of '%s'", c.state, _CLIENT_STATE_PRE_PLAY)) - return false + return errClientTerminate } if path != c.path { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) - return false + return errClientTerminate } // check publisher existence @@ -737,7 +892,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { err := <-res if err != nil { c.writeResError(req, gortsplib.StatusBadRequest, err) - return false + return errClientTerminate } // write response before setting state @@ -751,6 +906,9 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { }, }) + c.writeBuf = newDoubleBuffer(2048) + c.writeChan = make(chan *gortsplib.InterleavedFrame) + // set state res = make(chan error) c.p.events <- programEventClientPlay2{res, c} @@ -763,72 +921,23 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return "tracks" }(), c.streamProtocol) - // when protocol is TCP, the RTSP connection becomes a RTP connection - if c.streamProtocol == _STREAM_PROTOCOL_TCP { - // write RTP frames sequentially - go func() { - for frame := range c.writeChan { - c.conn.WriteInterleavedFrame(frame) - } - }() - - // receive RTP feedback, do not parse it, wait until connection closes - buf := make([]byte, 2048) - for { - _, err := c.conn.NetConn().Read(buf) - if err != nil { - if err != io.EOF { - c.log("ERR: %s", err) - } - return false - } - } - } - - return true - - case gortsplib.PAUSE: - if c.state != _CLIENT_STATE_PLAY { - c.writeResError(req, gortsplib.StatusBadRequest, - fmt.Errorf("client is in state '%s' instead of '%s'", c.state, _CLIENT_STATE_PLAY)) - return false - } - - if path != c.path { - c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) - return false - } - - c.log("paused") - - res := make(chan error) - c.p.events <- programEventClientPause{res, c} - <-res - - c.conn.WriteResponse(&gortsplib.Response{ - StatusCode: gortsplib.StatusOK, - Header: gortsplib.Header{ - "CSeq": cseq, - "Session": []string{"12345678"}, - }, - }) - return true + return errClientChangeRunMode case gortsplib.RECORD: if c.state != _CLIENT_STATE_PRE_RECORD { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("client is in state '%s' instead of '%s'", c.state, _CLIENT_STATE_PRE_RECORD)) - return false + return errClientTerminate } if path != c.path { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("path has changed")) - return false + return errClientTerminate } if len(c.streamTracks) != len(c.streamSdpParsed.Medias) { c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("not all tracks have been setup")) - return false + return errClientTerminate } c.conn.WriteResponse(&gortsplib.Response{ @@ -850,79 +959,14 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool { return "tracks" }(), c.streamProtocol) - // when protocol is TCP, the RTSP connection becomes a RTP connection - // receive RTP data and parse it - if c.streamProtocol == _STREAM_PROTOCOL_TCP { - frame := &gortsplib.InterleavedFrame{} - for { - frame.Content = c.readBuf.swap() - frame.Content = frame.Content[:cap(frame.Content)] - recv, err := c.conn.ReadInterleavedFrameOrRequest(frame) - if err != nil { - if err != io.EOF { - c.log("ERR: %s", err) - } - return false - } - - switch recvt := recv.(type) { - case *gortsplib.InterleavedFrame: - trackId, trackFlowType := interleavedChannelToTrackFlowType(frame.Channel) - - if trackId >= len(c.streamTracks) { - c.log("ERR: invalid track id '%d'", trackId) - return false - } - - c.p.events <- programEventClientFrameTcp{ - c.path, - trackId, - trackFlowType, - frame.Content, - } - - case *gortsplib.Request: - cseq, ok := recvt.Header["CSeq"] - if !ok || len(cseq) != 1 { - c.writeResError(recvt, gortsplib.StatusBadRequest, fmt.Errorf("cseq missing")) - return false - } - - switch recvt.Method { - case gortsplib.TEARDOWN: - // close connection silently - return false - - default: - c.writeResError(recvt, gortsplib.StatusBadRequest, fmt.Errorf("unhandled method '%s'", recvt.Method)) - return false - } - } - - } - } else { - c.udpLastFrameTime = time.Now() - c.udpCheckStreamTicker = time.NewTicker(_UDP_CHECK_STREAM_INTERVAL) - - go func() { - for range c.udpCheckStreamTicker.C { - if time.Since(c.udpLastFrameTime) >= _UDP_STREAM_DEAD_AFTER { - c.log("ERR: stream is dead") - c.conn.NetConn().Close() - break - } - } - }() - } - - return true + return errClientChangeRunMode case gortsplib.TEARDOWN: // close connection silently - return false + return errClientTerminate default: c.writeResError(req, gortsplib.StatusBadRequest, fmt.Errorf("unhandled method '%s'", req.Method)) - return false + return errClientTerminate } } diff --git a/server-udpl.go b/server-udpl.go index 44de8059..8aaab5bb 100644 --- a/server-udpl.go +++ b/server-udpl.go @@ -5,7 +5,7 @@ import ( "time" ) -type udpAddrFramePair struct { +type udpAddrBufPair struct { addr *net.UDPAddr buf []byte } @@ -17,7 +17,7 @@ type serverUdpListener struct { readBuf *doubleBuffer writeBuf *doubleBuffer - writeChan chan *udpAddrFramePair + writeChan chan *udpAddrBufPair done chan struct{} } @@ -35,7 +35,7 @@ func newServerUdpListener(p *program, port int, trackFlowType trackFlowType) (*s trackFlowType: trackFlowType, readBuf: newDoubleBuffer(2048), writeBuf: newDoubleBuffer(2048), - writeChan: make(chan *udpAddrFramePair), + writeChan: make(chan *udpAddrBufPair), done: make(chan struct{}), } @@ -69,8 +69,8 @@ func (l *serverUdpListener) run() { } l.p.events <- programEventClientFrameUdp{ - l.trackFlowType, addr, + l.trackFlowType, buf[:n], } } @@ -85,13 +85,12 @@ func (l *serverUdpListener) close() { <-l.done } -func (l *serverUdpListener) write(addr *net.UDPAddr, inbuf []byte) { +func (l *serverUdpListener) write(pair *udpAddrBufPair) { + // replace input buffer with write buffer buf := l.writeBuf.swap() - buf = buf[:len(inbuf)] - copy(buf, inbuf) + buf = buf[:len(pair.buf)] + copy(buf, pair.buf) + pair.buf = buf - l.writeChan <- &udpAddrFramePair{ - addr: addr, - buf: buf, - } + l.writeChan <- pair } diff --git a/streamer-udpl.go b/streamer-udpl.go index 1c2b1e58..9f102602 100644 --- a/streamer-udpl.go +++ b/streamer-udpl.go @@ -15,7 +15,6 @@ type streamerUdpListener struct { nconn *net.UDPConn running bool readBuf *doubleBuffer - lastFrameTime time.Time done chan struct{} } @@ -37,7 +36,6 @@ func newStreamerUdpListener(p *program, port int, streamer *streamer, publisherIp: publisherIp, nconn: nconn, readBuf: newDoubleBuffer(2048), - lastFrameTime: time.Now(), done: make(chan struct{}), } @@ -69,7 +67,7 @@ func (l *streamerUdpListener) run() { continue } - l.lastFrameTime = time.Now() + l.streamer.udpLastFrameTime = time.Now() l.p.events <- programEventStreamerFrame{l.streamer, l.trackId, l.trackFlowType, buf[:n]} } diff --git a/streamer.go b/streamer.go index 4cf99925..c25c7174 100644 --- a/streamer.go +++ b/streamer.go @@ -27,16 +27,17 @@ type streamerUdpListenerPair struct { } type streamer struct { - p *program - path string - ur *url.URL - proto streamProtocol - ready bool - clientSdpParsed *sdp.Message - serverSdpText []byte - serverSdpParsed *sdp.Message - firstTime bool - readBuf *doubleBuffer + p *program + path string + ur *url.URL + proto streamProtocol + ready bool + clientSdpParsed *sdp.Message + serverSdpText []byte + serverSdpParsed *sdp.Message + firstTime bool + udpLastFrameTime time.Time + readBuf *doubleBuffer terminate chan struct{} done chan struct{} @@ -402,6 +403,7 @@ func (s *streamer) runUdp(conn *gortsplib.ConnClient) bool { tickerSendKeepalive := time.NewTicker(_KEEPALIVE_INTERVAL) defer tickerSendKeepalive.Stop() + s.udpLastFrameTime = time.Now() tickerCheckStream := time.NewTicker(_CHECK_STREAM_INTERVAL) defer tickerCheckStream.Stop() @@ -431,21 +433,7 @@ func (s *streamer) runUdp(conn *gortsplib.ConnClient) bool { } case <-tickerCheckStream.C: - lastFrameTime := time.Time{} - - for _, pair := range streamerUdpListenerPairs { - lft := pair.udplRtp.lastFrameTime - if lft.After(lastFrameTime) { - lastFrameTime = lft - } - - lft = pair.udplRtcp.lastFrameTime - if lft.After(lastFrameTime) { - lastFrameTime = lft - } - } - - if time.Since(lastFrameTime) >= _STREAM_DEAD_AFTER { + if time.Since(s.udpLastFrameTime) >= _STREAM_DEAD_AFTER { s.log("ERR: stream is dead") return true } diff --git a/utils.go b/utils.go index 422d1e95..9942dc4c 100644 --- a/utils.go +++ b/utils.go @@ -29,6 +29,23 @@ func parseIpCidrList(in []string) ([]interface{}, error) { return ret, nil } +func ipEqualOrInRange(ip net.IP, ips []interface{}) bool { + for _, item := range ips { + switch titem := item.(type) { + case net.IP: + if titem.Equal(ip) { + return true + } + + case *net.IPNet: + if titem.Contains(ip) { + return true + } + } + } + return false +} + type trackFlowType int const (