diff --git a/conn-client.go b/conn-client.go index dff2f0b0..45ca5e72 100644 --- a/conn-client.go +++ b/conn-client.go @@ -13,6 +13,12 @@ type ConnClientConf struct { // pre-existing TCP connection that will be wrapped NConn net.Conn + // (optional) a username to authenticate with the server + Username string + + // (optional) a password to authenticate with the server + Password string + // (optional) timeout for read requests. // It defaults to 5 seconds ReadTimeout time.Duration @@ -41,7 +47,7 @@ type ConnClient struct { } // NewConnClient allocates a ConnClient. See ConnClientConf for the options. -func NewConnClient(conf ConnClientConf) *ConnClient { +func NewConnClient(conf ConnClientConf) (*ConnClient, error) { if conf.ReadTimeout == time.Duration(0) { conf.ReadTimeout = 5 * time.Second } @@ -55,11 +61,16 @@ func NewConnClient(conf ConnClientConf) *ConnClient { conf.WriteBufferSize = 4096 } + if conf.Username != "" && conf.Password == "" || + conf.Username == "" && conf.Password != "" { + return nil, fmt.Errorf("both username and password must be provided") + } + return &ConnClient{ conf: conf, br: bufio.NewReaderSize(conf.NConn, conf.ReadBufferSize), bw: bufio.NewWriterSize(conf.NConn, conf.WriteBufferSize), - } + }, nil } // NetConn returns the underlying net.Conn. @@ -67,14 +78,6 @@ func (c *ConnClient) NetConn() net.Conn { return c.conf.NConn } -// SetCredentials allows to automatically insert the Authenticate header into every outgoing request. -// The content of the header is computed with the given user, password, realm and nonce. -func (c *ConnClient) SetCredentials(wwwAuthenticateHeader []string, user string, pass string) error { - var err error - c.auth, err = NewAuthClient(wwwAuthenticateHeader, user, pass) - return err -} - // WriteRequest writes a request and reads a response. func (c *ConnClient) WriteRequest(req *Request) (*Response, error) { if req.Header == nil { @@ -108,14 +111,24 @@ func (c *ConnClient) WriteRequest(req *Request) (*Response, error) { } // get session from response - if res.StatusCode == StatusOK { - if sxRaw, ok := res.Header["Session"]; ok && len(sxRaw) == 1 { - sx, err := ReadHeaderSession(sxRaw[0]) - if err != nil { - return nil, fmt.Errorf("unable to parse session header: %s", err) - } - c.session = sx.Session + if sxRaw, ok := res.Header["Session"]; ok && len(sxRaw) == 1 { + sx, err := ReadHeaderSession(sxRaw[0]) + if err != nil { + return nil, fmt.Errorf("unable to parse session header: %s", err) } + c.session = sx.Session + } + + // setup authentication + if res.StatusCode == StatusUnauthorized && c.conf.Username != "" && c.auth == nil { + auth, err := NewAuthClient(res.Header["WWW-Authenticate"], c.conf.Username, c.conf.Password) + if err != nil { + return nil, fmt.Errorf("unable to setup authentication: %s", err) + } + c.auth = auth + + // send request again + return c.WriteRequest(req) } return res, nil diff --git a/response.go b/response.go index d9659ce9..14eaf0bb 100644 --- a/response.go +++ b/response.go @@ -59,6 +59,9 @@ const ( StatusProxyUnavailable StatusCode = 553 ) +// StatusMessages contains the status messages associated with each status code. +var StatusMessages = statusMessages + var statusMessages = map[StatusCode]string{ StatusContinue: "Continue",