move each goroutine in a dedicated struct (#285)

This commit is contained in:
Alessandro Ros
2023-05-17 21:14:00 +02:00
committed by GitHub
parent f94885005f
commit b0947c133e
24 changed files with 675 additions and 641 deletions

304
server.go
View File

@@ -41,7 +41,7 @@ type sessionRequestReq struct {
res chan sessionRequestRes
}
type streamMulticastIPReq struct {
type chGetMulticastIPReq struct {
res chan net.IP
}
@@ -124,7 +124,7 @@ type Server struct {
wg sync.WaitGroup
multicastNet *net.IPNet
multicastNextIP net.IP
tcpListener net.Listener
tcpListener *serverTCPListener
udpRTPListener *serverUDPListener
udpRTCPListener *serverUDPListener
sessions map[string]*ServerSession
@@ -132,10 +132,12 @@ type Server struct {
closeError error
// in
connClose chan *ServerConn
sessionRequest chan sessionRequestReq
sessionClose chan *ServerSession
streamMulticastIP chan streamMulticastIPReq
chNewConn chan net.Conn
chAcceptErr chan error
chCloseConn chan *ServerConn
chHandleRequest chan sessionRequestReq
chCloseSession chan *ServerSession
chGetMulticastIP chan chGetMulticastIPReq
}
// Start starts the server.
@@ -287,8 +289,19 @@ func (s *Server) Start() error {
s.multicastNextIP = s.multicastNet.IP
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.sessions = make(map[string]*ServerSession)
s.conns = make(map[*ServerConn]struct{})
s.chNewConn = make(chan net.Conn)
s.chAcceptErr = make(chan error)
s.chCloseConn = make(chan *ServerConn)
s.chHandleRequest = make(chan sessionRequestReq)
s.chCloseSession = make(chan *ServerSession)
s.chGetMulticastIP = make(chan chGetMulticastIPReq)
var err error
s.tcpListener, err = s.Listen(restrictNetwork("tcp", s.RTSPAddress))
s.tcpListener, err = newServerTCPListener(s)
if err != nil {
if s.udpRTPListener != nil {
s.udpRTPListener.close()
@@ -296,11 +309,10 @@ func (s *Server) Start() error {
if s.udpRTCPListener != nil {
s.udpRTCPListener.close()
}
s.ctxCancel()
return err
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())
s.wg.Add(1)
go s.run()
@@ -324,131 +336,7 @@ func (s *Server) Wait() error {
func (s *Server) run() {
defer s.wg.Done()
s.sessions = make(map[string]*ServerSession)
s.conns = make(map[*ServerConn]struct{})
s.connClose = make(chan *ServerConn)
s.sessionRequest = make(chan sessionRequestReq)
s.sessionClose = make(chan *ServerSession)
s.streamMulticastIP = make(chan streamMulticastIPReq)
s.wg.Add(1)
connNew := make(chan net.Conn)
acceptErr := make(chan error)
go func() {
defer s.wg.Done()
err := func() error {
for {
nconn, err := s.tcpListener.Accept()
if err != nil {
return err
}
select {
case connNew <- nconn:
case <-s.ctx.Done():
nconn.Close()
}
}
}()
select {
case acceptErr <- err:
case <-s.ctx.Done():
}
}()
s.closeError = func() error {
for {
select {
case err := <-acceptErr:
return err
case nconn := <-connNew:
sc := newServerConn(s, nconn)
s.conns[sc] = struct{}{}
case sc := <-s.connClose:
if _, ok := s.conns[sc]; !ok {
continue
}
delete(s.conns, sc)
sc.Close()
case req := <-s.sessionRequest:
if ss, ok := s.sessions[req.id]; ok {
if !req.sc.ip().Equal(ss.author.ip()) ||
req.sc.zone() != ss.author.zone() {
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{},
}
continue
}
select {
case ss.request <- req:
case <-ss.ctx.Done():
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
}
} else {
if !req.create {
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusSessionNotFound,
},
err: liberrors.ErrServerSessionNotFound{},
}
continue
}
ss := newServerSession(s, req.sc)
s.sessions[ss.secretID] = ss
select {
case ss.request <- req:
case <-ss.ctx.Done():
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
}
}
case ss := <-s.sessionClose:
if sss, ok := s.sessions[ss.secretID]; !ok || sss != ss {
continue
}
delete(s.sessions, ss.secretID)
ss.Close()
case req := <-s.streamMulticastIP:
ip32 := uint32(s.multicastNextIP[0])<<24 | uint32(s.multicastNextIP[1])<<16 |
uint32(s.multicastNextIP[2])<<8 | uint32(s.multicastNextIP[3])
mask := uint32(s.multicastNet.Mask[0])<<24 | uint32(s.multicastNet.Mask[1])<<16 |
uint32(s.multicastNet.Mask[2])<<8 | uint32(s.multicastNet.Mask[3])
ip32 = (ip32 & mask) | ((ip32 + 1) & ^mask)
ip := make(net.IP, 4)
ip[0] = byte(ip32 >> 24)
ip[1] = byte(ip32 >> 16)
ip[2] = byte(ip32 >> 8)
ip[3] = byte(ip32)
s.multicastNextIP = ip
req.res <- ip
case <-s.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}()
s.closeError = s.runInner()
s.ctxCancel()
@@ -460,7 +348,100 @@ func (s *Server) run() {
s.udpRTPListener.close()
}
s.tcpListener.Close()
s.tcpListener.close()
}
func (s *Server) runInner() error {
for {
select {
case err := <-s.chAcceptErr:
return err
case nconn := <-s.chNewConn:
sc := newServerConn(s, nconn)
s.conns[sc] = struct{}{}
case sc := <-s.chCloseConn:
if _, ok := s.conns[sc]; !ok {
continue
}
delete(s.conns, sc)
sc.Close()
case req := <-s.chHandleRequest:
if ss, ok := s.sessions[req.id]; ok {
if !req.sc.ip().Equal(ss.author.ip()) ||
req.sc.zone() != ss.author.zone() {
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerCannotUseSessionCreatedByOtherIP{},
}
continue
}
select {
case ss.chHandleRequest <- req:
case <-ss.ctx.Done():
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
}
} else {
if !req.create {
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusSessionNotFound,
},
err: liberrors.ErrServerSessionNotFound{},
}
continue
}
ss := newServerSession(s, req.sc)
s.sessions[ss.secretID] = ss
select {
case ss.chHandleRequest <- req:
case <-ss.ctx.Done():
req.res <- sessionRequestRes{
res: &base.Response{
StatusCode: base.StatusBadRequest,
},
err: liberrors.ErrServerTerminated{},
}
}
}
case ss := <-s.chCloseSession:
if sss, ok := s.sessions[ss.secretID]; !ok || sss != ss {
continue
}
delete(s.sessions, ss.secretID)
ss.Close()
case req := <-s.chGetMulticastIP:
ip32 := uint32(s.multicastNextIP[0])<<24 | uint32(s.multicastNextIP[1])<<16 |
uint32(s.multicastNextIP[2])<<8 | uint32(s.multicastNextIP[3])
mask := uint32(s.multicastNet.Mask[0])<<24 | uint32(s.multicastNet.Mask[1])<<16 |
uint32(s.multicastNet.Mask[2])<<8 | uint32(s.multicastNet.Mask[3])
ip32 = (ip32 & mask) | ((ip32 + 1) & ^mask)
ip := make(net.IP, 4)
ip[0] = byte(ip32 >> 24)
ip[1] = byte(ip32 >> 16)
ip[2] = byte(ip32 >> 8)
ip[3] = byte(ip32)
s.multicastNextIP = ip
req.res <- ip
case <-s.ctx.Done():
return liberrors.ErrServerTerminated{}
}
}
}
// StartAndWait starts the server and waits until a fatal error.
@@ -472,3 +453,56 @@ func (s *Server) StartAndWait() error {
return s.Wait()
}
func (s *Server) getMulticastIP() (net.IP, error) {
res := make(chan net.IP)
select {
case s.chGetMulticastIP <- chGetMulticastIPReq{res: res}:
return <-res, nil
case <-s.ctx.Done():
return nil, fmt.Errorf("terminated")
}
}
func (s *Server) newConn(nconn net.Conn) {
select {
case s.chNewConn <- nconn:
case <-s.ctx.Done():
nconn.Close()
}
}
func (s *Server) acceptErr(err error) {
select {
case s.chAcceptErr <- err:
case <-s.ctx.Done():
}
}
func (s *Server) closeConn(sc *ServerConn) {
select {
case s.chCloseConn <- sc:
case <-s.ctx.Done():
}
}
func (s *Server) closeSession(ss *ServerSession) {
select {
case s.chCloseSession <- ss:
case <-s.ctx.Done():
}
}
func (s *Server) handleRequest(req sessionRequestReq) (*base.Response, *ServerSession, error) {
select {
case s.chHandleRequest <- req:
res := <-req.res
return res.res, res.ss, res.err
case <-s.ctx.Done():
return &base.Response{
StatusCode: base.StatusBadRequest,
}, req.sc.session, liberrors.ErrServerTerminated{}
}
}