Rewrite WS transport handler

This commit is contained in:
Alexey Khit
2022-12-04 23:24:20 +03:00
parent 69b17230f3
commit b7718b33b8
7 changed files with 84 additions and 69 deletions

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"github.com/AlexxIT/go2rtc/cmd/app" "github.com/AlexxIT/go2rtc/cmd/app"
"github.com/AlexxIT/go2rtc/cmd/streams" "github.com/AlexxIT/go2rtc/cmd/streams"
"github.com/AlexxIT/go2rtc/pkg/streamer"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"net" "net"
"net/http" "net/http"
@@ -128,26 +127,3 @@ func streamsHandler(w http.ResponseWriter, r *http.Request) {
e.SetIndent("", " ") e.SetIndent("", " ")
_ = e.Encode(v) _ = e.Encode(v)
} }
func apiWS(w http.ResponseWriter, r *http.Request) {
ctx := new(Context)
if err := ctx.Upgrade(w, r); err != nil {
origin := r.Header.Get("Origin")
log.Error().Err(err).Caller().Msgf("host=%s origin=%s", r.Host, origin)
return
}
defer ctx.Close()
for {
msg := new(streamer.Message)
if err := ctx.Conn.ReadJSON(msg); err != nil {
log.Trace().Err(err).Caller().Send()
return
}
handler := wsHandlers[msg.Type]
if handler != nil {
handler(ctx, msg)
}
}
}

View File

@@ -45,50 +45,85 @@ func initWS(origin string) {
} }
} }
func apiWS(w http.ResponseWriter, r *http.Request) {
ws, err := wsUp.Upgrade(w, r, nil)
if err != nil {
origin := r.Header.Get("Origin")
log.Error().Err(err).Caller().Msgf("host=%s origin=%s", r.Host, origin)
return
}
tr := &Transport{Request: r}
tr.OnWrite(func(msg interface{}) {
if data, ok := msg.([]byte); ok {
_ = ws.WriteMessage(websocket.BinaryMessage, data)
} else {
_ = ws.WriteJSON(msg)
}
})
for {
msg := new(streamer.Message)
if err = ws.ReadJSON(msg); err != nil {
log.Trace().Err(err).Caller().Send()
_ = ws.Close()
break
}
if handler := wsHandlers[msg.Type]; handler != nil {
handler(tr, msg)
}
}
tr.Close()
}
var wsUp *websocket.Upgrader var wsUp *websocket.Upgrader
type WSHandler func(ctx *Context, msg *streamer.Message) type WSHandler func(tr *Transport, msg *streamer.Message)
type Context struct { type Transport struct {
Conn *websocket.Conn
Request *http.Request Request *http.Request
Consumer interface{} // TODO: rewrite Consumer interface{} // TODO: rewrite
onClose []func() mx sync.Mutex
mu sync.Mutex
onChange func()
onWrite func(msg interface{})
onClose []func()
} }
func (ctx *Context) Upgrade(w http.ResponseWriter, r *http.Request) (err error) { func (t *Transport) OnWrite(f func(msg interface{})) {
ctx.Conn, err = wsUp.Upgrade(w, r, nil) t.mx.Lock()
ctx.Request = r if t.onChange != nil {
return t.onChange()
}
t.onWrite = f
t.mx.Unlock()
} }
func (ctx *Context) Close() { func (t *Transport) Write(msg interface{}) {
for _, f := range ctx.onClose { t.mx.Lock()
t.onWrite(msg)
t.mx.Unlock()
}
func (t *Transport) Close() {
for _, f := range t.onClose {
f() f()
} }
_ = ctx.Conn.Close()
} }
func (ctx *Context) Write(msg interface{}) { func (t *Transport) Error(err error) {
ctx.mu.Lock() t.Write(&streamer.Message{
if data, ok := msg.([]byte); ok {
_ = ctx.Conn.WriteMessage(websocket.BinaryMessage, data)
} else {
_ = ctx.Conn.WriteJSON(msg)
}
ctx.mu.Unlock()
}
func (ctx *Context) Error(err error) {
ctx.Write(&streamer.Message{
Type: "error", Value: err.Error(), Type: "error", Value: err.Error(),
}) })
} }
func (ctx *Context) OnClose(f func()) { func (t *Transport) OnChange(f func()) {
ctx.onClose = append(ctx.onClose, f) t.onChange = f
}
func (t *Transport) OnClose(f func()) {
t.onClose = append(t.onClose, f)
} }

View File

@@ -100,7 +100,7 @@ func handlerStream(w http.ResponseWriter, r *http.Request) {
//log.Trace().Msg("[api.mjpeg] close") //log.Trace().Msg("[api.mjpeg] close")
} }
func handlerWS(ctx *api.Context, msg *streamer.Message) { func handlerWS(ctx *api.Transport, msg *streamer.Message) {
src := ctx.Request.URL.Query().Get("src") src := ctx.Request.URL.Query().Get("src")
stream := streams.GetOrNew(src) stream := streams.GetOrNew(src)
if stream == nil { if stream == nil {

View File

@@ -10,7 +10,7 @@ import (
const packetSize = 8192 const packetSize = 8192
func handlerWS(ctx *api.Context, msg *streamer.Message) { func handlerWS(ctx *api.Transport, msg *streamer.Message) {
src := ctx.Request.URL.Query().Get("src") src := ctx.Request.URL.Query().Get("src")
stream := streams.GetOrNew(src) stream := streams.GetOrNew(src)
if stream == nil { if stream == nil {
@@ -59,7 +59,7 @@ func handlerWS(ctx *api.Context, msg *streamer.Message) {
cons.Start() cons.Start()
} }
func handlerWS4(ctx *api.Context, msg *streamer.Message) { func handlerWS4(ctx *api.Transport, msg *streamer.Message) {
src := ctx.Request.URL.Query().Get("src") src := ctx.Request.URL.Query().Get("src")
stream := streams.GetOrNew(src) stream := streams.GetOrNew(src)
if stream == nil { if stream == nil {

View File

@@ -13,7 +13,7 @@ func AddCandidate(address string) {
candidates = append(candidates, address) candidates = append(candidates, address)
} }
func asyncCandidates(ctx *api.Context) { func asyncCandidates(ctx *api.Transport) {
for _, address := range candidates { for _, address := range candidates {
address, err := webrtc.LookupIP(address) address, err := webrtc.LookupIP(address)
if err != nil { if err != nil {
@@ -79,7 +79,7 @@ func syncCanditates(answer string) (string, error) {
return string(data), nil return string(data), nil
} }
func candidateHandler(ctx *api.Context, msg *streamer.Message) { func candidateHandler(ctx *api.Transport, msg *streamer.Message) {
if ctx.Consumer == nil { if ctx.Consumer == nil {
return return
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/AlexxIT/go2rtc/pkg/webrtc" "github.com/AlexxIT/go2rtc/pkg/webrtc"
pion "github.com/pion/webrtc/v3" pion "github.com/pion/webrtc/v3"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"io/ioutil" "io"
"net" "net"
"net/http" "net/http"
) )
@@ -66,8 +66,8 @@ var log zerolog.Logger
var NewPConn func() (*pion.PeerConnection, error) var NewPConn func() (*pion.PeerConnection, error)
func asyncHandler(ctx *api.Context, msg *streamer.Message) { func asyncHandler(tr *api.Transport, msg *streamer.Message) {
src := ctx.Request.URL.Query().Get("src") src := tr.Request.URL.Query().Get("src")
stream := streams.Get(src) stream := streams.Get(src)
if stream == nil { if stream == nil {
return return
@@ -85,7 +85,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) {
return return
} }
conn.UserAgent = ctx.Request.UserAgent() conn.UserAgent = tr.Request.UserAgent()
conn.Listen(func(msg interface{}) { conn.Listen(func(msg interface{}) {
switch msg := msg.(type) { switch msg := msg.(type) {
case pion.PeerConnectionState: case pion.PeerConnectionState:
@@ -96,7 +96,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) {
if msg != nil { if msg != nil {
s := msg.ToJSON().Candidate s := msg.ToJSON().Candidate
log.Trace().Str("candidate", s).Msg("[webrtc] local") log.Trace().Str("candidate", s).Msg("[webrtc] local")
ctx.Write(&streamer.Message{Type: "webrtc/candidate", Value: s}) tr.Write(&streamer.Message{Type: "webrtc/candidate", Value: s})
} }
} }
}) })
@@ -107,7 +107,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) {
if err = conn.SetOffer(offer); err != nil { if err = conn.SetOffer(offer); err != nil {
log.Warn().Err(err).Caller().Msg("conn.SetOffer") log.Warn().Err(err).Caller().Msg("conn.SetOffer")
ctx.Error(err) tr.Error(err)
return return
} }
@@ -115,7 +115,7 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) {
if err = stream.AddConsumer(conn); err != nil { if err = stream.AddConsumer(conn); err != nil {
log.Warn().Err(err).Caller().Msg("stream.AddConsumer") log.Warn().Err(err).Caller().Msg("stream.AddConsumer")
_ = conn.Conn.Close() _ = conn.Conn.Close()
ctx.Error(err) tr.Error(err)
return return
} }
@@ -127,15 +127,15 @@ func asyncHandler(ctx *api.Context, msg *streamer.Message) {
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msg("conn.GetAnswer") log.Error().Err(err).Caller().Msg("conn.GetAnswer")
ctx.Error(err) tr.Error(err)
return return
} }
ctx.Consumer = conn tr.Consumer = conn
ctx.Write(&streamer.Message{Type: "webrtc/answer", Value: answer}) tr.Write(&streamer.Message{Type: "webrtc/answer", Value: answer})
asyncCandidates(ctx) asyncCandidates(tr)
} }
func syncHandler(w http.ResponseWriter, r *http.Request) { func syncHandler(w http.ResponseWriter, r *http.Request) {
@@ -146,7 +146,7 @@ func syncHandler(w http.ResponseWriter, r *http.Request) {
} }
// get offer // get offer
offer, err := ioutil.ReadAll(r.Body) offer, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
log.Error().Err(err).Caller().Msg("ioutil.ReadAll") log.Error().Err(err).Caller().Msg("ioutil.ReadAll")
return return

View File

@@ -49,6 +49,10 @@ func (c *Conn) Init() {
//fmt.Printf("TODO: webrtc ontrack %+v\n", remote) //fmt.Printf("TODO: webrtc ontrack %+v\n", remote)
}) })
c.Conn.OnDataChannel(func(channel *webrtc.DataChannel) {
c.Fire(channel)
})
// OK connection: // OK connection:
// 15:01:46 ICE connection state changed: checking // 15:01:46 ICE connection state changed: checking
// 15:01:46 peer connection state changed: connected // 15:01:46 peer connection state changed: connected