diff --git a/api.go b/api.go index 2358ed5..3d91e8c 100644 --- a/api.go +++ b/api.go @@ -1172,7 +1172,10 @@ func (s *Server) UpdatePushProxy(ctx context.Context, req *pb.PushProxyInfo) (re return } target := &PushProxy{} - s.DB.First(target, req.ID) + err = s.DB.First(target, req.ID).Error + if err != nil { + return + } target.Name = req.Name target.URL = req.PushURL target.ParentID = uint(req.ParentID) @@ -1205,6 +1208,23 @@ func (s *Server) UpdatePushProxy(ctx context.Context, req *pb.PushProxyInfo) (re target.RTT = time.Duration(int(req.Rtt)) * time.Millisecond target.StreamPath = req.StreamPath s.DB.Save(target) + s.PushProxies.Call(func() error { + if device, ok := s.PushProxies.Get(uint(req.ID)); ok { + if target.URL != device.URL || device.Audio != target.Audio || device.StreamPath != target.StreamPath { + device.Stop(task.ErrStopByUser) + device.WaitStopped() + s.PushProxies.Add(target) + return nil + } + if device.PushOnStart != target.PushOnStart && target.PushOnStart && device.Handler != nil && device.Status == PushProxyStatusOnline { + device.Handler.Push() + } + device.Name = target.Name + device.PushOnStart = target.PushOnStart + device.Description = target.Description + } + return nil + }) res = &pb.SuccessResponse{} return } diff --git a/plugin/webrtc/api.go b/plugin/webrtc/api.go index 0459195..d69b91e 100644 --- a/plugin/webrtc/api.go +++ b/plugin/webrtc/api.go @@ -1,18 +1,20 @@ package plugin_webrtc import ( + "encoding/json" "fmt" "io" "net/http" "strings" . "github.com/pion/webrtc/v3" + "m7s.live/v5/pkg/task" . "m7s.live/v5/plugin/webrtc/pkg" ) // https://datatracker.ietf.org/doc/html/draft-ietf-wish-whip -func (conf *WebRTCPlugin) Push_(w http.ResponseWriter, r *http.Request) { - streamPath := r.URL.Path[len("/push/"):] +func (conf *WebRTCPlugin) servePush(w http.ResponseWriter, r *http.Request) { + streamPath := r.PathValue("streamPath") rawQuery := r.URL.RawQuery auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { @@ -69,9 +71,9 @@ func (conf *WebRTCPlugin) Push_(w http.ResponseWriter, r *http.Request) { } } -func (conf *WebRTCPlugin) Play_(w http.ResponseWriter, r *http.Request) { +func (conf *WebRTCPlugin) servePlay(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/sdp") - streamPath := r.URL.Path[len("/play/"):] + streamPath := r.PathValue("streamPath") rawQuery := r.URL.RawQuery var conn Connection conn.EnableDC = conf.EnableDC @@ -112,3 +114,124 @@ func (conf *WebRTCPlugin) Play_(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) } } + +// Batch 通过单个 PeerConnection 实现多个流的推拉 +func (conf *WebRTCPlugin) Batch(w http.ResponseWriter, r *http.Request) { + conn := NewSingleConnection() + conn.EnableDC = true // Enable DataChannel for signaling + conn.Logger = conf.Logger + bytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + conn.SDP = string(bytes) + if conn.PeerConnection, err = conf.api.NewPeerConnection(Configuration{ + ICEServers: conf.ICEServers, + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Create data channel for signaling + dataChannel, err := conn.PeerConnection.CreateDataChannel("signal", nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + dataChannel.OnMessage(func(msg DataChannelMessage) { + var signal Signal + if err := json.Unmarshal(msg.Data, &signal); err != nil { + conf.Error("failed to unmarshal signal", "error", err) + return + } + + switch signal.Type { + case SignalTypePublish: + if publisher, err := conf.Publish(conf.Context, signal.StreamPath); err == nil { + conn.Publisher = publisher + conn.Publisher.RemoteAddr = r.RemoteAddr + conn.Receive() + // Renegotiate SDP after successful publish + if answer, err := conn.GetAnswer(); err == nil { + dataChannel.SendText(NewAnswerSingal(answer.SDP)) + } else { + dataChannel.SendText(NewErrorSignal(err.Error(), signal.StreamPath)) + } + } else { + dataChannel.SendText(NewErrorSignal(err.Error(), signal.StreamPath)) + } + case SignalTypeSubscribe: + if err := conn.SetRemoteDescription(SessionDescription{ + Type: SDPTypeOffer, + SDP: signal.Offer, + }); err != nil { + dataChannel.SendText(NewErrorSignal("Failed to set remote description: "+err.Error(), "")) + return + } + // First remove subscribers that are not in the new list + for streamPath := range conn.Subscribers { + found := false + for _, newPath := range signal.StreamList { + if streamPath == newPath { + found = true + break + } + } + if !found { + conn.RemoveSubscriber(streamPath) + } + } + // Then add new subscribers + for _, streamPath := range signal.StreamList { + // Skip if already subscribed + if conn.HasSubscriber(streamPath) { + continue + } + if subscriber, err := conf.Subscribe(conf.Context, streamPath); err == nil { + subscriber.RemoteAddr = r.RemoteAddr + conn.AddSubscriber(streamPath, subscriber) + } else { + dataChannel.SendText(NewErrorSignal(err.Error(), streamPath)) + } + } + case SignalTypeUnpublish: + // Handle stream removal + if conn.Publisher != nil && conn.Publisher.StreamPath == signal.StreamPath { + conn.Publisher.Stop(task.ErrStopByUser) + conn.Publisher = nil + // Renegotiate SDP after unpublish + if answer, err := conn.GetAnswer(); err == nil { + dataChannel.SendText(NewAnswerSingal(answer.SDP)) + } else { + dataChannel.SendText(NewErrorSignal(err.Error(), signal.StreamPath)) + } + } + case SignalTypeAnswer: + // Handle received answer from browser + if err := conn.SetRemoteDescription(SessionDescription{ + Type: SDPTypeAnswer, + SDP: signal.Answer, + }); err != nil { + dataChannel.SendText(NewErrorSignal("Failed to set remote description: "+err.Error(), "")) + } + } + }) + + conf.AddTask(conn) + if err = conn.WaitStarted(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err = conn.SetRemoteDescription(SessionDescription{Type: SDPTypeOffer, SDP: conn.SDP}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if answer, err := conn.GetAnswer(); err == nil { + w.Header().Set("Content-Type", "application/sdp") + w.Write([]byte(answer.SDP)) + } else { + http.Error(w, err.Error(), http.StatusBadRequest) + } +} diff --git a/plugin/webrtc/index.go b/plugin/webrtc/index.go index c86245d..6b33111 100644 --- a/plugin/webrtc/index.go +++ b/plugin/webrtc/index.go @@ -44,7 +44,9 @@ type WebRTCPlugin struct { func (p *WebRTCPlugin) RegisterHandler() map[string]http.HandlerFunc { return map[string]http.HandlerFunc{ - "/test/{name}": p.testPage, + "/test/{name}": p.testPage, + "/push/{streamPath...}": p.servePush, + "/play/{streamPath...}": p.servePlay, } } diff --git a/plugin/webrtc/pkg/batcher.go b/plugin/webrtc/pkg/batcher.go new file mode 100644 index 0000000..391198a --- /dev/null +++ b/plugin/webrtc/pkg/batcher.go @@ -0,0 +1,64 @@ +package webrtc + +import "encoding/json" + +type SignalType string + +const ( + SignalTypeSubscribe SignalType = "subscribe" + SignalTypePublish SignalType = "publish" + SignalTypeUnpublish SignalType = "unpublish" + SignalTypeAnswer SignalType = "answer" +) + +type Signal struct { + Type SignalType `json:"type"` + StreamList []string `json:"streamList"` + Offer string `json:"offer"` + Answer string `json:"answer"` + StreamPath string `json:"streamPath"` +} + +type SignalStreamPath struct { + Type string `json:"type"` + StreamPath string `json:"streamPath"` +} + +func NewRemoveSingal(streamPath string) string { + s := SignalStreamPath{ + Type: "remove", + StreamPath: streamPath, + } + b, _ := json.Marshal(s) + return string(b) +} + +type SignalSDP struct { + Type string `json:"type"` + SDP string `json:"sdp"` +} + +func NewAnswerSingal(sdp string) string { + s := SignalSDP{ + Type: "answer", + SDP: sdp, + } + b, _ := json.Marshal(s) + return string(b) +} + +type SignalError struct { + Type string `json:"type"` + Message string `json:"message"` + StreamPath string `json:"streamPath,omitempty"` +} + +func NewErrorSignal(message string, streamPath string) string { + s := SignalError{ + Type: "error", + Message: message, + StreamPath: streamPath, + } + b, _ := json.Marshal(s) + return string(b) +} diff --git a/plugin/webrtc/pkg/connection.go b/plugin/webrtc/pkg/connection.go index 52ef1a3..f3c5a69 100644 --- a/plugin/webrtc/pkg/connection.go +++ b/plugin/webrtc/pkg/connection.go @@ -31,9 +31,11 @@ type Connection struct { func (IO *Connection) Start() (err error) { if IO.Publisher != nil { + IO.Depend(IO.Publisher) IO.Receive() } if IO.Subscriber != nil { + IO.Depend(IO.Subscriber) IO.Send() } IO.OnICECandidate(func(ice *ICECandidate) { @@ -90,7 +92,6 @@ func (IO *Connection) GetAnswer() (*SessionDescription, error) { func (IO *Connection) Receive() { puber := IO.Publisher - IO.Depend(puber) IO.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { IO.Info("OnTrack", "kind", track.Kind().String(), "payloadType", uint8(track.Codec().PayloadType)) var n int @@ -200,13 +201,11 @@ func (IO *Connection) Receive() { }) } -func (IO *Connection) Send() (err error) { - suber := IO.Subscriber - IO.Depend(suber) +func (IO *Connection) SendSubscriber(subscriber *m7s.Subscriber) (err error) { var useDC bool var audioTLSRTP, videoTLSRTP *TrackLocalStaticRTP var audioSender, videoSender *RTPSender - vctx, actx := suber.Publisher.GetVideoCodecCtx(), suber.Publisher.GetAudioCodecCtx() + vctx, actx := subscriber.Publisher.GetVideoCodecCtx(), subscriber.Publisher.GetAudioCodecCtx() if IO.EnableDC && vctx != nil && vctx.FourCC() == codec.FourCC_H265 { useDC = true } @@ -229,7 +228,7 @@ func (IO *Connection) Send() (err error) { return } } - videoTLSRTP, err = NewTrackLocalStaticRTP(rcc.RTPCodecCapability, videoCodec.String(), suber.StreamPath) + videoTLSRTP, err = NewTrackLocalStaticRTP(rcc.RTPCodecCapability, videoCodec.String(), subscriber.StreamPath) if err != nil { return } @@ -241,7 +240,7 @@ func (IO *Connection) Send() (err error) { rtcpBuf := make([]byte, 1500) for { if n, _, rtcpErr := videoSender.Read(rtcpBuf); rtcpErr != nil { - suber.Warn("rtcp read error", "error", rtcpErr) + subscriber.Warn("rtcp read error", "error", rtcpErr) return } else { if p, err := rtcp.Unmarshal(rtcpBuf[:n]); err == nil { @@ -271,7 +270,7 @@ func (IO *Connection) Send() (err error) { return } } - audioTLSRTP, err = NewTrackLocalStaticRTP(rcc.RTPCodecCapability, audioCodec.String(), suber.StreamPath) + audioTLSRTP, err = NewTrackLocalStaticRTP(rcc.RTPCodecCapability, audioCodec.String(), subscriber.StreamPath) if err != nil { return } @@ -282,11 +281,11 @@ func (IO *Connection) Send() (err error) { } var dc *DataChannel if useDC { - dc, err = IO.CreateDataChannel(suber.StreamPath, nil) + dc, err = IO.CreateDataChannel(subscriber.StreamPath, nil) if err != nil { return } - dc.OnOpen(func(){ + dc.OnOpen(func() { var live flv.Live live.WriteFlvTag = func(buffers net.Buffers) (err error) { r := util.NewReadableBuffersFromBytes(buffers...) @@ -306,18 +305,18 @@ func (IO *Connection) Send() (err error) { }) return } - live.Subscriber = suber + live.Subscriber = subscriber err = live.Run() dc.Close() }) } else { if audioSender == nil { - suber.SubAudio = false + subscriber.SubAudio = false } if videoSender == nil { - suber.SubVideo = false + subscriber.SubVideo = false } - go m7s.PlayBlock(suber, func(frame *mrtp.Audio) (err error) { + go m7s.PlayBlock(subscriber, func(frame *mrtp.Audio) (err error) { for _, p := range frame.Packets { if err = audioTLSRTP.WriteRTP(p); err != nil { return @@ -336,6 +335,49 @@ func (IO *Connection) Send() (err error) { return } +func (IO *Connection) Send() (err error) { + if IO.Subscriber != nil { + err = IO.SendSubscriber(IO.Subscriber) + } + return +} + func (IO *Connection) Dispose() { IO.PeerConnection.Close() } + +// SingleConnection extends Connection to handle multiple subscribers in a single WebRTC connection +type SingleConnection struct { + Connection + Subscribers map[string]*m7s.Subscriber // map streamPath to subscriber +} + +func NewSingleConnection() *SingleConnection { + return &SingleConnection{ + Subscribers: make(map[string]*m7s.Subscriber), + } +} + +// AddSubscriber adds a new subscriber to the connection and starts sending +func (c *SingleConnection) AddSubscriber(streamPath string, subscriber *m7s.Subscriber) { + c.Subscribers[streamPath] = subscriber + if err := c.SendSubscriber(subscriber); err != nil { + c.Error("failed to start subscriber", "error", err, "streamPath", streamPath) + subscriber.Stop(err) + delete(c.Subscribers, streamPath) + } +} + +// RemoveSubscriber removes a subscriber from the connection +func (c *SingleConnection) RemoveSubscriber(streamPath string) { + if subscriber, ok := c.Subscribers[streamPath]; ok { + subscriber.Stop(task.ErrStopByUser) + delete(c.Subscribers, streamPath) + } +} + +// HasSubscriber checks if a stream is already subscribed +func (c *SingleConnection) HasSubscriber(streamPath string) bool { + _, ok := c.Subscribers[streamPath] + return ok +}