WebRTC module refactoring

This commit is contained in:
Alexey Khit
2023-02-24 10:45:40 +03:00
parent eca79f1c0b
commit 3fb917f00f
6 changed files with 108 additions and 85 deletions

View File

@@ -12,7 +12,23 @@ import (
// Message - struct for data exchange in Web API // Message - struct for data exchange in Web API
type Message struct { type Message struct {
Type string `json:"type"` Type string `json:"type"`
Value interface{} `json:"value,omitempty"` Value any `json:"value,omitempty"`
}
func (m *Message) String() string {
if s, ok := m.Value.(string); ok {
return s
}
return ""
}
func (m *Message) GetString(key string) string {
if v, ok := m.Value.(map[string]any); ok {
if s, ok := v[key].(string); ok {
return s
}
}
return ""
} }
type WSHandler func(tr *Transport, msg *Message) error type WSHandler func(tr *Transport, msg *Message) error

View File

@@ -21,7 +21,7 @@ func handlerWSMSE(tr *api.Transport, msg *api.Message) error {
UserAgent: tr.Request.UserAgent(), UserAgent: tr.Request.UserAgent(),
} }
if codecs, ok := msg.Value.(string); ok { if codecs := msg.String(); codecs != "" {
log.Trace().Str("codecs", codecs).Msgf("[mp4] new WS/MSE consumer") log.Trace().Str("codecs", codecs).Msgf("[mp4] new WS/MSE consumer")
cons.Medias = parseMedias(codecs, true) cons.Medias = parseMedias(codecs, true)
} }
@@ -69,7 +69,7 @@ func handlerWSMP4(tr *api.Transport, msg *api.Message) error {
OnlyKeyframe: true, OnlyKeyframe: true,
} }
if codecs, ok := msg.Value.(string); ok { if codecs := msg.String(); codecs != "" {
log.Trace().Str("codecs", codecs).Msgf("[mp4] new WS/MP4 consumer") log.Trace().Str("codecs", codecs).Msgf("[mp4] new WS/MP4 consumer")
cons.Medias = parseMedias(codecs, false) cons.Medias = parseMedias(codecs, false)
} }

View File

@@ -87,10 +87,10 @@ func candidateHandler(tr *api.Transport, msg *api.Message) error {
if tr.Consumer == nil { if tr.Consumer == nil {
return nil return nil
} }
if conn := tr.Consumer.(*webrtc.Conn); conn != nil { if conn := tr.Consumer.(*webrtc.Server); conn != nil {
s := msg.Value.(string) candidate := msg.String()
log.Trace().Str("candidate", s).Msg("[webrtc] remote") log.Trace().Str("candidate", candidate).Msg("[webrtc] remote")
conn.AddCandidate(s) conn.AddCandidate(candidate)
} }
return nil return nil
} }

View File

@@ -34,14 +34,16 @@ func Init() {
log = app.GetLogger("webrtc") log = app.GetLogger("webrtc")
address := cfg.Mod.Listen address := cfg.Mod.Listen
// create pionAPI with custom codecs list and custom network settings
pionAPI, err := webrtc.NewAPI(address) pionAPI, err := webrtc.NewAPI(address)
if pionAPI == nil { if pionAPI == nil {
log.Error().Err(err).Caller().Msg("webrtc.NewAPI") log.Error().Err(err).Caller().Send()
return return
} }
if err != nil { if err != nil {
log.Warn().Err(err).Msg("[webrtc] listen") log.Warn().Err(err).Caller().Send()
} else if address != "" { } else if address != "" {
log.Info().Str("addr", address).Msg("[webrtc] listen") log.Info().Str("addr", address).Msg("[webrtc] listen")
_, Port, _ = net.SplitHostPort(address) _, Port, _ = net.SplitHostPort(address)
@@ -52,12 +54,13 @@ func Init() {
SDPSemantics: pion.SDPSemanticsUnifiedPlanWithFallback, SDPSemantics: pion.SDPSemanticsUnifiedPlanWithFallback,
} }
NewPConn = func() (*pion.PeerConnection, error) { newPeerConnection = func() (*pion.PeerConnection, error) {
return pionAPI.NewPeerConnection(pionConf) return pionAPI.NewPeerConnection(pionConf)
} }
candidates = cfg.Mod.Candidates candidates = cfg.Mod.Candidates
api.HandleWS("webrtc", asyncHandler)
api.HandleWS("webrtc/offer", asyncHandler) api.HandleWS("webrtc/offer", asyncHandler)
api.HandleWS("webrtc/candidate", candidateHandler) api.HandleWS("webrtc/candidate", candidateHandler)
@@ -67,7 +70,7 @@ func Init() {
var Port string var Port string
var log zerolog.Logger var log zerolog.Logger
var NewPConn func() (*pion.PeerConnection, error) var newPeerConnection func() (*pion.PeerConnection, error)
func asyncHandler(tr *api.Transport, msg *api.Message) error { func asyncHandler(tr *api.Transport, msg *api.Message) error {
src := tr.Request.URL.Query().Get("src") src := tr.Request.URL.Query().Get("src")
@@ -78,22 +81,23 @@ func asyncHandler(tr *api.Transport, msg *api.Message) error {
log.Debug().Str("url", src).Msg("[webrtc] new consumer") log.Debug().Str("url", src).Msg("[webrtc] new consumer")
var err error // create new PeerConnection instance
pc, err := newPeerConnection()
// create new webrtc instance
conn := new(webrtc.Conn)
conn.Conn, err = NewPConn()
if err != nil { if err != nil {
log.Error().Err(err).Caller().Send() log.Error().Err(err).Caller().Send()
return err return err
} }
conn.UserAgent = tr.Request.UserAgent() // apiV2 - json/object exchange, V2 - raw SDP and raw Candidates exchange
conn.Listen(func(msg interface{}) { apiV2 := msg.Type == "webrtc"
cons := webrtc.NewServer(pc)
cons.UserAgent = tr.Request.UserAgent()
cons.Listen(func(msg any) {
switch msg := msg.(type) { switch msg := msg.(type) {
case pion.PeerConnectionState: case pion.PeerConnectionState:
if msg == pion.PeerConnectionStateClosed { if msg == pion.PeerConnectionStateClosed {
stream.RemoveConsumer(conn) stream.RemoveConsumer(cons)
} }
case *pion.ICECandidate: case *pion.ICECandidate:
if msg != nil { if msg != nil {
@@ -105,25 +109,29 @@ func asyncHandler(tr *api.Transport, msg *api.Message) error {
}) })
// 1. SetOffer, so we can get remote client codecs // 1. SetOffer, so we can get remote client codecs
offer := msg.Value.(string) var offer string
if apiV2 {
offer = msg.GetString("sdp")
} else {
offer = msg.String()
}
log.Trace().Msgf("[webrtc] offer:\n%s", offer) log.Trace().Msgf("[webrtc] offer:\n%s", offer)
if err = conn.SetOffer(offer); err != nil { if err = cons.SetOffer(offer); err != nil {
log.Warn().Err(err).Caller().Send() log.Warn().Err(err).Caller().Send()
return err return err
} }
// 2. AddConsumer, so we get new tracks // 2. AddConsumer, so we get new tracks
if err = stream.AddConsumer(conn); err != nil { if err = stream.AddConsumer(cons); err != nil {
log.Debug().Err(err).Msg("[webrtc] add consumer") log.Debug().Err(err).Msg("[webrtc] add consumer")
_ = conn.Conn.Close() _ = cons.Close()
return err return err
} }
conn.Init()
// 3. Exchange SDP without waiting all candidates // 3. Exchange SDP without waiting all candidates
answer, err := conn.GetAnswer() answer, err := cons.GetAnswer()
log.Trace().Msgf("[webrtc] answer\n%s", answer) log.Trace().Msgf("[webrtc] answer\n%s", answer)
if err != nil { if err != nil {
@@ -131,7 +139,7 @@ func asyncHandler(tr *api.Transport, msg *api.Message) error {
return err return err
} }
tr.Consumer = conn tr.Consumer = cons
tr.Write(&api.Message{Type: "webrtc/answer", Value: answer}) tr.Write(&api.Message{Type: "webrtc/answer", Value: answer})
@@ -140,11 +148,6 @@ func asyncHandler(tr *api.Transport, msg *api.Message) error {
return nil return nil
} }
type SDP struct {
Type string `json:"type"`
Sdp string `json:"sdp"`
}
// syncHandler // syncHandler
func syncHandler(w http.ResponseWriter, r *http.Request) { func syncHandler(w http.ResponseWriter, r *http.Request) {
url := r.URL.Query().Get("src") url := r.URL.Query().Get("src")
@@ -153,21 +156,23 @@ func syncHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
var offer string
ct := r.Header.Get("Content-Type") ct := r.Header.Get("Content-Type")
if ct != "" { if ct != "" {
ct, _, _ = mime.ParseMediaType(ct) ct, _, _ = mime.ParseMediaType(ct)
} }
if ct == "application/json" { // apiV2 - json/object exchange, V1 - raw SDP exchange
var v SDP apiV2 := ct == "application/json"
if err := json.NewDecoder(r.Body).Decode(&v); err != nil {
var offer string
if apiV2 {
var sd pion.SessionDescription
if err := json.NewDecoder(r.Body).Decode(&sd); err != nil {
log.Error().Err(err).Caller().Send() log.Error().Err(err).Caller().Send()
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
offer = v.Sdp offer = sd.SDP
} else { } else {
body, err := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
@@ -186,10 +191,12 @@ func syncHandler(w http.ResponseWriter, r *http.Request) {
} }
// send SDP to client // send SDP to client
if ct == "application/json" { if apiV2 {
w.Header().Set("Content-Type", ct) w.Header().Set("Content-Type", ct)
v := SDP{Sdp: answer, Type: "answer"} v := pion.SessionDescription{
Type: pion.SDPTypeAnswer, SDP: answer,
}
if err = json.NewEncoder(w).Encode(v); err != nil { if err = json.NewEncoder(w).Encode(v); err != nil {
log.Error().Err(err).Caller().Send() log.Error().Err(err).Caller().Send()
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -201,17 +208,15 @@ func syncHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
func ExchangeSDP( func ExchangeSDP(stream *streams.Stream, offer string, userAgent string) (answer string, err error) {
stream *streams.Stream, offer string, userAgent string, pc, err := newPeerConnection()
) (answer string, err error) {
// create new webrtc instance
conn := new(webrtc.Conn)
conn.Conn, err = NewPConn()
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msg("NewPConn") log.Error().Err(err).Caller().Msg("NewPConn")
return return
} }
// create new webrtc instance
conn := webrtc.NewServer(pc)
conn.UserAgent = userAgent conn.UserAgent = userAgent
conn.Listen(func(msg interface{}) { conn.Listen(func(msg interface{}) {
switch msg := msg.(type) { switch msg := msg.(type) {
@@ -233,12 +238,10 @@ func ExchangeSDP(
// 2. AddConsumer, so we get new tracks // 2. AddConsumer, so we get new tracks
if err = stream.AddConsumer(conn); err != nil { if err = stream.AddConsumer(conn); err != nil {
log.Warn().Err(err).Caller().Msg("stream.AddConsumer") log.Warn().Err(err).Caller().Msg("stream.AddConsumer")
_ = conn.Conn.Close() _ = conn.Close()
return return
} }
conn.Init()
// exchange sdp without waiting all candidates // exchange sdp without waiting all candidates
//answer, err := conn.ExchangeSDP(offer, false) //answer, err := conn.ExchangeSDP(offer, false)
answer, err = conn.GetCompleteAnswer() answer, err = conn.GetCompleteAnswer()

View File

@@ -9,13 +9,11 @@ import (
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
) )
// Consumer func (c *Server) GetMedias() []*streamer.Media {
func (c *Conn) GetMedias() []*streamer.Media {
return c.medias return c.medias
} }
func (c *Conn) AddTrack(media *streamer.Media, track *streamer.Track) *streamer.Track { func (c *Server) AddTrack(media *streamer.Media, track *streamer.Track) *streamer.Track {
switch track.Direction { switch track.Direction {
// send our track to WebRTC consumer // send our track to WebRTC consumer
case streamer.DirectionSendonly: case streamer.DirectionSendonly:
@@ -43,7 +41,7 @@ func (c *Conn) AddTrack(media *streamer.Media, track *streamer.Track) *streamer.
return nil return nil
} }
if _, err = c.Conn.AddTrack(trackLocal); err != nil { if _, err = c.conn.AddTrack(trackLocal); err != nil {
return nil return nil
} }
@@ -80,7 +78,7 @@ func (c *Conn) AddTrack(media *streamer.Media, track *streamer.Track) *streamer.
// receive track from WebRTC consumer (microphone, backchannel, two way audio) // receive track from WebRTC consumer (microphone, backchannel, two way audio)
case streamer.DirectionRecvonly: case streamer.DirectionRecvonly:
for _, tr := range c.Conn.GetTransceivers() { for _, tr := range c.conn.GetTransceivers() {
if tr.Mid() != media.MID { if tr.Mid() != media.MID {
continue continue
} }
@@ -106,7 +104,7 @@ func (c *Conn) AddTrack(media *streamer.Media, track *streamer.Track) *streamer.
panic("wrong direction") panic("wrong direction")
} }
func (c *Conn) MarshalJSON() ([]byte, error) { func (c *Server) MarshalJSON() ([]byte, error) {
info := &streamer.Info{ info := &streamer.Info{
Type: "WebRTC client", Type: "WebRTC client",
RemoteAddr: c.remote(), RemoteAddr: c.remote(),

View File

@@ -6,12 +6,12 @@ import (
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
) )
type Conn struct { type Server struct {
streamer.Element streamer.Element
UserAgent string UserAgent string
Conn *webrtc.PeerConnection conn *webrtc.PeerConnection
medias []*streamer.Media medias []*streamer.Media
tracks []*streamer.Track tracks []*streamer.Track
@@ -20,12 +20,14 @@ type Conn struct {
send int send int
} }
func (c *Conn) Init() { func NewServer(conn *webrtc.PeerConnection) *Server {
c.Conn.OnICECandidate(func(candidate *webrtc.ICECandidate) { c := &Server{conn: conn}
conn.OnICECandidate(func(candidate *webrtc.ICECandidate) {
c.Fire(candidate) c.Fire(candidate)
}) })
c.Conn.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { conn.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
for _, track := range c.tracks { for _, track := range c.tracks {
if track.Direction != streamer.DirectionRecvonly { if track.Direction != streamer.DirectionRecvonly {
continue continue
@@ -50,7 +52,7 @@ func (c *Conn) Init() {
//fmt.Printf("TODO: webrtc ontrack %+v\n", remote) //fmt.Printf("TODO: webrtc ontrack %+v\n", remote)
}) })
c.Conn.OnDataChannel(func(channel *webrtc.DataChannel) { conn.OnDataChannel(func(channel *webrtc.DataChannel) {
c.Fire(channel) c.Fire(channel)
}) })
@@ -63,7 +65,7 @@ func (c *Conn) Init() {
// Fail connection: // Fail connection:
// 14:53:08 ICE connection state changed: checking // 14:53:08 ICE connection state changed: checking
// 14:53:39 peer connection state changed: failed // 14:53:39 peer connection state changed: failed
c.Conn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { conn.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
c.Fire(state) c.Fire(state)
// TODO: remove // TODO: remove
@@ -74,25 +76,24 @@ func (c *Conn) Init() {
c.Fire(streamer.StateNull) // TODO: remove c.Fire(streamer.StateNull) // TODO: remove
// disconnect event comes earlier, than failed // disconnect event comes earlier, than failed
// but it comes only for success connections // but it comes only for success connections
_ = c.Conn.Close() _ = conn.Close()
c.Conn = nil
case webrtc.PeerConnectionStateFailed: case webrtc.PeerConnectionStateFailed:
if c.Conn != nil { _ = conn.Close()
_ = c.Conn.Close()
}
} }
}) })
return c
} }
func (c *Conn) SetOffer(offer string) (err error) { func (c *Server) SetOffer(offer string) (err error) {
sdOffer := webrtc.SessionDescription{ desc := webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer, SDP: offer, Type: webrtc.SDPTypeOffer, SDP: offer,
} }
if err = c.Conn.SetRemoteDescription(sdOffer); err != nil { if err = c.conn.SetRemoteDescription(desc); err != nil {
return return
} }
rawSDP := []byte(c.Conn.RemoteDescription().SDP) rawSDP := []byte(c.conn.RemoteDescription().SDP)
sd := &sdp.SessionDescription{} sd := &sdp.SessionDescription{}
if err = sd.Unmarshal(rawSDP); err != nil { if err = sd.Unmarshal(rawSDP); err != nil {
return return
@@ -116,8 +117,8 @@ func (c *Conn) SetOffer(offer string) (err error) {
return return
} }
func (c *Conn) GetAnswer() (answer string, err error) { func (c *Server) GetAnswer() (answer string, err error) {
for _, tr := range c.Conn.GetTransceivers() { for _, tr := range c.conn.GetTransceivers() {
if tr.Direction() != webrtc.RTPTransceiverDirectionSendonly { if tr.Direction() != webrtc.RTPTransceiverDirectionSendonly {
continue continue
} }
@@ -133,37 +134,42 @@ func (c *Conn) GetAnswer() (answer string, err error) {
} }
var sdAnswer webrtc.SessionDescription var sdAnswer webrtc.SessionDescription
sdAnswer, err = c.Conn.CreateAnswer(nil) sdAnswer, err = c.conn.CreateAnswer(nil)
if err != nil { if err != nil {
return return
} }
if err = c.Conn.SetLocalDescription(sdAnswer); err != nil { if err = c.conn.SetLocalDescription(sdAnswer); err != nil {
return return
} }
return sdAnswer.SDP, nil return sdAnswer.SDP, nil
} }
func (c *Conn) GetCompleteAnswer() (answer string, err error) { func (c *Server) GetCompleteAnswer() (answer string, err error) {
if _, err = c.GetAnswer(); err != nil { if _, err = c.GetAnswer(); err != nil {
return return
} }
<-webrtc.GatheringCompletePromise(c.Conn) <-webrtc.GatheringCompletePromise(c.conn)
return c.Conn.LocalDescription().SDP, nil return c.conn.LocalDescription().SDP, nil
} }
func (c *Conn) AddCandidate(candidate string) { func (c *Server) Close() error {
_ = c.Conn.AddICECandidate(webrtc.ICECandidateInit{Candidate: candidate}) return c.conn.Close()
} }
func (c *Conn) remote() string { func (c *Server) AddCandidate(candidate string) {
if c.Conn == nil { // pion uses only candidate value from json/object candidate struct
_ = c.conn.AddICECandidate(webrtc.ICECandidateInit{Candidate: candidate})
}
func (c *Server) remote() string {
if c.conn == nil {
return "" return ""
} }
for _, trans := range c.Conn.GetTransceivers() { for _, trans := range c.conn.GetTransceivers() {
if trans == nil { if trans == nil {
continue continue
} }