decode the configuration when decoding JSON

This commit is contained in:
aler9
2021-09-27 10:36:28 +02:00
committed by Alessandro Ros
parent 54292d712e
commit 6921a402d1
18 changed files with 532 additions and 285 deletions

View File

@@ -0,0 +1,51 @@
package conf
import (
"encoding/json"
"fmt"
"github.com/aler9/gortsplib/pkg/headers"
)
// AuthMethods is the authMethods parameter.
type AuthMethods []headers.AuthMethod
// MarshalJSON marshals a AuthMethods into JSON.
func (d AuthMethods) MarshalJSON() ([]byte, error) {
var out []string
for _, v := range d {
switch v {
case headers.AuthBasic:
out = append(out, "basic")
default:
out = append(out, "digest")
}
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a AuthMethods from JSON.
func (d *AuthMethods) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
for _, v := range in {
switch v {
case "basic":
*d = append(*d, headers.AuthBasic)
case "digest":
*d = append(*d, headers.AuthDigest)
default:
return fmt.Errorf("unsupported authentication method: %s", in)
}
}
return nil
}

View File

@@ -15,26 +15,6 @@ import (
"github.com/aler9/rtsp-simple-server/internal/logger" "github.com/aler9/rtsp-simple-server/internal/logger"
) )
// Encryption is an encryption policy.
type Encryption int
// encryption policies.
const (
EncryptionNo Encryption = iota
EncryptionOptional
EncryptionStrict
)
// Protocol is a RTSP protocol
type Protocol int
// RTSP protocols.
const (
ProtocolUDP Protocol = iota
ProtocolMulticast
ProtocolTCP
)
func decrypt(key string, byts []byte) ([]byte, error) { func decrypt(key string, byts []byte) ([]byte, error) {
enc, err := base64.StdEncoding.DecodeString(string(byts)) enc, err := base64.StdEncoding.DecodeString(string(byts))
if err != nil { if err != nil {
@@ -121,10 +101,8 @@ func loadFromFile(fpath string, conf *Conf) (bool, error) {
// Conf is a configuration. // Conf is a configuration.
type Conf struct { type Conf struct {
// general // general
LogLevel string `json:"logLevel"` LogLevel LogLevel `json:"logLevel"`
LogLevelParsed logger.Level `json:"-"` LogDestinations LogDestinations `json:"logDestinations"`
LogDestinations []string `json:"logDestinations"`
LogDestinationsParsed map[logger.Destination]struct{} `json:"-"`
LogFile string `json:"logFile"` LogFile string `json:"logFile"`
ReadTimeout StringDuration `json:"readTimeout"` ReadTimeout StringDuration `json:"readTimeout"`
WriteTimeout StringDuration `json:"writeTimeout"` WriteTimeout StringDuration `json:"writeTimeout"`
@@ -140,10 +118,8 @@ type Conf struct {
// RTSP // RTSP
RTSPDisable bool `json:"rtspDisable"` RTSPDisable bool `json:"rtspDisable"`
Protocols []string `json:"protocols"` Protocols Protocols `json:"protocols"`
ProtocolsParsed map[Protocol]struct{} `json:"-"` Encryption Encryption `json:"encryption"`
Encryption string `json:"encryption"`
EncryptionParsed Encryption `json:"-"`
RTSPAddress string `json:"rtspAddress"` RTSPAddress string `json:"rtspAddress"`
RTSPSAddress string `json:"rtspsAddress"` RTSPSAddress string `json:"rtspsAddress"`
RTPAddress string `json:"rtpAddress"` RTPAddress string `json:"rtpAddress"`
@@ -153,8 +129,7 @@ type Conf struct {
MulticastRTCPPort int `json:"multicastRTCPPort"` MulticastRTCPPort int `json:"multicastRTCPPort"`
ServerKey string `json:"serverKey"` ServerKey string `json:"serverKey"`
ServerCert string `json:"serverCert"` ServerCert string `json:"serverCert"`
AuthMethods []string `json:"authMethods"` AuthMethods AuthMethods `json:"authMethods"`
AuthMethodsParsed []headers.AuthMethod `json:"-"`
ReadBufferSize int `json:"readBufferSize"` ReadBufferSize int `json:"readBufferSize"`
// RTMP // RTMP
@@ -197,52 +172,26 @@ func Load(fpath string) (*Conf, bool, error) {
// CheckAndFillMissing checks the configuration for errors and fills missing fields. // CheckAndFillMissing checks the configuration for errors and fills missing fields.
func (conf *Conf) CheckAndFillMissing() error { func (conf *Conf) CheckAndFillMissing() error {
if conf.LogLevel == "" { if conf.LogLevel == 0 {
conf.LogLevel = "info" conf.LogLevel = LogLevel(logger.Info)
}
switch conf.LogLevel {
case "warn":
conf.LogLevelParsed = logger.Warn
case "info":
conf.LogLevelParsed = logger.Info
case "debug":
conf.LogLevelParsed = logger.Debug
default:
return fmt.Errorf("unsupported log level: %s", conf.LogLevel)
} }
if len(conf.LogDestinations) == 0 { if len(conf.LogDestinations) == 0 {
conf.LogDestinations = []string{"stdout"} conf.LogDestinations = LogDestinations{logger.DestinationStdout: {}}
}
conf.LogDestinationsParsed = make(map[logger.Destination]struct{})
for _, dest := range conf.LogDestinations {
switch dest {
case "stdout":
conf.LogDestinationsParsed[logger.DestinationStdout] = struct{}{}
case "file":
conf.LogDestinationsParsed[logger.DestinationFile] = struct{}{}
case "syslog":
conf.LogDestinationsParsed[logger.DestinationSyslog] = struct{}{}
default:
return fmt.Errorf("unsupported log destination: %s", dest)
}
} }
if conf.LogFile == "" { if conf.LogFile == "" {
conf.LogFile = "rtsp-simple-server.log" conf.LogFile = "rtsp-simple-server.log"
} }
if conf.ReadTimeout == 0 { if conf.ReadTimeout == 0 {
conf.ReadTimeout = 10 * StringDuration(time.Second) conf.ReadTimeout = 10 * StringDuration(time.Second)
} }
if conf.WriteTimeout == 0 { if conf.WriteTimeout == 0 {
conf.WriteTimeout = 10 * StringDuration(time.Second) conf.WriteTimeout = 10 * StringDuration(time.Second)
} }
if conf.ReadBufferCount == 0 { if conf.ReadBufferCount == 0 {
conf.ReadBufferCount = 512 conf.ReadBufferCount = 512
} }
@@ -260,67 +209,43 @@ func (conf *Conf) CheckAndFillMissing() error {
} }
if len(conf.Protocols) == 0 { if len(conf.Protocols) == 0 {
conf.Protocols = []string{"udp", "multicast", "tcp"} conf.Protocols = Protocols{
} ProtocolUDP: {},
conf.ProtocolsParsed = make(map[Protocol]struct{}) ProtocolMulticast: {},
for _, proto := range conf.Protocols { ProtocolTCP: {},
switch proto {
case "udp":
conf.ProtocolsParsed[ProtocolUDP] = struct{}{}
case "multicast":
conf.ProtocolsParsed[ProtocolMulticast] = struct{}{}
case "tcp":
conf.ProtocolsParsed[ProtocolTCP] = struct{}{}
default:
return fmt.Errorf("unsupported protocol: %s", proto)
} }
} }
if len(conf.ProtocolsParsed) == 0 {
return fmt.Errorf("no protocols provided") if conf.Encryption == EncryptionStrict {
if _, ok := conf.Protocols[ProtocolUDP]; ok {
return fmt.Errorf("strict encryption can't be used with the UDP stream protocol")
} }
if conf.Encryption == "" {
conf.Encryption = "no"
}
switch conf.Encryption {
case "no", "false":
conf.EncryptionParsed = EncryptionNo
case "optional":
conf.EncryptionParsed = EncryptionOptional
case "strict", "yes", "true":
conf.EncryptionParsed = EncryptionStrict
if _, ok := conf.ProtocolsParsed[ProtocolUDP]; ok {
return fmt.Errorf("encryption can't be used with the UDP stream protocol")
}
default:
return fmt.Errorf("unsupported encryption value: '%s'", conf.Encryption)
} }
if conf.RTSPAddress == "" { if conf.RTSPAddress == "" {
conf.RTSPAddress = ":8554" conf.RTSPAddress = ":8554"
} }
if conf.RTSPSAddress == "" { if conf.RTSPSAddress == "" {
conf.RTSPSAddress = ":8555" conf.RTSPSAddress = ":8555"
} }
if conf.RTPAddress == "" { if conf.RTPAddress == "" {
conf.RTPAddress = ":8000" conf.RTPAddress = ":8000"
} }
if conf.RTCPAddress == "" { if conf.RTCPAddress == "" {
conf.RTCPAddress = ":8001" conf.RTCPAddress = ":8001"
} }
if conf.MulticastIPRange == "" { if conf.MulticastIPRange == "" {
conf.MulticastIPRange = "224.1.0.0/16" conf.MulticastIPRange = "224.1.0.0/16"
} }
if conf.MulticastRTPPort == 0 { if conf.MulticastRTPPort == 0 {
conf.MulticastRTPPort = 8002 conf.MulticastRTPPort = 8002
} }
if conf.MulticastRTCPPort == 0 { if conf.MulticastRTCPPort == 0 {
conf.MulticastRTCPPort = 8003 conf.MulticastRTCPPort = 8003
} }
@@ -328,24 +253,13 @@ func (conf *Conf) CheckAndFillMissing() error {
if conf.ServerKey == "" { if conf.ServerKey == "" {
conf.ServerKey = "server.key" conf.ServerKey = "server.key"
} }
if conf.ServerCert == "" { if conf.ServerCert == "" {
conf.ServerCert = "server.crt" conf.ServerCert = "server.crt"
} }
if len(conf.AuthMethods) == 0 { if len(conf.AuthMethods) == 0 {
conf.AuthMethods = []string{"basic", "digest"} conf.AuthMethods = AuthMethods{headers.AuthBasic, headers.AuthDigest}
}
for _, method := range conf.AuthMethods {
switch method {
case "basic":
conf.AuthMethodsParsed = append(conf.AuthMethodsParsed, headers.AuthBasic)
case "digest":
conf.AuthMethodsParsed = append(conf.AuthMethodsParsed, headers.AuthDigest)
default:
return fmt.Errorf("unsupported authentication method: %s", method)
}
} }
if conf.RTMPAddress == "" { if conf.RTMPAddress == "" {
@@ -355,12 +269,15 @@ func (conf *Conf) CheckAndFillMissing() error {
if conf.HLSAddress == "" { if conf.HLSAddress == "" {
conf.HLSAddress = ":8888" conf.HLSAddress = ":8888"
} }
if conf.HLSSegmentCount == 0 { if conf.HLSSegmentCount == 0 {
conf.HLSSegmentCount = 3 conf.HLSSegmentCount = 3
} }
if conf.HLSSegmentDuration == 0 { if conf.HLSSegmentDuration == 0 {
conf.HLSSegmentDuration = 1 * StringDuration(time.Second) conf.HLSSegmentDuration = 1 * StringDuration(time.Second)
} }
if conf.HLSAllowOrigin == "" { if conf.HLSAllowOrigin == "" {
conf.HLSAllowOrigin = "*" conf.HLSAllowOrigin = "*"
} }

View File

@@ -45,7 +45,6 @@ paths:
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, &PathConf{ require.Equal(t, &PathConf{
Source: "publisher", Source: "publisher",
SourceProtocol: "",
SourceOnDemandStartTimeout: 10 * StringDuration(time.Second), SourceOnDemandStartTimeout: 10 * StringDuration(time.Second),
SourceOnDemandCloseAfter: 10 * StringDuration(time.Second), SourceOnDemandCloseAfter: 10 * StringDuration(time.Second),
RunOnDemandStartTimeout: 5 * StringDuration(time.Second), RunOnDemandStartTimeout: 5 * StringDuration(time.Second),
@@ -69,7 +68,6 @@ func TestConfFromFileAndEnv(t *testing.T) {
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, &PathConf{ require.Equal(t, &PathConf{
Source: "rtsp://testing", Source: "rtsp://testing",
SourceProtocol: "automatic",
SourceOnDemandStartTimeout: 10 * StringDuration(time.Second), SourceOnDemandStartTimeout: 10 * StringDuration(time.Second),
SourceOnDemandCloseAfter: 10 * StringDuration(time.Second), SourceOnDemandCloseAfter: 10 * StringDuration(time.Second),
RunOnDemandStartTimeout: 10 * StringDuration(time.Second), RunOnDemandStartTimeout: 10 * StringDuration(time.Second),
@@ -89,7 +87,6 @@ func TestConfFromEnvOnly(t *testing.T) {
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, &PathConf{ require.Equal(t, &PathConf{
Source: "rtsp://testing", Source: "rtsp://testing",
SourceProtocol: "automatic",
SourceOnDemandStartTimeout: 10 * StringDuration(time.Second), SourceOnDemandStartTimeout: 10 * StringDuration(time.Second),
SourceOnDemandCloseAfter: 10 * StringDuration(time.Second), SourceOnDemandCloseAfter: 10 * StringDuration(time.Second),
RunOnDemandStartTimeout: 10 * StringDuration(time.Second), RunOnDemandStartTimeout: 10 * StringDuration(time.Second),

View File

@@ -0,0 +1,58 @@
package conf
import (
"encoding/json"
"fmt"
)
// Encryption is the encryption parameter.
type Encryption int
// supported encryption policies.
const (
EncryptionNo Encryption = iota
EncryptionOptional
EncryptionStrict
)
// MarshalJSON marshals a Encryption into JSON.
func (d Encryption) MarshalJSON() ([]byte, error) {
var out string
switch d {
case EncryptionNo:
out = "no"
case EncryptionOptional:
out = "optional"
default:
out = "strict"
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a Encryption from JSON.
func (d *Encryption) UnmarshalJSON(b []byte) error {
var in string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
switch in {
case "no", "false":
*d = EncryptionNo
case "optional":
*d = EncryptionOptional
case "strict", "yes", "true":
*d = EncryptionStrict
default:
return fmt.Errorf("unsupported encryption value: '%s'", in)
}
return nil
}

View File

@@ -1,36 +1,35 @@
package conf package conf
import ( import (
"encoding/json"
"fmt" "fmt"
"os" "os"
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"time"
) )
func loadEnvInternal(env map[string]string, prefix string, rv reflect.Value) error { func loadEnvInternal(env map[string]string, prefix string, rv reflect.Value) error {
rt := rv.Type() rt := rv.Type()
if rt == reflect.TypeOf(StringDuration(0)) { if i, ok := rv.Addr().Interface().(json.Unmarshaler); ok {
if ev, ok := env[prefix]; ok { if ev, ok := env[prefix]; ok {
d, err := time.ParseDuration(ev) err := i.UnmarshalJSON([]byte(`"` + ev + `"`))
if err != nil { if err != nil {
return fmt.Errorf("%s: %s", prefix, err) return fmt.Errorf("%s: %s", prefix, err)
} }
rv.Set(reflect.ValueOf(StringDuration(d)))
} }
return nil return nil
} }
switch rt.Kind() { switch rt {
case reflect.String: case reflect.TypeOf(""):
if ev, ok := env[prefix]; ok { if ev, ok := env[prefix]; ok {
rv.SetString(ev) rv.SetString(ev)
} }
return nil return nil
case reflect.Int: case reflect.TypeOf(int(0)):
if ev, ok := env[prefix]; ok { if ev, ok := env[prefix]; ok {
iv, err := strconv.ParseInt(ev, 10, 64) iv, err := strconv.ParseInt(ev, 10, 64)
if err != nil { if err != nil {
@@ -40,7 +39,7 @@ func loadEnvInternal(env map[string]string, prefix string, rv reflect.Value) err
} }
return nil return nil
case reflect.Uint64: case reflect.TypeOf(uint64(0)):
if ev, ok := env[prefix]; ok { if ev, ok := env[prefix]; ok {
iv, err := strconv.ParseUint(ev, 10, 64) iv, err := strconv.ParseUint(ev, 10, 64)
if err != nil { if err != nil {
@@ -50,7 +49,7 @@ func loadEnvInternal(env map[string]string, prefix string, rv reflect.Value) err
} }
return nil return nil
case reflect.Bool: case reflect.TypeOf(bool(false)):
if ev, ok := env[prefix]; ok { if ev, ok := env[prefix]; ok {
switch strings.ToLower(ev) { switch strings.ToLower(ev) {
case "yes", "true": case "yes", "true":
@@ -64,7 +63,9 @@ func loadEnvInternal(env map[string]string, prefix string, rv reflect.Value) err
} }
} }
return nil return nil
}
switch rt.Kind() {
case reflect.Slice: case reflect.Slice:
if rt.Elem().Kind() == reflect.String { if rt.Elem().Kind() == reflect.String {
if ev, ok := env[prefix]; ok { if ev, ok := env[prefix]; ok {

View File

@@ -0,0 +1,45 @@
package conf
import (
"encoding/json"
"fmt"
"net"
)
// IPsOrNets is a parameter that acceps IPs or subnets.
type IPsOrNets []interface{}
// MarshalJSON marshals a IPsOrNets into JSON.
func (d IPsOrNets) MarshalJSON() ([]byte, error) {
out := make([]string, len(d))
for i, v := range d {
out[i] = v.(fmt.Stringer).String()
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a IPsOrNets from JSON.
func (d *IPsOrNets) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
if len(in) == 0 {
return nil
}
for _, t := range in {
if _, ipnet, err := net.ParseCIDR(t); err == nil {
*d = append(*d, ipnet)
} else if ip := net.ParseIP(t); ip != nil {
*d = append(*d, ip)
} else {
return fmt.Errorf("unable to parse ip/network '%s'", t)
}
}
return nil
}

View File

@@ -0,0 +1,65 @@
package conf
import (
"encoding/json"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/logger"
)
// LogDestinations is the logDestionations parameter.
type LogDestinations map[logger.Destination]struct{}
// MarshalJSON marshals a LogDestinations into JSON.
func (d LogDestinations) MarshalJSON() ([]byte, error) {
out := make([]string, len(d))
i := 0
for p := range d {
var v string
switch p {
case logger.DestinationStdout:
v = "stdout"
case logger.DestinationFile:
v = "file"
default:
v = "syslog"
}
out[i] = v
i++
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a LogDestinations from JSON.
func (d *LogDestinations) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
*d = make(LogDestinations)
for _, proto := range in {
switch proto {
case "stdout":
(*d)[logger.DestinationStdout] = struct{}{}
case "file":
(*d)[logger.DestinationFile] = struct{}{}
case "syslog":
(*d)[logger.DestinationSyslog] = struct{}{}
default:
return fmt.Errorf("unsupported log destination: %s", proto)
}
}
return nil
}

53
internal/conf/loglevel.go Normal file
View File

@@ -0,0 +1,53 @@
package conf
import (
"encoding/json"
"fmt"
"github.com/aler9/rtsp-simple-server/internal/logger"
)
// LogLevel is the logLevel parameter.
type LogLevel logger.Level
// MarshalJSON marshals a LogLevel into JSON.
func (d LogLevel) MarshalJSON() ([]byte, error) {
var out string
switch d {
case LogLevel(logger.Warn):
out = "warn"
case LogLevel(logger.Info):
out = "info"
default:
out = "debug"
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a LogLevel from JSON.
func (d *LogLevel) UnmarshalJSON(b []byte) error {
var in string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
switch in {
case "warn":
*d = LogLevel(logger.Warn)
case "info":
*d = LogLevel(logger.Info)
case "debug":
*d = LogLevel(logger.Debug)
default:
return fmt.Errorf("unsupported log level: %s", in)
}
return nil
}

View File

@@ -3,13 +3,11 @@ package conf
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net"
"net/url" "net/url"
"regexp" "regexp"
"strings" "strings"
"time" "time"
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
) )
@@ -19,30 +17,6 @@ var reUserPass = regexp.MustCompile(`^[a-zA-Z0-9!\$\(\)\*\+\.;<=>\[\]\^_\-\{\}]+
var rePathName = regexp.MustCompile(`^[0-9a-zA-Z_\-/\.~]+$`) var rePathName = regexp.MustCompile(`^[0-9a-zA-Z_\-/\.~]+$`)
func parseIPCidrList(in []string) ([]interface{}, error) {
if len(in) == 0 {
return nil, nil
}
var ret []interface{}
for _, t := range in {
_, ipnet, err := net.ParseCIDR(t)
if err == nil {
ret = append(ret, ipnet)
continue
}
ip := net.ParseIP(t)
if ip != nil {
ret = append(ret, ip)
continue
}
return nil, fmt.Errorf("unable to parse ip/network '%s'", t)
}
return ret, nil
}
// IsValidPathName checks if a path name is valid. // IsValidPathName checks if a path name is valid.
func IsValidPathName(name string) error { func IsValidPathName(name string) error {
if name == "" { if name == "" {
@@ -70,8 +44,7 @@ type PathConf struct {
// source // source
Source string `json:"source"` Source string `json:"source"`
SourceProtocol string `json:"sourceProtocol"` SourceProtocol SourceProtocol `json:"sourceProtocol"`
SourceProtocolParsed *gortsplib.ClientProtocol `json:"-"`
SourceAnyPortEnable bool `json:"sourceAnyPortEnable"` SourceAnyPortEnable bool `json:"sourceAnyPortEnable"`
SourceFingerprint string `json:"sourceFingerprint"` SourceFingerprint string `json:"sourceFingerprint"`
SourceOnDemand bool `json:"sourceOnDemand"` SourceOnDemand bool `json:"sourceOnDemand"`
@@ -84,12 +57,10 @@ type PathConf struct {
// authentication // authentication
PublishUser string `json:"publishUser"` PublishUser string `json:"publishUser"`
PublishPass string `json:"publishPass"` PublishPass string `json:"publishPass"`
PublishIPs []string `json:"publishIPs"` PublishIPs IPsOrNets `json:"publishIPs"`
PublishIPsParsed []interface{} `json:"-"`
ReadUser string `json:"readUser"` ReadUser string `json:"readUser"`
ReadPass string `json:"readPass"` ReadPass string `json:"readPass"`
ReadIPs []string `json:"readIPs"` ReadIPs IPsOrNets `json:"readIPs"`
ReadIPsParsed []interface{} `json:"-"`
// custom commands // custom commands
RunOnInit string `json:"runOnInit"` RunOnInit string `json:"runOnInit"`
@@ -143,29 +114,6 @@ func (pconf *PathConf) checkAndFillMissing(name string) error {
return fmt.Errorf("'%s' is not a valid RTSP URL", pconf.Source) return fmt.Errorf("'%s' is not a valid RTSP URL", pconf.Source)
} }
if pconf.SourceProtocol == "" {
pconf.SourceProtocol = "automatic"
}
switch pconf.SourceProtocol {
case "udp":
v := gortsplib.ClientProtocolUDP
pconf.SourceProtocolParsed = &v
case "multicast":
v := gortsplib.ClientProtocolMulticast
pconf.SourceProtocolParsed = &v
case "tcp":
v := gortsplib.ClientProtocolTCP
pconf.SourceProtocolParsed = &v
case "automatic":
default:
return fmt.Errorf("unsupported protocol '%s'", pconf.SourceProtocol)
}
if strings.HasPrefix(pconf.Source, "rtsps://") && pconf.SourceFingerprint == "" { if strings.HasPrefix(pconf.Source, "rtsps://") && pconf.SourceFingerprint == "" {
return fmt.Errorf("sourceFingerprint is required with a RTSPS URL") return fmt.Errorf("sourceFingerprint is required with a RTSPS URL")
} }
@@ -258,9 +206,11 @@ func (pconf *PathConf) checkAndFillMissing(name string) error {
} }
} }
if (pconf.PublishUser != "" && pconf.PublishPass == "") || (pconf.PublishUser == "" && pconf.PublishPass != "") { if (pconf.PublishUser != "" && pconf.PublishPass == "") ||
(pconf.PublishUser == "" && pconf.PublishPass != "") {
return fmt.Errorf("read username and password must be both filled") return fmt.Errorf("read username and password must be both filled")
} }
if pconf.PublishUser != "" { if pconf.PublishUser != "" {
if pconf.Source != "publisher" { if pconf.Source != "publisher" {
return fmt.Errorf("'publishUser' is useless when source is not 'publisher'") return fmt.Errorf("'publishUser' is useless when source is not 'publisher'")
@@ -270,6 +220,7 @@ func (pconf *PathConf) checkAndFillMissing(name string) error {
return fmt.Errorf("publish username contains unsupported characters (supported are %s)", userPassSupportedChars) return fmt.Errorf("publish username contains unsupported characters (supported are %s)", userPassSupportedChars)
} }
} }
if pconf.PublishPass != "" { if pconf.PublishPass != "" {
if pconf.Source != "publisher" { if pconf.Source != "publisher" {
return fmt.Errorf("'publishPass' is useless when source is not 'publisher', since " + return fmt.Errorf("'publishPass' is useless when source is not 'publisher', since " +
@@ -280,48 +231,28 @@ func (pconf *PathConf) checkAndFillMissing(name string) error {
return fmt.Errorf("publish password contains unsupported characters (supported are %s)", userPassSupportedChars) return fmt.Errorf("publish password contains unsupported characters (supported are %s)", userPassSupportedChars)
} }
} }
if len(pconf.PublishIPs) == 0 {
pconf.PublishIPs = nil
}
var err error
pconf.PublishIPsParsed, err = func() ([]interface{}, error) {
if len(pconf.PublishIPs) == 0 {
return nil, nil
}
if pconf.Source != "publisher" { if len(pconf.PublishIPs) > 0 && pconf.Source != "publisher" {
return nil, fmt.Errorf("'publishIPs' is useless when source is not 'publisher', since " + return fmt.Errorf("'publishIPs' is useless when source is not 'publisher', since " +
"the stream is not provided by a publisher, but by a fixed source") "the stream is not provided by a publisher, but by a fixed source")
} }
return parseIPCidrList(pconf.PublishIPs) if (pconf.ReadUser != "" && pconf.ReadPass == "") ||
}() (pconf.ReadUser == "" && pconf.ReadPass != "") {
if err != nil {
return err
}
if (pconf.ReadUser != "" && pconf.ReadPass == "") || (pconf.ReadUser == "" && pconf.ReadPass != "") {
return fmt.Errorf("read username and password must be both filled") return fmt.Errorf("read username and password must be both filled")
} }
if pconf.ReadUser != "" { if pconf.ReadUser != "" {
if !strings.HasPrefix(pconf.ReadUser, "sha256:") && !reUserPass.MatchString(pconf.ReadUser) { if !strings.HasPrefix(pconf.ReadUser, "sha256:") && !reUserPass.MatchString(pconf.ReadUser) {
return fmt.Errorf("read username contains unsupported characters (supported are %s)", userPassSupportedChars) return fmt.Errorf("read username contains unsupported characters (supported are %s)", userPassSupportedChars)
} }
} }
if pconf.ReadPass != "" { if pconf.ReadPass != "" {
if !strings.HasPrefix(pconf.ReadPass, "sha256:") && !reUserPass.MatchString(pconf.ReadPass) { if !strings.HasPrefix(pconf.ReadPass, "sha256:") && !reUserPass.MatchString(pconf.ReadPass) {
return fmt.Errorf("read password contains unsupported characters (supported are %s)", userPassSupportedChars) return fmt.Errorf("read password contains unsupported characters (supported are %s)", userPassSupportedChars)
} }
} }
if len(pconf.ReadIPs) == 0 {
pconf.ReadIPs = nil
}
pconf.ReadIPsParsed, err = func() ([]interface{}, error) {
return parseIPCidrList(pconf.ReadIPs)
}()
if err != nil {
return err
}
if pconf.RunOnInit != "" && pconf.Regexp != nil { if pconf.RunOnInit != "" && pconf.Regexp != nil {
return fmt.Errorf("a path with a regular expression does not support option 'runOnInit'; use another path") return fmt.Errorf("a path with a regular expression does not support option 'runOnInit'; use another path")

71
internal/conf/protocol.go Normal file
View File

@@ -0,0 +1,71 @@
package conf
import (
"encoding/json"
"fmt"
)
// Protocol is a RTSP stream protocol.
type Protocol int
// supported RTSP protocols.
const (
ProtocolUDP Protocol = iota
ProtocolMulticast
ProtocolTCP
)
// Protocols is the protocols parameter.
type Protocols map[Protocol]struct{}
// MarshalJSON marshals a Protocols into JSON.
func (d Protocols) MarshalJSON() ([]byte, error) {
out := make([]string, len(d))
for p := range d {
var v string
switch p {
case ProtocolUDP:
v = "udp"
case ProtocolMulticast:
v = "multicast"
default:
v = "tcp"
}
out = append(out, v)
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a Protocols from JSON.
func (d *Protocols) UnmarshalJSON(b []byte) error {
var in []string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
*d = make(Protocols)
for _, proto := range in {
switch proto {
case "udp":
(*d)[ProtocolUDP] = struct{}{}
case "multicast":
(*d)[ProtocolMulticast] = struct{}{}
case "tcp":
(*d)[ProtocolTCP] = struct{}{}
default:
return fmt.Errorf("unsupported protocol: %s", proto)
}
}
return nil
}

View File

@@ -0,0 +1,64 @@
package conf
import (
"encoding/json"
"fmt"
"github.com/aler9/gortsplib"
)
// SourceProtocol is the sourceProtocol parameter.
type SourceProtocol struct {
*gortsplib.ClientProtocol
}
// MarshalJSON marshals a SourceProtocol into JSON.
func (d SourceProtocol) MarshalJSON() ([]byte, error) {
var out string
if d.ClientProtocol == nil {
out = "automatic"
} else {
switch *d.ClientProtocol {
case gortsplib.ClientProtocolUDP:
out = "udp"
case gortsplib.ClientProtocolMulticast:
out = "multicast"
default:
out = "tcp"
}
}
return json.Marshal(out)
}
// UnmarshalJSON unmarshals a SourceProtocol from JSON.
func (d *SourceProtocol) UnmarshalJSON(b []byte) error {
var in string
if err := json.Unmarshal(b, &in); err != nil {
return err
}
switch in {
case "udp":
v := gortsplib.ClientProtocolUDP
d.ClientProtocol = &v
case "multicast":
v := gortsplib.ClientProtocolMulticast
d.ClientProtocol = &v
case "tcp":
v := gortsplib.ClientProtocolTCP
d.ClientProtocol = &v
case "automatic":
default:
return fmt.Errorf("unsupported protocol '%s'", in)
}
return nil
}

View File

@@ -2,7 +2,6 @@ package conf
import ( import (
"encoding/json" "encoding/json"
"errors"
"time" "time"
) )
@@ -10,28 +9,23 @@ import (
// Durations are normally unmarshaled from numbers. // Durations are normally unmarshaled from numbers.
type StringDuration time.Duration type StringDuration time.Duration
// MarshalJSON marshals a StringDuration into a string. // MarshalJSON marshals a StringDuration into JSON.
func (d StringDuration) MarshalJSON() ([]byte, error) { func (d StringDuration) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Duration(d).String()) return json.Marshal(time.Duration(d).String())
} }
// UnmarshalJSON unmarshals a StringDuration from a string. // UnmarshalJSON unmarshals a StringDuration from JSON.
func (d *StringDuration) UnmarshalJSON(b []byte) error { func (d *StringDuration) UnmarshalJSON(b []byte) error {
var v interface{} var in string
if err := json.Unmarshal(b, &v); err != nil { if err := json.Unmarshal(b, &in); err != nil {
return err return err
} }
value, ok := v.(string) du, err := time.ParseDuration(in)
if !ok {
return errors.New("invalid duration")
}
du, err := time.ParseDuration(value)
if err != nil { if err != nil {
return err return err
} }
*d = StringDuration(du) *d = StringDuration(du)
return nil return nil
} }

View File

@@ -182,8 +182,8 @@ func (p *Core) createResources(initial bool) error {
if p.logger == nil { if p.logger == nil {
p.logger, err = logger.New( p.logger, err = logger.New(
p.conf.LogLevelParsed, logger.Level(p.conf.LogLevel),
p.conf.LogDestinationsParsed, p.conf.LogDestinations,
p.conf.LogFile) p.conf.LogFile)
if err != nil { if err != nil {
return err return err
@@ -234,15 +234,15 @@ func (p *Core) createResources(initial bool) error {
} }
if !p.conf.RTSPDisable && if !p.conf.RTSPDisable &&
(p.conf.EncryptionParsed == conf.EncryptionNo || (p.conf.Encryption == conf.EncryptionNo ||
p.conf.EncryptionParsed == conf.EncryptionOptional) { p.conf.Encryption == conf.EncryptionOptional) {
if p.rtspServer == nil { if p.rtspServer == nil {
_, useUDP := p.conf.ProtocolsParsed[conf.ProtocolUDP] _, useUDP := p.conf.Protocols[conf.ProtocolUDP]
_, useMulticast := p.conf.ProtocolsParsed[conf.ProtocolMulticast] _, useMulticast := p.conf.Protocols[conf.ProtocolMulticast]
p.rtspServer, err = newRTSPServer( p.rtspServer, err = newRTSPServer(
p.ctx, p.ctx,
p.conf.RTSPAddress, p.conf.RTSPAddress,
p.conf.AuthMethodsParsed, p.conf.AuthMethods,
p.conf.ReadTimeout, p.conf.ReadTimeout,
p.conf.WriteTimeout, p.conf.WriteTimeout,
p.conf.ReadBufferCount, p.conf.ReadBufferCount,
@@ -258,7 +258,7 @@ func (p *Core) createResources(initial bool) error {
"", "",
"", "",
p.conf.RTSPAddress, p.conf.RTSPAddress,
p.conf.ProtocolsParsed, p.conf.Protocols,
p.conf.RunOnConnect, p.conf.RunOnConnect,
p.conf.RunOnConnectRestart, p.conf.RunOnConnectRestart,
p.metrics, p.metrics,
@@ -271,13 +271,13 @@ func (p *Core) createResources(initial bool) error {
} }
if !p.conf.RTSPDisable && if !p.conf.RTSPDisable &&
(p.conf.EncryptionParsed == conf.EncryptionStrict || (p.conf.Encryption == conf.EncryptionStrict ||
p.conf.EncryptionParsed == conf.EncryptionOptional) { p.conf.Encryption == conf.EncryptionOptional) {
if p.rtspsServer == nil { if p.rtspsServer == nil {
p.rtspsServer, err = newRTSPServer( p.rtspsServer, err = newRTSPServer(
p.ctx, p.ctx,
p.conf.RTSPSAddress, p.conf.RTSPSAddress,
p.conf.AuthMethodsParsed, p.conf.AuthMethods,
p.conf.ReadTimeout, p.conf.ReadTimeout,
p.conf.WriteTimeout, p.conf.WriteTimeout,
p.conf.ReadBufferCount, p.conf.ReadBufferCount,
@@ -293,7 +293,7 @@ func (p *Core) createResources(initial bool) error {
p.conf.ServerCert, p.conf.ServerCert,
p.conf.ServerKey, p.conf.ServerKey,
p.conf.RTSPAddress, p.conf.RTSPAddress,
p.conf.ProtocolsParsed, p.conf.Protocols,
p.conf.RunOnConnect, p.conf.RunOnConnect,
p.conf.RunOnConnectRestart, p.conf.RunOnConnectRestart,
p.metrics, p.metrics,
@@ -370,7 +370,7 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
closeLogger := false closeLogger := false
if newConf == nil || if newConf == nil ||
!reflect.DeepEqual(newConf.LogDestinationsParsed, p.conf.LogDestinationsParsed) || !reflect.DeepEqual(newConf.LogDestinations, p.conf.LogDestinations) ||
newConf.LogFile != p.conf.LogFile { newConf.LogFile != p.conf.LogFile {
closeLogger = true closeLogger = true
} }
@@ -406,20 +406,20 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
closeRTSPServer := false closeRTSPServer := false
if newConf == nil || if newConf == nil ||
newConf.RTSPDisable != p.conf.RTSPDisable || newConf.RTSPDisable != p.conf.RTSPDisable ||
newConf.EncryptionParsed != p.conf.EncryptionParsed || newConf.Encryption != p.conf.Encryption ||
newConf.RTSPAddress != p.conf.RTSPAddress || newConf.RTSPAddress != p.conf.RTSPAddress ||
!reflect.DeepEqual(newConf.AuthMethodsParsed, p.conf.AuthMethodsParsed) || !reflect.DeepEqual(newConf.AuthMethods, p.conf.AuthMethods) ||
newConf.ReadTimeout != p.conf.ReadTimeout || newConf.ReadTimeout != p.conf.ReadTimeout ||
newConf.WriteTimeout != p.conf.WriteTimeout || newConf.WriteTimeout != p.conf.WriteTimeout ||
newConf.ReadBufferCount != p.conf.ReadBufferCount || newConf.ReadBufferCount != p.conf.ReadBufferCount ||
!reflect.DeepEqual(newConf.ProtocolsParsed, p.conf.ProtocolsParsed) || !reflect.DeepEqual(newConf.Protocols, p.conf.Protocols) ||
newConf.RTPAddress != p.conf.RTPAddress || newConf.RTPAddress != p.conf.RTPAddress ||
newConf.RTCPAddress != p.conf.RTCPAddress || newConf.RTCPAddress != p.conf.RTCPAddress ||
newConf.MulticastIPRange != p.conf.MulticastIPRange || newConf.MulticastIPRange != p.conf.MulticastIPRange ||
newConf.MulticastRTPPort != p.conf.MulticastRTPPort || newConf.MulticastRTPPort != p.conf.MulticastRTPPort ||
newConf.MulticastRTCPPort != p.conf.MulticastRTCPPort || newConf.MulticastRTCPPort != p.conf.MulticastRTCPPort ||
newConf.RTSPAddress != p.conf.RTSPAddress || newConf.RTSPAddress != p.conf.RTSPAddress ||
!reflect.DeepEqual(newConf.ProtocolsParsed, p.conf.ProtocolsParsed) || !reflect.DeepEqual(newConf.Protocols, p.conf.Protocols) ||
newConf.RunOnConnect != p.conf.RunOnConnect || newConf.RunOnConnect != p.conf.RunOnConnect ||
newConf.RunOnConnectRestart != p.conf.RunOnConnectRestart || newConf.RunOnConnectRestart != p.conf.RunOnConnectRestart ||
closeMetrics || closeMetrics ||
@@ -430,16 +430,16 @@ func (p *Core) closeResources(newConf *conf.Conf, calledByAPI bool) {
closeRTSPSServer := false closeRTSPSServer := false
if newConf == nil || if newConf == nil ||
newConf.RTSPDisable != p.conf.RTSPDisable || newConf.RTSPDisable != p.conf.RTSPDisable ||
newConf.EncryptionParsed != p.conf.EncryptionParsed || newConf.Encryption != p.conf.Encryption ||
newConf.RTSPSAddress != p.conf.RTSPSAddress || newConf.RTSPSAddress != p.conf.RTSPSAddress ||
!reflect.DeepEqual(newConf.AuthMethodsParsed, p.conf.AuthMethodsParsed) || !reflect.DeepEqual(newConf.AuthMethods, p.conf.AuthMethods) ||
newConf.ReadTimeout != p.conf.ReadTimeout || newConf.ReadTimeout != p.conf.ReadTimeout ||
newConf.WriteTimeout != p.conf.WriteTimeout || newConf.WriteTimeout != p.conf.WriteTimeout ||
newConf.ReadBufferCount != p.conf.ReadBufferCount || newConf.ReadBufferCount != p.conf.ReadBufferCount ||
newConf.ServerCert != p.conf.ServerCert || newConf.ServerCert != p.conf.ServerCert ||
newConf.ServerKey != p.conf.ServerKey || newConf.ServerKey != p.conf.ServerKey ||
newConf.RTSPAddress != p.conf.RTSPAddress || newConf.RTSPAddress != p.conf.RTSPAddress ||
!reflect.DeepEqual(newConf.ProtocolsParsed, p.conf.ProtocolsParsed) || !reflect.DeepEqual(newConf.Protocols, p.conf.Protocols) ||
newConf.RunOnConnect != p.conf.RunOnConnect || newConf.RunOnConnect != p.conf.RunOnConnect ||
newConf.RunOnConnectRestart != p.conf.RunOnConnectRestart || newConf.RunOnConnectRestart != p.conf.RunOnConnectRestart ||
closeMetrics || closeMetrics ||

View File

@@ -411,10 +411,10 @@ func (r *hlsMuxer) handleRequest(req hlsMuxerRequest) {
conf := r.path.Conf() conf := r.path.Conf()
if conf.ReadIPsParsed != nil { if conf.ReadIPs != nil {
tmp, _, _ := net.SplitHostPort(req.Req.RemoteAddr) tmp, _, _ := net.SplitHostPort(req.Req.RemoteAddr)
ip := net.ParseIP(tmp) ip := net.ParseIP(tmp)
if !ipEqualOrInRange(ip, conf.ReadIPsParsed) { if !ipEqualOrInRange(ip, conf.ReadIPs) {
r.log(logger.Info, "ERR: ip '%s' not allowed", ip) r.log(logger.Info, "ERR: ip '%s' not allowed", ip)
req.W.WriteHeader(http.StatusUnauthorized) req.W.WriteHeader(http.StatusUnauthorized)
req.Res <- nil req.Res <- nil

View File

@@ -586,7 +586,7 @@ func (pa *path) staticSourceCreate() {
pa.source = newRTSPSource( pa.source = newRTSPSource(
pa.ctx, pa.ctx,
pa.conf.Source, pa.conf.Source,
pa.conf.SourceProtocolParsed, pa.conf.SourceProtocol,
pa.conf.SourceAnyPortEnable, pa.conf.SourceAnyPortEnable,
pa.conf.SourceFingerprint, pa.conf.SourceFingerprint,
pa.readTimeout, pa.readTimeout,

View File

@@ -177,7 +177,7 @@ outer:
req.IP, req.IP,
req.ValidateCredentials, req.ValidateCredentials,
req.PathName, req.PathName,
pathConf.ReadIPsParsed, pathConf.ReadIPs,
pathConf.ReadUser, pathConf.ReadUser,
pathConf.ReadPass, pathConf.ReadPass,
) )
@@ -204,7 +204,7 @@ outer:
req.IP, req.IP,
req.ValidateCredentials, req.ValidateCredentials,
req.PathName, req.PathName,
pathConf.ReadIPsParsed, pathConf.ReadIPs,
pathConf.ReadUser, pathConf.ReadUser,
pathConf.ReadPass, pathConf.ReadPass,
) )
@@ -231,7 +231,7 @@ outer:
req.IP, req.IP,
req.ValidateCredentials, req.ValidateCredentials,
req.PathName, req.PathName,
pathConf.PublishIPsParsed, pathConf.PublishIPs,
pathConf.PublishUser, pathConf.PublishUser,
pathConf.PublishPass, pathConf.PublishPass,
) )

View File

@@ -76,7 +76,7 @@ func TestRTSPServerPublishRead(t *testing.T) {
"hlsDisable: yes\n" + "hlsDisable: yes\n" +
"readTimeout: 20s\n" + "readTimeout: 20s\n" +
"protocols: [tcp]\n" + "protocols: [tcp]\n" +
"encryption: yes\n" + "encryption: \"yes\"\n" +
"serverCert: " + serverCertFpath + "\n" + "serverCert: " + serverCertFpath + "\n" +
"serverKey: " + serverKeyFpath + "\n") "serverKey: " + serverKeyFpath + "\n")
require.Equal(t, true, ok) require.Equal(t, true, ok)

View File

@@ -29,7 +29,7 @@ type rtspSourceParent interface {
type rtspSource struct { type rtspSource struct {
ur string ur string
proto *gortsplib.ClientProtocol proto conf.SourceProtocol
anyPortEnable bool anyPortEnable bool
fingerprint string fingerprint string
readTimeout conf.StringDuration readTimeout conf.StringDuration
@@ -46,7 +46,7 @@ type rtspSource struct {
func newRTSPSource( func newRTSPSource(
parentCtx context.Context, parentCtx context.Context,
ur string, ur string,
proto *gortsplib.ClientProtocol, proto conf.SourceProtocol,
anyPortEnable bool, anyPortEnable bool,
fingerprint string, fingerprint string,
readTimeout conf.StringDuration, readTimeout conf.StringDuration,
@@ -118,7 +118,7 @@ func (s *rtspSource) runInner() bool {
s.log(logger.Debug, "connecting") s.log(logger.Debug, "connecting")
client := &gortsplib.Client{ client := &gortsplib.Client{
Protocol: s.proto, Protocol: s.proto.ClientProtocol,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
VerifyConnection: func(cs tls.ConnectionState) error { VerifyConnection: func(cs tls.ConnectionState) error {