Files
webrtc/internal/mux/mux.go
2025-09-19 06:19:21 +03:00

217 lines
4.5 KiB
Go

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package mux multiplexes packets on a single socket (RFC7983)
package mux
import (
"errors"
"io"
"net"
"sync"
"github.com/pion/ice/v4"
"github.com/pion/logging"
"github.com/pion/transport/v3/packetio"
)
const (
// The maximum amount of data that can be buffered before returning errors.
maxBufferSize = 1000 * 1000 // 1MB
// How many total pending packets can be cached.
maxPendingPackets = 15
)
// Config collects the arguments to mux.Mux construction into
// a single structure.
type Config struct {
Conn net.Conn
BufferSize int
LoggerFactory logging.LoggerFactory
}
// Mux allows multiplexing.
type Mux struct {
nextConn net.Conn
bufferSize int
lock sync.Mutex
endpoints map[*Endpoint]MatchFunc
isClosed bool
pendingPackets [][]byte
closedCh chan struct{}
log logging.LeveledLogger
}
// NewMux creates a new Mux.
func NewMux(config Config) *Mux {
mux := &Mux{
nextConn: config.Conn,
endpoints: make(map[*Endpoint]MatchFunc),
bufferSize: config.BufferSize,
closedCh: make(chan struct{}),
log: config.LoggerFactory.NewLogger("mux"),
}
go mux.readLoop()
return mux
}
// NewEndpoint creates a new Endpoint.
func (m *Mux) NewEndpoint(matchFunc MatchFunc) *Endpoint {
endpoint := &Endpoint{
mux: m,
buffer: packetio.NewBuffer(),
}
// Set a maximum size of the buffer in bytes.
endpoint.buffer.SetLimitSize(maxBufferSize)
m.lock.Lock()
m.endpoints[endpoint] = matchFunc
m.lock.Unlock()
go m.handlePendingPackets(endpoint, matchFunc)
return endpoint
}
// RemoveEndpoint removes an endpoint from the Mux.
func (m *Mux) RemoveEndpoint(e *Endpoint) {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.endpoints, e)
}
// Close closes the Mux and all associated Endpoints.
func (m *Mux) Close() error {
m.lock.Lock()
for e := range m.endpoints {
if err := e.close(); err != nil {
m.lock.Unlock()
return err
}
delete(m.endpoints, e)
}
m.isClosed = true
m.lock.Unlock()
err := m.nextConn.Close()
if err != nil {
return err
}
// Wait for readLoop to end
<-m.closedCh
return nil
}
func (m *Mux) readLoop() {
defer func() {
close(m.closedCh)
}()
buf := make([]byte, m.bufferSize)
for {
n, err := m.nextConn.Read(buf)
switch {
case errors.Is(err, io.EOF), errors.Is(err, ice.ErrClosed):
return
case errors.Is(err, io.ErrShortBuffer), errors.Is(err, packetio.ErrTimeout):
m.log.Errorf("mux: failed to read from packetio.Buffer %s", err.Error())
continue
case err != nil:
m.log.Errorf("mux: ending readLoop packetio.Buffer error %s", err.Error())
return
}
if err = m.dispatch(buf[:n]); err != nil {
if errors.Is(err, io.ErrClosedPipe) {
// if the buffer was closed, that's not an error we care to report
return
}
m.log.Errorf("mux: ending readLoop dispatch error %s", err.Error())
return
}
}
}
func (m *Mux) dispatch(buf []byte) error {
if len(buf) == 0 {
m.log.Warnf("Warning: mux: unable to dispatch zero length packet")
return nil
}
var endpoint *Endpoint
m.lock.Lock()
for e, f := range m.endpoints {
if f(buf) {
endpoint = e
break
}
}
if endpoint == nil {
defer m.lock.Unlock()
if !m.isClosed {
if len(m.pendingPackets) >= maxPendingPackets {
m.log.Warnf(
"Warning: mux: no endpoint for packet starting with %d, not adding to queue size(%d)",
buf[0], //nolint:gosec // G602, false positive?
len(m.pendingPackets),
)
} else {
m.log.Warnf(
"Warning: mux: no endpoint for packet starting with %d, adding to queue size(%d)",
buf[0], //nolint:gosec // G602, false positive?
len(m.pendingPackets),
)
m.pendingPackets = append(m.pendingPackets, append([]byte{}, buf...))
}
}
return nil
}
m.lock.Unlock()
_, err := endpoint.buffer.Write(buf)
// Expected when bytes are received faster than the endpoint can process them (#2152, #2180)
if errors.Is(err, packetio.ErrFull) {
m.log.Infof("mux: endpoint buffer is full, dropping packet")
return nil
}
return err
}
func (m *Mux) handlePendingPackets(endpoint *Endpoint, matchFunc MatchFunc) {
m.lock.Lock()
defer m.lock.Unlock()
pendingPackets := make([][]byte, 0, len(m.pendingPackets))
for _, buf := range m.pendingPackets {
if matchFunc(buf) {
if _, err := endpoint.buffer.Write(buf); err != nil {
m.log.Warnf("Warning: mux: error writing packet to endpoint from pending queue: %s", err)
}
} else {
pendingPackets = append(pendingPackets, buf)
}
}
m.pendingPackets = pendingPackets
}