mirror of
				https://github.com/aler9/gortsplib
				synced 2025-10-31 10:36:26 +08:00 
			
		
		
		
	perform frame readings and writings in separate routines, in order to increase UDP throughput and avoid freezes caused by a single laggy reader (https://github.com/aler9/rtsp-simple-server/issues/125) (https://github.com/aler9/rtsp-simple-server/issues/162)
This commit is contained in:
		
							
								
								
									
										168
									
								
								serverconn.go
									
									
									
									
									
								
							
							
						
						
									
										168
									
								
								serverconn.go
									
									
									
									
									
								
							| @@ -8,13 +8,13 @@ import ( | ||||
| 	"net" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/aler9/gortsplib/pkg/base" | ||||
| 	"github.com/aler9/gortsplib/pkg/headers" | ||||
| 	"github.com/aler9/gortsplib/pkg/multibuffer" | ||||
| 	"github.com/aler9/gortsplib/pkg/ringbuffer" | ||||
| 	"github.com/aler9/gortsplib/pkg/rtcpreceiver" | ||||
| ) | ||||
|  | ||||
| @@ -23,11 +23,13 @@ const ( | ||||
| 	serverConnWriteBufferSize        = 4096 | ||||
| 	serverConnCheckStreamInterval    = 5 * time.Second | ||||
| 	serverConnReceiverReportInterval = 10 * time.Second | ||||
| 	serverConnTCPFrameReadBufferSize = 2048 | ||||
| ) | ||||
|  | ||||
| // server errors. | ||||
| var ( | ||||
| 	ErrServerTeardown = errors.New("teardown") | ||||
| 	ErrServerTeardown    = errors.New("teardown") | ||||
| 	errServerCSeqMissing = errors.New("CSeq is missing") | ||||
| ) | ||||
|  | ||||
| // ServerConnState is the state of the connection. | ||||
| @@ -138,20 +140,24 @@ type ServerConn struct { | ||||
| 	state              ServerConnState | ||||
| 	tracks             map[int]ServerConnTrack | ||||
| 	tracksProtocol     *StreamProtocol | ||||
| 	rtcpReceivers      []*rtcpreceiver.RTCPReceiver | ||||
| 	udpLastFrameTimes  []*int64 | ||||
| 	writeMutex         sync.Mutex | ||||
| 	readHandlers       ServerConnReadHandlers | ||||
| 	nextFramesEnabled  bool | ||||
| 	rtcpReceivers      []*rtcpreceiver.RTCPReceiver | ||||
| 	doEnableFrames     bool | ||||
| 	framesEnabled      bool | ||||
| 	readTimeoutEnabled bool | ||||
| 	udpTimeout         *int32 | ||||
|  | ||||
| 	// writer | ||||
| 	frameRingBuffer     *ringbuffer.RingBuffer | ||||
| 	backgroundWriteDone chan struct{} | ||||
|  | ||||
| 	// background record | ||||
| 	backgroundRecordTerminate chan struct{} | ||||
| 	backgroundRecordDone      chan struct{} | ||||
| 	udpTimeout                int32 | ||||
| 	udpLastFrameTimes         []*int64 | ||||
|  | ||||
| 	// in | ||||
| 	terminate chan struct{} | ||||
|  | ||||
| 	backgroundRecordTerminate chan struct{} | ||||
| 	backgroundRecordDone      chan struct{} | ||||
| } | ||||
|  | ||||
| func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { | ||||
| @@ -163,13 +169,14 @@ func newServerConn(conf ServerConf, nconn net.Conn) *ServerConn { | ||||
| 	}() | ||||
|  | ||||
| 	return &ServerConn{ | ||||
| 		conf:       conf, | ||||
| 		nconn:      nconn, | ||||
| 		br:         bufio.NewReaderSize(conn, serverConnReadBufferSize), | ||||
| 		bw:         bufio.NewWriterSize(conn, serverConnWriteBufferSize), | ||||
| 		tracks:     make(map[int]ServerConnTrack), | ||||
| 		udpTimeout: new(int32), | ||||
| 		terminate:  make(chan struct{}), | ||||
| 		conf:                conf, | ||||
| 		nconn:               nconn, | ||||
| 		br:                  bufio.NewReaderSize(conn, serverConnReadBufferSize), | ||||
| 		bw:                  bufio.NewWriterSize(conn, serverConnWriteBufferSize), | ||||
| 		tracks:              make(map[int]ServerConnTrack), | ||||
| 		frameRingBuffer:     ringbuffer.New(conf.ReadBufferCount), | ||||
| 		backgroundWriteDone: make(chan struct{}), | ||||
| 		terminate:           make(chan struct{}), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -206,6 +213,30 @@ func (sc *ServerConn) Tracks() map[int]ServerConnTrack { | ||||
| 	return sc.tracks | ||||
| } | ||||
|  | ||||
| func (sc *ServerConn) backgroundWrite() { | ||||
| 	defer close(sc.backgroundWriteDone) | ||||
|  | ||||
| 	for { | ||||
| 		what, ok := sc.frameRingBuffer.Pull() | ||||
| 		if !ok { | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		switch w := what.(type) { | ||||
| 		case *base.InterleavedFrame: | ||||
| 			sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 			w.Write(sc.bw) | ||||
|  | ||||
| 		case *base.Response: | ||||
| 			sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 			w.Write(sc.bw) | ||||
|  | ||||
| 		default: | ||||
| 			panic(fmt.Errorf("unsupported type: %T", what)) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (sc *ServerConn) checkState(allowed map[ServerConnState]struct{}) error { | ||||
| 	if _, ok := allowed[sc.state]; ok { | ||||
| 		return nil | ||||
| @@ -236,12 +267,12 @@ func (sc *ServerConn) frameModeEnable() { | ||||
| 	switch sc.state { | ||||
| 	case ServerConnStatePlay: | ||||
| 		if *sc.tracksProtocol == StreamProtocolTCP { | ||||
| 			sc.nextFramesEnabled = true | ||||
| 			sc.doEnableFrames = true | ||||
| 		} | ||||
|  | ||||
| 	case ServerConnStateRecord: | ||||
| 		if *sc.tracksProtocol == StreamProtocolTCP { | ||||
| 			sc.nextFramesEnabled = true | ||||
| 			sc.doEnableFrames = true | ||||
| 			sc.readTimeoutEnabled = true | ||||
|  | ||||
| 		} else { | ||||
| @@ -266,16 +297,25 @@ func (sc *ServerConn) frameModeEnable() { | ||||
| func (sc *ServerConn) frameModeDisable() { | ||||
| 	switch sc.state { | ||||
| 	case ServerConnStatePlay: | ||||
| 		sc.nextFramesEnabled = false | ||||
| 		if *sc.tracksProtocol == StreamProtocolTCP { | ||||
| 			sc.framesEnabled = false | ||||
| 			sc.frameRingBuffer.Close() | ||||
| 			<-sc.backgroundWriteDone | ||||
| 		} | ||||
|  | ||||
| 	case ServerConnStateRecord: | ||||
| 		close(sc.backgroundRecordTerminate) | ||||
| 		<-sc.backgroundRecordDone | ||||
|  | ||||
| 		sc.nextFramesEnabled = false | ||||
| 		sc.readTimeoutEnabled = false | ||||
| 		if *sc.tracksProtocol == StreamProtocolTCP { | ||||
| 			sc.readTimeoutEnabled = false | ||||
| 			sc.nconn.SetReadDeadline(time.Time{}) | ||||
|  | ||||
| 		if *sc.tracksProtocol == StreamProtocolUDP { | ||||
| 			sc.framesEnabled = false | ||||
| 			sc.frameRingBuffer.Close() | ||||
| 			<-sc.backgroundWriteDone | ||||
|  | ||||
| 		} else { | ||||
| 			for _, track := range sc.tracks { | ||||
| 				sc.conf.UDPRTPListener.removePublisher(sc.ip(), track.rtpPort) | ||||
| 				sc.conf.UDPRTCPListener.removePublisher(sc.ip(), track.rtcpPort) | ||||
| @@ -285,6 +325,13 @@ func (sc *ServerConn) frameModeDisable() { | ||||
| } | ||||
|  | ||||
| func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { | ||||
| 	if cseq, ok := req.Header["CSeq"]; !ok || len(cseq) != 1 { | ||||
| 		return &base.Response{ | ||||
| 			StatusCode: base.StatusBadRequest, | ||||
| 			Header:     base.Header{}, | ||||
| 		}, errServerCSeqMissing | ||||
| 	} | ||||
|  | ||||
| 	if sc.readHandlers.OnRequest != nil { | ||||
| 		sc.readHandlers.OnRequest(req) | ||||
| 	} | ||||
| @@ -676,19 +723,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { | ||||
|  | ||||
| func (sc *ServerConn) backgroundRead() error { | ||||
| 	handleRequestOuter := func(req *base.Request) error { | ||||
| 		// check cseq | ||||
| 		cseq, ok := req.Header["CSeq"] | ||||
| 		if !ok || len(cseq) != 1 { | ||||
| 			sc.writeMutex.Lock() | ||||
| 			sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 			base.Response{ | ||||
| 				StatusCode: base.StatusBadRequest, | ||||
| 				Header:     base.Header{}, | ||||
| 			}.Write(sc.bw) | ||||
| 			sc.writeMutex.Unlock() | ||||
| 			return errors.New("CSeq is missing") | ||||
| 		} | ||||
|  | ||||
| 		res, err := sc.handleRequest(req) | ||||
|  | ||||
| 		if res.Header == nil { | ||||
| @@ -696,7 +730,9 @@ func (sc *ServerConn) backgroundRead() error { | ||||
| 		} | ||||
|  | ||||
| 		// add cseq | ||||
| 		res.Header["CSeq"] = cseq | ||||
| 		if err != errServerCSeqMissing { | ||||
| 			res.Header["CSeq"] = req.Header["CSeq"] | ||||
| 		} | ||||
|  | ||||
| 		// add server | ||||
| 		res.Header["Server"] = base.HeaderValue{"gortsplib"} | ||||
| @@ -705,33 +741,42 @@ func (sc *ServerConn) backgroundRead() error { | ||||
| 			sc.readHandlers.OnResponse(res) | ||||
| 		} | ||||
|  | ||||
| 		sc.writeMutex.Lock() | ||||
| 		// start background write | ||||
| 		if sc.doEnableFrames { | ||||
| 			sc.doEnableFrames = false | ||||
| 			sc.framesEnabled = true | ||||
|  | ||||
| 		sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 		res.Write(sc.bw) | ||||
| 			// write response before frames | ||||
| 			sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 			res.Write(sc.bw) | ||||
|  | ||||
| 		// set framesEnabled after sending the response | ||||
| 		// in order to start sending frames after the response, never before | ||||
| 		if sc.framesEnabled != sc.nextFramesEnabled { | ||||
| 			sc.framesEnabled = sc.nextFramesEnabled | ||||
| 			// start background write | ||||
| 			sc.frameRingBuffer.Reset() | ||||
| 			sc.backgroundWriteDone = make(chan struct{}) | ||||
| 			go sc.backgroundWrite() | ||||
|  | ||||
| 			// write to background write | ||||
| 		} else if sc.framesEnabled { | ||||
| 			sc.frameRingBuffer.Push(res) | ||||
|  | ||||
| 			// write directly | ||||
| 		} else { | ||||
| 			sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 			res.Write(sc.bw) | ||||
| 		} | ||||
|  | ||||
| 		sc.writeMutex.Unlock() | ||||
|  | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	var req base.Request | ||||
| 	var frame base.InterleavedFrame | ||||
| 	tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, clientConnTCPFrameReadBufferSize) | ||||
| 	tcpFrameBuffer := multibuffer.New(sc.conf.ReadBufferCount, serverConnTCPFrameReadBufferSize) | ||||
| 	var errRet error | ||||
|  | ||||
| outer: | ||||
| 	for { | ||||
| 		if sc.readTimeoutEnabled { | ||||
| 			sc.nconn.SetReadDeadline(time.Now().Add(sc.conf.ReadTimeout)) | ||||
| 		} else { | ||||
| 			sc.nconn.SetReadDeadline(time.Time{}) | ||||
| 		} | ||||
|  | ||||
| 		if sc.framesEnabled { | ||||
| @@ -764,7 +809,7 @@ outer: | ||||
| 		} else { | ||||
| 			err := req.Read(sc.br) | ||||
| 			if err != nil { | ||||
| 				if atomic.LoadInt32(sc.udpTimeout) == 1 { | ||||
| 				if atomic.LoadInt32(&sc.udpTimeout) == 1 { | ||||
| 					errRet = fmt.Errorf("no UDP packets received recently (maybe there's a firewall/NAT in between)") | ||||
| 				} else { | ||||
| 					errRet = err | ||||
| @@ -801,41 +846,34 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error { | ||||
| } | ||||
|  | ||||
| // WriteFrame writes a frame. | ||||
| func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) error { | ||||
| 	sc.writeMutex.Lock() | ||||
| 	defer sc.writeMutex.Unlock() | ||||
|  | ||||
| func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { | ||||
| 	if *sc.tracksProtocol == StreamProtocolUDP { | ||||
| 		track := sc.tracks[trackID] | ||||
|  | ||||
| 		if streamType == StreamTypeRTP { | ||||
| 			return sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ | ||||
| 			sc.conf.UDPRTPListener.write(payload, &net.UDPAddr{ | ||||
| 				IP:   sc.ip(), | ||||
| 				Zone: sc.zone(), | ||||
| 				Port: track.rtpPort, | ||||
| 			}) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		return sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{ | ||||
| 		sc.conf.UDPRTCPListener.write(payload, &net.UDPAddr{ | ||||
| 			IP:   sc.ip(), | ||||
| 			Zone: sc.zone(), | ||||
| 			Port: track.rtcpPort, | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// StreamProtocolTCP | ||||
|  | ||||
| 	if !sc.framesEnabled { | ||||
| 		return errors.New("frames are disabled") | ||||
| 	} | ||||
|  | ||||
| 	sc.nconn.SetWriteDeadline(time.Now().Add(sc.conf.WriteTimeout)) | ||||
| 	frame := base.InterleavedFrame{ | ||||
| 	sc.frameRingBuffer.Push(&base.InterleavedFrame{ | ||||
| 		TrackID:    trackID, | ||||
| 		StreamType: streamType, | ||||
| 		Payload:    payload, | ||||
| 	} | ||||
| 	return frame.Write(sc.bw) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (sc *ServerConn) backgroundRecord() { | ||||
| @@ -859,7 +897,7 @@ func (sc *ServerConn) backgroundRecord() { | ||||
| 				last := time.Unix(atomic.LoadInt64(lastUnix), 0) | ||||
|  | ||||
| 				if now.Sub(last) >= sc.conf.ReadTimeout { | ||||
| 					atomic.StoreInt32(sc.udpTimeout, 1) | ||||
| 					atomic.StoreInt32(&sc.udpTimeout, 1) | ||||
| 					sc.nconn.Close() | ||||
| 					return | ||||
| 				} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 aler9
					aler9