ConnClient: automatically add session header to all requests

This commit is contained in:
aler9
2020-05-08 23:54:41 +02:00
parent 3b233ef7e9
commit 5df784b3ec

View File

@@ -2,6 +2,7 @@ package gortsplib
import ( import (
"bufio" "bufio"
"fmt"
"net" "net"
"strconv" "strconv"
"time" "time"
@@ -66,11 +67,6 @@ func (c *ConnClient) NetConn() net.Conn {
return c.conf.NConn return c.conf.NConn
} }
// SetSession sets a Session header that is automatically inserted into every outgoing request.
func (c *ConnClient) SetSession(v string) {
c.session = v
}
// SetCredentials allows to automatically insert the Authenticate header into every outgoing request. // SetCredentials allows to automatically insert the Authenticate header into every outgoing request.
// The content of the header is computed with the given user, password, realm and nonce. // The content of the header is computed with the given user, password, realm and nonce.
func (c *ConnClient) SetCredentials(wwwAuthenticateHeader []string, user string, pass string) error { func (c *ConnClient) SetCredentials(wwwAuthenticateHeader []string, user string, pass string) error {
@@ -81,24 +77,21 @@ func (c *ConnClient) SetCredentials(wwwAuthenticateHeader []string, user string,
// WriteRequest writes a request and reads a response. // WriteRequest writes a request and reads a response.
func (c *ConnClient) WriteRequest(req *Request) (*Response, error) { func (c *ConnClient) WriteRequest(req *Request) (*Response, error) {
if c.session != "" {
if req.Header == nil {
req.Header = make(Header)
}
req.Header["Session"] = []string{c.session}
}
if c.auth != nil {
if req.Header == nil {
req.Header = make(Header)
}
req.Header["Authorization"] = c.auth.GenerateHeader(req.Method, req.Url)
}
// automatically insert CSeq
if req.Header == nil { if req.Header == nil {
req.Header = make(Header) req.Header = make(Header)
} }
// insert session
if c.session != "" {
req.Header["Session"] = []string{c.session}
}
// insert auth
if c.auth != nil {
req.Header["Authorization"] = c.auth.GenerateHeader(req.Method, req.Url)
}
// insert cseq
c.curCSeq += 1 c.curCSeq += 1
req.Header["CSeq"] = []string{strconv.FormatInt(int64(c.curCSeq), 10)} req.Header["CSeq"] = []string{strconv.FormatInt(int64(c.curCSeq), 10)}
@@ -109,7 +102,23 @@ func (c *ConnClient) WriteRequest(req *Request) (*Response, error) {
} }
c.conf.NConn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout)) c.conf.NConn.SetReadDeadline(time.Now().Add(c.conf.ReadTimeout))
return readResponse(c.br) res, err := readResponse(c.br)
if err != nil {
return nil, err
}
// get session from response
if res.StatusCode == StatusOK {
if sxRaw, ok := res.Header["Session"]; ok && len(sxRaw) == 1 {
sx, err := ReadHeaderSession(sxRaw[0])
if err != nil {
return nil, fmt.Errorf("unable to parse session header: %s", err)
}
c.session = sx.Session
}
}
return res, nil
} }
// ReadInterleavedFrame reads an InterleavedFrame. // ReadInterleavedFrame reads an InterleavedFrame.