diff --git a/cmd/api/api.go b/cmd/api/api.go index 940130b9..322d23c4 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -4,7 +4,6 @@ import ( "encoding/json" "github.com/AlexxIT/go2rtc/cmd/app" "github.com/AlexxIT/go2rtc/cmd/streams" - "github.com/AlexxIT/go2rtc/pkg/streamer" "github.com/rs/zerolog" "net" "net/http" @@ -128,26 +127,3 @@ func streamsHandler(w http.ResponseWriter, r *http.Request) { e.SetIndent("", " ") _ = e.Encode(v) } - -func apiWS(w http.ResponseWriter, r *http.Request) { - ctx := new(Context) - if err := ctx.Upgrade(w, r); err != nil { - origin := r.Header.Get("Origin") - log.Error().Err(err).Caller().Msgf("host=%s origin=%s", r.Host, origin) - return - } - defer ctx.Close() - - for { - msg := new(streamer.Message) - if err := ctx.Conn.ReadJSON(msg); err != nil { - log.Trace().Err(err).Caller().Send() - return - } - - handler := wsHandlers[msg.Type] - if handler != nil { - handler(ctx, msg) - } - } -} diff --git a/cmd/api/ws.go b/cmd/api/ws.go index 50a8ced9..68224d25 100644 --- a/cmd/api/ws.go +++ b/cmd/api/ws.go @@ -45,50 +45,85 @@ func initWS(origin string) { } } +func apiWS(w http.ResponseWriter, r *http.Request) { + ws, err := wsUp.Upgrade(w, r, nil) + if err != nil { + origin := r.Header.Get("Origin") + log.Error().Err(err).Caller().Msgf("host=%s origin=%s", r.Host, origin) + return + } + + tr := &Transport{Request: r} + tr.OnWrite(func(msg interface{}) { + if data, ok := msg.([]byte); ok { + _ = ws.WriteMessage(websocket.BinaryMessage, data) + } else { + _ = ws.WriteJSON(msg) + } + }) + + for { + msg := new(streamer.Message) + if err = ws.ReadJSON(msg); err != nil { + log.Trace().Err(err).Caller().Send() + _ = ws.Close() + break + } + + if handler := wsHandlers[msg.Type]; handler != nil { + handler(tr, msg) + } + } + + tr.Close() +} + var wsUp *websocket.Upgrader -type WSHandler func(ctx *Context, msg *streamer.Message) +type WSHandler func(tr *Transport, msg *streamer.Message) -type Context struct { - Conn *websocket.Conn +type Transport struct { Request *http.Request Consumer interface{} // TODO: rewrite - onClose []func() - mu sync.Mutex + mx sync.Mutex + + onChange func() + onWrite func(msg interface{}) + onClose []func() } -func (ctx *Context) Upgrade(w http.ResponseWriter, r *http.Request) (err error) { - ctx.Conn, err = wsUp.Upgrade(w, r, nil) - ctx.Request = r - return +func (t *Transport) OnWrite(f func(msg interface{})) { + t.mx.Lock() + if t.onChange != nil { + t.onChange() + } + t.onWrite = f + t.mx.Unlock() } -func (ctx *Context) Close() { - for _, f := range ctx.onClose { +func (t *Transport) Write(msg interface{}) { + t.mx.Lock() + t.onWrite(msg) + t.mx.Unlock() +} + +func (t *Transport) Close() { + for _, f := range t.onClose { f() } - _ = ctx.Conn.Close() } -func (ctx *Context) Write(msg interface{}) { - ctx.mu.Lock() - - if data, ok := msg.([]byte); ok { - _ = ctx.Conn.WriteMessage(websocket.BinaryMessage, data) - } else { - _ = ctx.Conn.WriteJSON(msg) - } - - ctx.mu.Unlock() -} - -func (ctx *Context) Error(err error) { - ctx.Write(&streamer.Message{ +func (t *Transport) Error(err error) { + t.Write(&streamer.Message{ Type: "error", Value: err.Error(), }) } -func (ctx *Context) OnClose(f func()) { - ctx.onClose = append(ctx.onClose, f) +func (t *Transport) OnChange(f func()) { + t.onChange = f +} + +func (t *Transport) OnClose(f func()) { + t.onClose = append(t.onClose, f) } diff --git a/cmd/mjpeg/mjpeg.go b/cmd/mjpeg/mjpeg.go index d080ee67..2219df5f 100644 --- a/cmd/mjpeg/mjpeg.go +++ b/cmd/mjpeg/mjpeg.go @@ -100,7 +100,7 @@ func handlerStream(w http.ResponseWriter, r *http.Request) { //log.Trace().Msg("[api.mjpeg] close") } -func handlerWS(ctx *api.Context, msg *streamer.Message) { +func handlerWS(ctx *api.Transport, msg *streamer.Message) { src := ctx.Request.URL.Query().Get("src") stream := streams.GetOrNew(src) if stream == nil { diff --git a/cmd/mp4/ws.go b/cmd/mp4/ws.go index d1268507..6b4c891d 100644 --- a/cmd/mp4/ws.go +++ b/cmd/mp4/ws.go @@ -10,7 +10,7 @@ import ( const packetSize = 8192 -func handlerWS(ctx *api.Context, msg *streamer.Message) { +func handlerWS(ctx *api.Transport, msg *streamer.Message) { src := ctx.Request.URL.Query().Get("src") stream := streams.GetOrNew(src) if stream == nil { @@ -59,7 +59,7 @@ func handlerWS(ctx *api.Context, msg *streamer.Message) { cons.Start() } -func handlerWS4(ctx *api.Context, msg *streamer.Message) { +func handlerWS4(ctx *api.Transport, msg *streamer.Message) { src := ctx.Request.URL.Query().Get("src") stream := streams.GetOrNew(src) if stream == nil { diff --git a/cmd/webrtc/candidates.go b/cmd/webrtc/candidates.go index feb965be..bc0e6fd4 100644 --- a/cmd/webrtc/candidates.go +++ b/cmd/webrtc/candidates.go @@ -13,7 +13,7 @@ func AddCandidate(address string) { candidates = append(candidates, address) } -func asyncCandidates(ctx *api.Context) { +func asyncCandidates(ctx *api.Transport) { for _, address := range candidates { address, err := webrtc.LookupIP(address) if err != nil { @@ -79,7 +79,7 @@ func syncCanditates(answer string) (string, error) { return string(data), nil } -func candidateHandler(ctx *api.Context, msg *streamer.Message) { +func candidateHandler(ctx *api.Transport, msg *streamer.Message) { if ctx.Consumer == nil { return } diff --git a/cmd/webrtc/webrtc.go b/cmd/webrtc/webrtc.go index e07af7f1..524d962a 100644 --- a/cmd/webrtc/webrtc.go +++ b/cmd/webrtc/webrtc.go @@ -8,7 +8,7 @@ import ( "github.com/AlexxIT/go2rtc/pkg/webrtc" pion "github.com/pion/webrtc/v3" "github.com/rs/zerolog" - "io/ioutil" + "io" "net" "net/http" ) @@ -66,8 +66,8 @@ var log zerolog.Logger var NewPConn func() (*pion.PeerConnection, error) -func asyncHandler(ctx *api.Context, msg *streamer.Message) { - src := ctx.Request.URL.Query().Get("src") +func asyncHandler(tr *api.Transport, msg *streamer.Message) { + src := tr.Request.URL.Query().Get("src") stream := streams.Get(src) if stream == nil { return @@ -85,7 +85,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) { return } - conn.UserAgent = ctx.Request.UserAgent() + conn.UserAgent = tr.Request.UserAgent() conn.Listen(func(msg interface{}) { switch msg := msg.(type) { case pion.PeerConnectionState: @@ -96,7 +96,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) { if msg != nil { s := msg.ToJSON().Candidate log.Trace().Str("candidate", s).Msg("[webrtc] local") - ctx.Write(&streamer.Message{Type: "webrtc/candidate", Value: s}) + tr.Write(&streamer.Message{Type: "webrtc/candidate", Value: s}) } } }) @@ -107,7 +107,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) { if err = conn.SetOffer(offer); err != nil { log.Warn().Err(err).Caller().Msg("conn.SetOffer") - ctx.Error(err) + tr.Error(err) return } @@ -115,7 +115,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) { if err = stream.AddConsumer(conn); err != nil { log.Warn().Err(err).Caller().Msg("stream.AddConsumer") _ = conn.Conn.Close() - ctx.Error(err) + tr.Error(err) return } @@ -127,15 +127,15 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) { if err != nil { log.Error().Err(err).Caller().Msg("conn.GetAnswer") - ctx.Error(err) + tr.Error(err) return } - ctx.Consumer = conn + tr.Consumer = conn - ctx.Write(&streamer.Message{Type: "webrtc/answer", Value: answer}) + tr.Write(&streamer.Message{Type: "webrtc/answer", Value: answer}) - asyncCandidates(ctx) + asyncCandidates(tr) } func syncHandler(w http.ResponseWriter, r *http.Request) { @@ -146,7 +146,7 @@ func syncHandler(w http.ResponseWriter, r *http.Request) { } // get offer - offer, err := ioutil.ReadAll(r.Body) + offer, err := io.ReadAll(r.Body) if err != nil { log.Error().Err(err).Caller().Msg("ioutil.ReadAll") return diff --git a/pkg/webrtc/conn.go b/pkg/webrtc/conn.go index a4e0a103..eb4d5974 100644 --- a/pkg/webrtc/conn.go +++ b/pkg/webrtc/conn.go @@ -49,6 +49,10 @@ func (c *Conn) Init() { //fmt.Printf("TODO: webrtc ontrack %+v\n", remote) }) + c.Conn.OnDataChannel(func(channel *webrtc.DataChannel) { + c.Fire(channel) + }) + // OK connection: // 15:01:46 ICE connection state changed: checking // 15:01:46 peer connection state changed: connected