diff --git a/rtmp/rtmp.go b/rtmp/rtmp.go index 3c219709..4990b49d 100644 --- a/rtmp/rtmp.go +++ b/rtmp/rtmp.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "fmt" "net" + "net/url" "path/filepath" "strings" "sync" @@ -326,18 +327,53 @@ func (s *server) log(who, action, path, message string, client net.Addr) { }).Log(message) } +// getToken returns the path and the token found in the URL. If the token +// was part of the path, the token is removed from the path. The token in +// the query string takes precedence. The token in the path is assumed to +// be the last path element. +func getToken(u *url.URL) (string, string) { + q := u.Query() + token := q.Get("token") + + if len(token) != 0 { + // The token was in the query. Return the unmomdified path and the token + return u.Path, token + } + + pathElements := strings.Split(u.EscapedPath(), "/") + nPathElements := len(pathElements) + + if nPathElements == 0 { + return u.Path, "" + } + + // Return the path without the token + return strings.Join(pathElements[:nPathElements-1], "/"), pathElements[nPathElements-1] +} + // handlePlay is called when a RTMP client wants to play a stream func (s *server) handlePlay(conn *rtmp.Conn) { client := conn.NetConn().RemoteAddr() - // Check the token - q := conn.URL.Query() - token := q.Get("token") + defer conn.Close() - if len(s.token) != 0 && s.token != token { - s.log("PLAY", "FORBIDDEN", conn.URL.Path, "invalid token ("+token+")", client) - conn.Close() - return + playPath := conn.URL.Path + + // Check the token in the URL if one is required + if len(s.token) != 0 { + path, token := getToken(conn.URL) + + if len(token) == 0 { + s.log("PLAY", "FORBIDDEN", path, "no streamkey provided", client) + return + } + + if s.token != token { + s.log("PLAY", "FORBIDDEN", path, "invalid streamkey ("+token+")", client) + return + } + + playPath = path } /* @@ -361,14 +397,14 @@ func (s *server) handlePlay(conn *rtmp.Conn) { // Look for the stream s.lock.RLock() - ch := s.channels[conn.URL.Path] + ch := s.channels[playPath] s.lock.RUnlock() if ch != nil { // Set the metadata for the client conn.SetMetaData(ch.metadata) - s.log("PLAY", "START", conn.URL.Path, "", client) + s.log("PLAY", "START", playPath, "", client) // Get a cursor and apply filters cursor := ch.queue.Oldest() @@ -395,32 +431,39 @@ func (s *server) handlePlay(conn *rtmp.Conn) { ch.RemoveSubscriber(id) - s.log("PLAY", "STOP", conn.URL.Path, "", client) + s.log("PLAY", "STOP", playPath, "", client) } else { - s.log("PLAY", "NOTFOUND", conn.URL.Path, "", client) + s.log("PLAY", "NOTFOUND", playPath, "", client) } - - conn.Close() } // handlePublish is called when a RTMP client wants to publish a stream func (s *server) handlePublish(conn *rtmp.Conn) { client := conn.NetConn().RemoteAddr() - // Check the token - q := conn.URL.Query() - token := q.Get("token") + defer conn.Close() - if len(s.token) != 0 && s.token != token { - s.log("PUBLISH", "FORBIDDEN", conn.URL.Path, "invalid token ("+token+")", client) - conn.Close() - return + playPath := conn.URL.Path + + if len(s.token) != 0 { + path, token := getToken(conn.URL) + + if len(token) == 0 { + s.log("PLAY", "FORBIDDEN", path, "no streamkey provided", client) + return + } + + if s.token != token { + s.log("PLAY", "FORBIDDEN", path, "invalid streamkey ("+token+")", client) + return + } + + playPath = path } // Check the app patch - if !strings.HasPrefix(conn.URL.Path, s.app) { + if !strings.HasPrefix(playPath, s.app) { s.log("PUBLISH", "FORBIDDEN", conn.URL.Path, "invalid app", client) - conn.Close() return } @@ -428,8 +471,7 @@ func (s *server) handlePublish(conn *rtmp.Conn) { streams, _ := conn.Streams() if len(streams) == 0 { - s.log("PUBLISH", "INVALID", conn.URL.Path, "no streams available", client) - conn.Close() + s.log("PUBLISH", "INVALID", playPath, "no streams available", client) return } @@ -437,7 +479,7 @@ func (s *server) handlePublish(conn *rtmp.Conn) { ch := s.channels[conn.URL.Path] if ch == nil { - reference := strings.TrimPrefix(strings.TrimSuffix(conn.URL.Path, filepath.Ext(conn.URL.Path)), s.app+"/") + reference := strings.TrimPrefix(strings.TrimSuffix(playPath, filepath.Ext(playPath)), s.app+"/") // Create a new channel ch = newChannel(conn, reference, s.collector) @@ -456,7 +498,7 @@ func (s *server) handlePublish(conn *rtmp.Conn) { } } - s.channels[conn.URL.Path] = ch + s.channels[playPath] = ch } else { ch = nil } @@ -464,27 +506,24 @@ func (s *server) handlePublish(conn *rtmp.Conn) { s.lock.Unlock() if ch == nil { - s.log("PUBLISH", "CONFLICT", conn.URL.Path, "already publishing", client) - conn.Close() + s.log("PUBLISH", "CONFLICT", playPath, "already publishing", client) return } - s.log("PUBLISH", "START", conn.URL.Path, "", client) + s.log("PUBLISH", "START", playPath, "", client) for _, stream := range streams { - s.log("PUBLISH", "STREAM", conn.URL.Path, stream.Type().String(), client) + s.log("PUBLISH", "STREAM", playPath, stream.Type().String(), client) } // Ingest the data avutil.CopyPackets(ch.queue, conn) s.lock.Lock() - delete(s.channels, conn.URL.Path) + delete(s.channels, playPath) s.lock.Unlock() ch.Close() - s.log("PUBLISH", "STOP", conn.URL.Path, "", client) - - conn.Close() + s.log("PUBLISH", "STOP", playPath, "", client) } diff --git a/rtmp/rtmp_test.go b/rtmp/rtmp_test.go new file mode 100644 index 00000000..20bb5274 --- /dev/null +++ b/rtmp/rtmp_test.go @@ -0,0 +1,26 @@ +package rtmp + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestToken(t *testing.T) { + data := [][]string{ + {"/foo/bar", "/foo", "bar"}, + {"/foo/bar?token=abc", "/foo/bar", "abc"}, + {"/foo/bar/abc", "/foo/bar", "abc"}, + } + + for _, d := range data { + u, err := url.Parse(d[0]) + require.NoError(t, err) + + path, token := getToken(u) + + require.Equal(t, d[1], path, "url=%s", u.String()) + require.Equal(t, d[2], token, "url=%s", u.String()) + } +}