Files
go2rtc/pkg/webrtc/conn.go
2025-04-07 16:56:38 +03:00

221 lines
5.0 KiB
Go

package webrtc
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/AlexxIT/go2rtc/pkg/core"
"github.com/pion/rtcp"
"github.com/pion/rtp"
"github.com/pion/webrtc/v4"
)
type Conn struct {
core.Connection
core.Listener
Mode core.Mode `json:"mode"`
pc *webrtc.PeerConnection
offer string
closed core.Waiter
}
func NewConn(pc *webrtc.PeerConnection) *Conn {
c := &Conn{
Connection: core.Connection{
ID: core.NewID(),
FormatName: "webrtc",
Transport: pc,
},
pc: pc,
}
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
// last candidate will be empty
if candidate != nil {
c.Fire(candidate)
}
})
pc.OnDataChannel(func(channel *webrtc.DataChannel) {
c.Fire(channel)
})
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
if state != webrtc.ICEConnectionStateChecking {
return
}
pc.SCTP().Transport().ICETransport().OnSelectedCandidatePairChange(
func(pair *webrtc.ICECandidatePair) {
// fix situation when candidate pair changes multiple times
if i := strings.IndexByte(c.Protocol, '+'); i > 0 {
c.Protocol = c.Protocol[:i]
}
c.Protocol += "+" + pair.Remote.Protocol.String()
c.RemoteAddr = fmt.Sprintf(
"%s:%d %s", sanitizeIP6(pair.Remote.Address), pair.Remote.Port, pair.Remote.Typ,
)
if pair.Remote.RelatedAddress != "" {
c.RemoteAddr += fmt.Sprintf(
" %s:%d", sanitizeIP6(pair.Remote.RelatedAddress), pair.Remote.RelatedPort,
)
}
},
)
})
pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
media, codec := c.getMediaCodec(remote)
if media == nil {
return
}
track, err := c.GetTrack(media, codec)
if err != nil {
return
}
switch c.Mode {
case core.ModePassiveProducer, core.ModeActiveProducer:
// replace the theoretical list of codecs with the actual list of codecs
if len(media.Codecs) > 1 {
media.Codecs = []*core.Codec{codec}
}
}
if c.Mode == core.ModePassiveProducer && remote.Kind() == webrtc.RTPCodecTypeVideo {
go func() {
pkts := []rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: uint32(remote.SSRC())}}
for range time.NewTicker(time.Second * 2).C {
if err := pc.WriteRTCP(pkts); err != nil {
return
}
}
}()
}
for {
b := make([]byte, ReceiveMTU)
n, _, err := remote.Read(b)
if err != nil {
return
}
c.Recv += n
packet := &rtp.Packet{}
if err := packet.Unmarshal(b[:n]); err != nil {
return
}
if len(packet.Payload) == 0 {
continue
}
track.WriteRTP(packet)
}
})
// OK connection:
// 15:01:46 ICE connection state changed: checking
// 15:01:46 peer connection state changed: connected
// 15:01:54 peer connection state changed: disconnected
// 15:02:20 peer connection state changed: failed
//
// Fail connection:
// 14:53:08 ICE connection state changed: checking
// 14:53:39 peer connection state changed: failed
pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
c.Fire(state)
switch state {
case webrtc.PeerConnectionStateConnected:
for _, sender := range c.Senders {
sender.Start()
}
case webrtc.PeerConnectionStateDisconnected, webrtc.PeerConnectionStateFailed, webrtc.PeerConnectionStateClosed:
// disconnect event comes earlier, than failed
// but it comes only for success connections
_ = c.Close()
}
})
return c
}
func (c *Conn) MarshalJSON() ([]byte, error) {
return json.Marshal(c.Connection)
}
func (c *Conn) Close() error {
c.closed.Done(nil)
return c.pc.Close()
}
func (c *Conn) AddCandidate(candidate string) error {
// pion uses only candidate value from json/object candidate struct
return c.pc.AddICECandidate(webrtc.ICECandidateInit{Candidate: candidate})
}
func (c *Conn) getTranseiver(mid string) *webrtc.RTPTransceiver {
for _, tr := range c.pc.GetTransceivers() {
if tr.Mid() == mid {
return tr
}
}
return nil
}
func (c *Conn) getSenderTrack(mid string) *Track {
if tr := c.getTranseiver(mid); tr != nil {
if s := tr.Sender(); s != nil {
if t := s.Track().(*Track); t != nil {
return t
}
}
}
return nil
}
func (c *Conn) getMediaCodec(remote *webrtc.TrackRemote) (*core.Media, *core.Codec) {
for _, tr := range c.pc.GetTransceivers() {
// search Transeiver for this TrackRemote
if tr.Receiver() == nil || tr.Receiver().Track() != remote {
continue
}
// search Media for this MID
for _, media := range c.Medias {
if media.ID != tr.Mid() || media.Direction != core.DirectionRecvonly {
continue
}
// search codec for this PayloadType
for _, codec := range media.Codecs {
if codec.PayloadType != uint8(remote.PayloadType()) {
continue
}
return media, codec
}
}
}
// fix moment when core.ModePassiveProducer or core.ModeActiveProducer
// sends new codec with new payload type to same media
// check GetTrack
panic(core.Caller())
return nil, nil
}
func sanitizeIP6(host string) string {
if strings.IndexByte(host, ':') > 0 {
return "[" + host + "]"
}
return host
}