// Copyright 2018 Google Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package httpclient is used by the upstreaming forwardproxy to establish connections to http(s) upstreams. // it implements x/net/proxy.Dialer interface package httpclient import ( "bufio" "context" "crypto/tls" "encoding/base64" "errors" "io" "net" "net/http" "net/url" "sync" "golang.org/x/net/http2" ) // HTTPConnectDialer allows to configure one-time use HTTP CONNECT client type HTTPConnectDialer struct { ProxyURL url.URL DefaultHeader http.Header // TODO: If spkiFp is set, use it as SPKI fingerprint to confirm identity of the // proxy, instead of relying on standard PKI CA roots SpkiFP []byte Dialer net.Dialer // overridden dialer allow to control establishment of TCP connection // overridden DialTLS allows user to control establishment of TLS connection // MUST return connection with completed Handshake, and NegotiatedProtocol DialTLS func(network string, address string) (net.Conn, string, error) EnableH2ConnReuse bool cacheH2Mu sync.Mutex cachedH2ClientConn *http2.ClientConn cachedH2RawConn net.Conn } // NewHTTPConnectDialer creates a client to issue CONNECT requests and tunnel traffic via HTTPS proxy. // proxyURLStr must provide Scheme and Host, may provide credentials and port. // Example: https://username:password@golang.org:443 func NewHTTPConnectDialer(proxyURLStr string) (*HTTPConnectDialer, error) { proxyURL, err := url.Parse(proxyURLStr) if err != nil { return nil, err } if proxyURL.Host == "" { return nil, errors.New("misparsed `url=" + proxyURLStr + "`, make sure to specify full url like https://username:password@hostname.com:443/") } switch proxyURL.Scheme { case "http": if proxyURL.Port() == "" { proxyURL.Host = net.JoinHostPort(proxyURL.Host, "80") } case "https": if proxyURL.Port() == "" { proxyURL.Host = net.JoinHostPort(proxyURL.Host, "443") } case "": return nil, errors.New("specify scheme explicitly (https://)") default: return nil, errors.New("scheme " + proxyURL.Scheme + " is not supported") } client := &HTTPConnectDialer{ ProxyURL: *proxyURL, DefaultHeader: make(http.Header), SpkiFP: nil, EnableH2ConnReuse: true, } if proxyURL.User != nil { if proxyURL.User.Username() != "" { password, _ := proxyURL.User.Password() client.DefaultHeader.Set("Proxy-Authorization", "Basic "+ base64.StdEncoding.EncodeToString([]byte(proxyURL.User.Username()+":"+password))) } } return client, nil } func (c *HTTPConnectDialer) Dial(network, address string) (net.Conn, error) { return c.DialContext(context.Background(), network, address) } // Users of context.WithValue should define their own types for keys type ContextKeyHeader struct{} // ctx.Value will be inspected for optional ContextKeyHeader{} key, with `http.Header` value, // which will be added to outgoing request headers, overriding any colliding c.DefaultHeader func (c *HTTPConnectDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { req := (&http.Request{ Method: "CONNECT", URL: &url.URL{Host: address}, Header: make(http.Header), Host: address, }).WithContext(ctx) for k, v := range c.DefaultHeader { req.Header[k] = v } if ctxHeader, ctxHasHeader := ctx.Value(ContextKeyHeader{}).(http.Header); ctxHasHeader { for k, v := range ctxHeader { req.Header[k] = v } } connectHttp2 := func(rawConn net.Conn, h2clientConn *http2.ClientConn) (net.Conn, error) { req.Proto = "HTTP/2.0" req.ProtoMajor = 2 req.ProtoMinor = 0 pr, pw := io.Pipe() req.Body = pr resp, err := h2clientConn.RoundTrip(req) if err != nil { err = rawConn.Close() return nil, err } if resp.StatusCode != http.StatusOK { _ = rawConn.Close() return nil, errors.New("Proxy responded with non 200 code: " + resp.Status) } return NewHttp2Conn(rawConn, pw, resp.Body), nil } connectHttp1 := func(rawConn net.Conn) (net.Conn, error) { req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 err := req.Write(rawConn) if err != nil { err = rawConn.Close() return nil, err } resp, err := http.ReadResponse(bufio.NewReader(rawConn), req) if err != nil { err = rawConn.Close() return nil, err } if resp.StatusCode != http.StatusOK { _ = rawConn.Close() return nil, errors.New("Proxy responded with non 200 code: " + resp.Status) } return rawConn, nil } if c.EnableH2ConnReuse { c.cacheH2Mu.Lock() unlocked := false if c.cachedH2ClientConn != nil && c.cachedH2RawConn != nil { if c.cachedH2ClientConn.CanTakeNewRequest() { rc := c.cachedH2RawConn cc := c.cachedH2ClientConn c.cacheH2Mu.Unlock() unlocked = true proxyConn, err := connectHttp2(rc, cc) if err == nil { return proxyConn, err } // else: carry on and try again } } if !unlocked { c.cacheH2Mu.Unlock() } } var err error var rawConn net.Conn negotiatedProtocol := "" switch c.ProxyURL.Scheme { case "http": rawConn, err = c.Dialer.DialContext(ctx, network, c.ProxyURL.Host) if err != nil { return nil, err } case "https": if c.DialTLS != nil { rawConn, negotiatedProtocol, err = c.DialTLS(network, c.ProxyURL.Host) if err != nil { return nil, err } } else { tlsConf := tls.Config{ NextProtos: []string{"h2", "http/1.1"}, ServerName: c.ProxyURL.Hostname(), MinVersion: tls.VersionTLS12, } tlsConn, err := tls.Dial(network, c.ProxyURL.Host, &tlsConf) if err != nil { return nil, err } err = tlsConn.Handshake() if err != nil { return nil, err } negotiatedProtocol = tlsConn.ConnectionState().NegotiatedProtocol rawConn = tlsConn } default: return nil, errors.New("scheme " + c.ProxyURL.Scheme + " is not supported") } switch negotiatedProtocol { case "": fallthrough case "http/1.1": return connectHttp1(rawConn) case "h2": t := http2.Transport{} h2clientConn, err := t.NewClientConn(rawConn) if err != nil { err = rawConn.Close() return nil, err } proxyConn, err := connectHttp2(rawConn, h2clientConn) if err != nil { err = rawConn.Close() return nil, err } if c.EnableH2ConnReuse { c.cacheH2Mu.Lock() c.cachedH2ClientConn = h2clientConn c.cachedH2RawConn = rawConn c.cacheH2Mu.Unlock() } return proxyConn, err default: _ = rawConn.Close() return nil, errors.New("negotiated unsupported application layer protocol: " + negotiatedProtocol) } } func NewHttp2Conn(c net.Conn, pipedReqBody *io.PipeWriter, respBody io.ReadCloser) net.Conn { return &http2Conn{Conn: c, in: pipedReqBody, out: respBody} } type http2Conn struct { net.Conn in *io.PipeWriter out io.ReadCloser } func (h *http2Conn) Read(p []byte) (n int, err error) { return h.out.Read(p) } func (h *http2Conn) Write(p []byte) (n int, err error) { return h.in.Write(p) } func (h *http2Conn) Close() error { inErr := h.in.Close() outErr := h.out.Close() if inErr != nil { return inErr } return outErr } func (h *http2Conn) CloseConn() error { return h.Conn.Close() } func (h *http2Conn) CloseWrite() error { return h.in.Close() } func (h *http2Conn) CloseRead() error { return h.out.Close() }