Files
core/vendor/github.com/datarhei/gosrt/dial.go
2022-08-12 18:42:53 +03:00

708 lines
16 KiB
Go

package srt
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"math/rand"
"net"
"os"
"sync"
"syscall"
"time"
"github.com/datarhei/gosrt/internal/circular"
"github.com/datarhei/gosrt/internal/crypto"
"github.com/datarhei/gosrt/internal/packet"
)
// ErrClientClosed is returned when the client connection has
// been voluntarily closed.
var ErrClientClosed = errors.New("srt: client closed")
// dialer implements the Conn interface
type dialer struct {
pc *net.UDPConn
localAddr net.Addr
remoteAddr net.Addr
config Config
socketId uint32
initialPacketSequenceNumber circular.Number
crypto crypto.Crypto
conn *srtConn
connLock sync.RWMutex
connChan chan connResponse
start time.Time
rcvQueue chan packet.Packet // for packets that come from the wire
sndQueue chan packet.Packet // for packets that go to the wire
shutdown bool
shutdownLock sync.RWMutex
shutdownOnce sync.Once
stopReader context.CancelFunc
stopWriter context.CancelFunc
doneChan chan error
}
type connResponse struct {
conn *srtConn
err error
}
// Dial connects to the address using the SRT protocol with the given config
// and returns a Conn interface.
//
// The address is of the form "host:port".
//
// Example:
//
// Dial("srt", "127.0.0.1:3000", DefaultConfig())
//
// In case of an error the returned Conn is nil and the error is non-nil.
func Dial(network, address string, config Config) (Conn, error) {
if network != "srt" {
return nil, fmt.Errorf("the network must be 'srt'")
}
if err := config.Validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
if config.Logger == nil {
config.Logger = NewLogger(nil)
}
dl := &dialer{
config: config,
}
raddr, err := net.ResolveUDPAddr("udp", address)
if err != nil {
return nil, fmt.Errorf("unable to resolve address: %w", err)
}
pc, err := net.DialUDP("udp", nil, raddr)
if err != nil {
return nil, fmt.Errorf("failed dialing: %w", err)
}
file, err := pc.File()
if err != nil {
return nil, err
}
// Set TOS
if config.IPTOS > 0 {
err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_TOS, config.IPTOS)
if err != nil {
return nil, fmt.Errorf("failed setting socket option TOS: %w", err)
}
}
// Set TTL
if config.IPTTL > 0 {
err = syscall.SetsockoptInt(int(file.Fd()), syscall.IPPROTO_IP, syscall.IP_TTL, config.IPTTL)
if err != nil {
return nil, fmt.Errorf("failed setting socket option TTL: %w", err)
}
}
dl.pc = pc
dl.localAddr = pc.LocalAddr()
dl.remoteAddr = pc.RemoteAddr()
dl.conn = nil
dl.connChan = make(chan connResponse)
dl.rcvQueue = make(chan packet.Packet, 2048)
dl.sndQueue = make(chan packet.Packet, 2048)
dl.doneChan = make(chan error)
dl.start = time.Now()
// create a new socket ID
r := rand.New(rand.NewSource(time.Now().UnixNano()))
dl.socketId = r.Uint32()
dl.initialPacketSequenceNumber = circular.New(r.Uint32()&packet.MAX_SEQUENCENUMBER, packet.MAX_SEQUENCENUMBER)
go func() {
buffer := make([]byte, MAX_MSS_SIZE) // MTU size
for {
if dl.isShutdown() {
dl.doneChan <- ErrClientClosed
return
}
pc.SetReadDeadline(time.Now().Add(3 * time.Second))
n, _, err := pc.ReadFrom(buffer)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
continue
}
if dl.isShutdown() {
dl.doneChan <- ErrClientClosed
return
}
dl.doneChan <- err
return
}
p := packet.NewPacket(dl.remoteAddr, buffer[:n])
if p == nil {
continue
}
// non-blocking
select {
case dl.rcvQueue <- p:
default:
dl.log("dial", func() string { return "receive queue is full" })
}
}
}()
var readerCtx context.Context
readerCtx, dl.stopReader = context.WithCancel(context.Background())
go dl.reader(readerCtx)
var writerCtx context.Context
writerCtx, dl.stopWriter = context.WithCancel(context.Background())
go dl.writer(writerCtx)
// Send the initial handshake request
dl.sendInduction()
dl.log("dial", func() string { return "waiting for response" })
timer := time.AfterFunc(dl.config.ConnectionTimeout, func() {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("connection timeout. server didn't respond"),
}
})
// Wait for handshake to conclude
response := <-dl.connChan
if response.err != nil {
dl.Close()
return nil, response.err
}
timer.Stop()
dl.connLock.Lock()
dl.conn = response.conn
dl.connLock.Unlock()
return dl, nil
}
func (dl *dialer) checkConnection() error {
select {
case err := <-dl.doneChan:
dl.Close()
return err
default:
}
return nil
}
// reader reads packets from the receive queue and pushes them into the connection
func (dl *dialer) reader(ctx context.Context) {
defer func() {
dl.log("dial", func() string { return "left reader loop" })
}()
dl.log("dial", func() string { return "reader loop started" })
for {
select {
case <-ctx.Done():
return
case p := <-dl.rcvQueue:
if dl.isShutdown() {
break
}
dl.log("packet:recv:dump", func() string { return p.Dump() })
if p.Header().DestinationSocketId != dl.socketId {
break
}
if p.Header().IsControlPacket && p.Header().ControlType == packet.CTRLTYPE_HANDSHAKE {
dl.handleHandshake(p)
break
}
dl.connLock.RLock()
if dl.conn == nil {
dl.connLock.RUnlock()
break
}
dl.conn.push(p)
dl.connLock.RUnlock()
}
}
}
// send adds a packet to the send queue
func (dl *dialer) send(p packet.Packet) {
// non-blocking
select {
case dl.sndQueue <- p:
default:
dl.log("dial", func() string { return "send queue is full" })
}
}
// writer reads packets from the send queue and writes them to the wire
func (dl *dialer) writer(ctx context.Context) {
defer func() {
dl.log("dial", func() string { return "left writer loop" })
}()
dl.log("dial", func() string { return "writer loop started" })
var data bytes.Buffer
for {
select {
case <-ctx.Done():
return
case p := <-dl.sndQueue:
data.Reset()
if err := p.Marshal(&data); err != nil {
p.Decommission()
dl.log("packet:send:error", func() string { return "marshalling packet failed" })
continue
}
buffer := data.Bytes()
dl.log("packet:send:dump", func() string { return p.Dump() })
// Write the packet's contents to the wire.
dl.pc.Write(buffer)
if p.Header().IsControlPacket {
// Control packets can be decommissioned because they will not be sent again
p.Decommission()
}
}
}
}
func (dl *dialer) handleHandshake(p packet.Packet) {
cif := &packet.CIFHandshake{}
err := p.UnmarshalCIF(cif)
dl.log("handshake:recv:dump", func() string { return p.Dump() })
dl.log("handshake:recv:cif", func() string { return cif.String() })
if err != nil {
dl.log("handshake:recv:error", func() string { return err.Error() })
return
}
// assemble the response (4.3.1. Caller-Listener Handshake)
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = cif.SRTSocketId
if cif.HandshakeType == packet.HSTYPE_INDUCTION {
// Verify version
if cif.Version != 5 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support handshake v5"),
}
return
}
// Verify magic number
if cif.ExtensionField != 0x4A17 {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer sent the wrong magic number"),
}
return
}
// Setup crypto context
if len(dl.config.Passphrase) != 0 {
keylen := dl.config.PBKeylen
// If the server advertises a specific block cipher family and key size,
// use this one, otherwise, use the configured one
if cif.EncryptionField != 0 {
switch cif.EncryptionField {
case 2:
keylen = 16
case 3:
keylen = 24
case 4:
keylen = 32
}
}
cr, err := crypto.New(keylen)
if err != nil {
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("failed creating crypto context: %w", err),
}
}
dl.crypto = cr
}
cif.IsRequest = true
cif.HandshakeType = packet.HSTYPE_CONCLUSION
cif.InitialPacketSequenceNumber = dl.initialPacketSequenceNumber
cif.MaxTransmissionUnitSize = dl.config.MSS // MTU size
cif.MaxFlowWindowSize = dl.config.FC
cif.SRTSocketId = dl.socketId
cif.PeerIP.FromNetAddr(dl.localAddr)
cif.HasHS = true
cif.SRTVersion = SRT_VERSION
cif.SRTFlags.TSBPDSND = true
cif.SRTFlags.TSBPDRCV = true
cif.SRTFlags.CRYPT = true // must always set to true
cif.SRTFlags.TLPKTDROP = true
cif.SRTFlags.PERIODICNAK = true
cif.SRTFlags.REXMITFLG = true
cif.SRTFlags.STREAM = false
cif.SRTFlags.PACKET_FILTER = false
cif.RecvTSBPDDelay = uint16(dl.config.ReceiverLatency.Milliseconds())
cif.SendTSBPDDelay = uint16(dl.config.PeerLatency.Milliseconds())
cif.HasSID = true
cif.StreamId = dl.config.StreamId
if dl.crypto != nil {
cif.HasKM = true
cif.SRTKM = &packet.CIFKM{}
if err := dl.crypto.MarshalKM(cif.SRTKM, dl.config.Passphrase, packet.EvenKeyEncrypted); err != nil {
dl.connChan <- connResponse{
conn: nil,
err: err,
}
return
}
}
p.MarshalCIF(cif)
dl.log("handshake:send:dump", func() string { return p.Dump() })
dl.log("handshake:send:cif", func() string { return cif.String() })
dl.send(p)
} else if cif.HandshakeType == packet.HSTYPE_CONCLUSION {
// We only support HSv5
if cif.Version != 5 {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support handshake v5"),
}
return
}
// Check if the peer version is sufficient
if cif.SRTVersion < dl.config.MinVersion {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer SRT version is not sufficient"),
}
return
}
// Check the required SRT flags
if !cif.SRTFlags.TSBPDSND || !cif.SRTFlags.TSBPDRCV || !cif.SRTFlags.TLPKTDROP || !cif.SRTFlags.PERIODICNAK || !cif.SRTFlags.REXMITFLG {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't agree on SRT flags"),
}
return
}
// We only support live streaming
if cif.SRTFlags.STREAM {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("peer doesn't support live streaming"),
}
return
}
// Use the largest TSBPD delay as advertised by the listener, but
// at least 120ms
tsbpdDelay := uint16(120)
if cif.RecvTSBPDDelay > tsbpdDelay {
tsbpdDelay = cif.RecvTSBPDDelay
}
if cif.SendTSBPDDelay > tsbpdDelay {
tsbpdDelay = cif.SendTSBPDDelay
}
// If the peer has a smaller MTU size, adjust to it
if cif.MaxTransmissionUnitSize < dl.config.MSS {
dl.config.MSS = cif.MaxTransmissionUnitSize
dl.config.PayloadSize = dl.config.MSS - SRT_HEADER_SIZE - UDP_HEADER_SIZE
if dl.config.PayloadSize < MIN_PAYLOAD_SIZE {
dl.sendShutdown(cif.SRTSocketId)
dl.connChan <- connResponse{
conn: nil,
err: fmt.Errorf("effective MSS too small (%d bytes) to fit the minimal payload size (%d bytes)", dl.config.MSS, MIN_PAYLOAD_SIZE),
}
return
}
}
// Create a new connection
conn := newSRTConn(srtConnConfig{
localAddr: dl.localAddr,
remoteAddr: dl.remoteAddr,
config: dl.config,
start: dl.start,
socketId: dl.socketId,
peerSocketId: cif.SRTSocketId,
tsbpdTimeBase: uint64(time.Since(dl.start).Microseconds()),
tsbpdDelay: uint64(tsbpdDelay) * 1000,
initialPacketSequenceNumber: cif.InitialPacketSequenceNumber,
crypto: dl.crypto,
keyBaseEncryption: packet.EvenKeyEncrypted,
onSend: dl.send,
onShutdown: func(socketId uint32) { dl.Close() },
logger: dl.config.Logger,
})
dl.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s)", conn.SocketId(), conn.StreamId()) })
dl.connChan <- connResponse{
conn: conn,
err: nil,
}
} else {
var err error
if cif.HandshakeType.IsRejection() {
err = fmt.Errorf("connection rejected: %s", cif.HandshakeType.String())
} else {
err = fmt.Errorf("unsupported handshake: %s", cif.HandshakeType.String())
}
dl.connChan <- connResponse{
conn: nil,
err: err,
}
}
}
func (dl *dialer) sendInduction() {
p := packet.NewPacket(dl.remoteAddr, nil)
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_HANDSHAKE
p.Header().SubType = 0
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = 0
cif := &packet.CIFHandshake{
IsRequest: true,
Version: 4,
EncryptionField: 0,
ExtensionField: 2,
InitialPacketSequenceNumber: circular.New(0, packet.MAX_SEQUENCENUMBER),
MaxTransmissionUnitSize: dl.config.MSS, // MTU size
MaxFlowWindowSize: dl.config.FC,
HandshakeType: packet.HSTYPE_INDUCTION,
SRTSocketId: dl.socketId,
SynCookie: 0,
}
cif.PeerIP.FromNetAddr(dl.localAddr)
p.MarshalCIF(cif)
dl.log("handshake:send:dump", func() string { return p.Dump() })
dl.log("handshake:send:cif", func() string { return cif.String() })
dl.send(p)
}
func (dl *dialer) sendShutdown(peerSocketId uint32) {
p := packet.NewPacket(dl.remoteAddr, nil)
data := [4]byte{}
binary.BigEndian.PutUint32(data[0:], 0)
p.SetData(data[0:4])
p.Header().IsControlPacket = true
p.Header().ControlType = packet.CTRLTYPE_SHUTDOWN
p.Header().TypeSpecific = 0
p.Header().Timestamp = uint32(time.Since(dl.start).Microseconds())
p.Header().DestinationSocketId = peerSocketId
dl.log("control:send:shutdown:dump", func() string { return p.Dump() })
dl.send(p)
}
func (dl *dialer) LocalAddr() net.Addr {
return dl.conn.LocalAddr()
}
func (dl *dialer) RemoteAddr() net.Addr {
return dl.conn.RemoteAddr()
}
func (dl *dialer) SocketId() uint32 {
return dl.conn.SocketId()
}
func (dl *dialer) PeerSocketId() uint32 {
return dl.conn.PeerSocketId()
}
func (dl *dialer) StreamId() string {
return dl.conn.StreamId()
}
func (dl *dialer) isShutdown() bool {
dl.shutdownLock.RLock()
defer dl.shutdownLock.RUnlock()
return dl.shutdown
}
func (dl *dialer) Close() error {
dl.shutdownOnce.Do(func() {
dl.shutdownLock.Lock()
dl.shutdown = true
dl.shutdownLock.Unlock()
dl.connLock.RLock()
if dl.conn != nil {
dl.conn.Close()
}
dl.connLock.RUnlock()
dl.stopReader()
dl.stopWriter()
dl.log("dial", func() string { return "closing socket" })
dl.pc.Close()
select {
case <-dl.doneChan:
default:
}
})
return nil
}
func (dl *dialer) Read(p []byte) (n int, err error) {
if err := dl.checkConnection(); err != nil {
return 0, err
}
dl.connLock.RLock()
defer dl.connLock.RUnlock()
return dl.conn.Read(p)
}
func (dl *dialer) readPacket() (packet.Packet, error) {
if err := dl.checkConnection(); err != nil {
return nil, err
}
dl.connLock.RLock()
defer dl.connLock.RUnlock()
return dl.conn.readPacket()
}
func (dl *dialer) Write(p []byte) (n int, err error) {
if err := dl.checkConnection(); err != nil {
return 0, err
}
dl.connLock.RLock()
defer dl.connLock.RUnlock()
return dl.conn.Write(p)
}
func (dl *dialer) writePacket(p packet.Packet) error {
if err := dl.checkConnection(); err != nil {
return err
}
dl.connLock.RLock()
defer dl.connLock.RUnlock()
return dl.conn.writePacket(p)
}
func (dl *dialer) SetDeadline(t time.Time) error { return dl.conn.SetDeadline(t) }
func (dl *dialer) SetReadDeadline(t time.Time) error { return dl.conn.SetReadDeadline(t) }
func (dl *dialer) SetWriteDeadline(t time.Time) error { return dl.conn.SetWriteDeadline(t) }
func (dl *dialer) Stats() Statistics { return dl.conn.Stats() }
func (dl *dialer) log(topic string, message func() string) {
dl.config.Logger.Print(topic, dl.socketId, 2, message)
}