diff --git a/buf.go b/buf.go new file mode 100644 index 0000000..96b4c60 --- /dev/null +++ b/buf.go @@ -0,0 +1,9 @@ +package mps + +type bufferPool struct { + get func() []byte + put func([]byte) +} + +func (bp bufferPool) Get() []byte { return bp.get() } +func (bp bufferPool) Put(v []byte) { bp.put(v) } diff --git a/cert/cert.go b/cert/cert.go new file mode 100644 index 0000000..8aab753 --- /dev/null +++ b/cert/cert.go @@ -0,0 +1,93 @@ +package cert + +import "crypto/tls" + +const CertPEM = `-----BEGIN CERTIFICATE----- +MIIF7jCCA9agAwIBAgIJAP/+a5pIA2lJMA0GCSqGSIb3DQEBCwUAMIGLMQswCQYD +VQQGEwJDTjERMA8GA1UECAwIWmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQww +CgYDVQQKDANtcHMxDDAKBgNVBAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5p +bzEiMCAGCSqGSIb3DQEJARYTdGVsYW5mbG93QGdtYWlsLmNvbTAeFw0yMDA4MDYx +MTE4MThaFw00MDA4MDExMTE4MThaMIGLMQswCQYDVQQGEwJDTjERMA8GA1UECAwI +WmhlSmlhbmcxETAPBgNVBAcMCEhhbmdaaG91MQwwCgYDVQQKDANtcHMxDDAKBgNV +BAsMA21wczEWMBQGA1UEAwwNbXBzLmdpdGh1Yi5pbzEiMCAGCSqGSIb3DQEJARYT +dGVsYW5mbG93QGdtYWlsLmNvbTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoC +ggIBANJZU0vyrS7aROi5+0e6AR4VBulFEjoivLrYaa1Pl1ENHHTgfjjmnLf2+22G +ImMp95RUDYIT2tZ2GhksLJil+fJEvv7HMihsWYYTjGzr5u3kPke0+fB/7dbRYJ+h +FvlsLEkItYPT9iBHryStu5CRV3P1VNtR9/7FF8YdX3kOqMQASnHQhBYNZ7av2OuR +3pDPLD0PKccqMeTXW+yMsB+z0L03RQQG3LOmi/7nWogvqVrnuwP7JbybOtHEvLO0 +rLEoAdXwdCCSAHdBCz2qat/I9CubGlKdUlgVw8eXVWZeYJeVOOQy8f7L9AEPoc5k +uXpEyRPCzpo/T/6KSxi2oxaEI4BSZUtyxRS/Laezdgs+GnKkjO56Z3lMPCvwwLFO +DNdtxA3OgLIvcZSA9zWPgoOSVQ0nCIQl3L3qEJ/TqyUWkcPINhiLNgnVSdu1dQ7q +rFZegmi5RAKAyl0M1rSlmTAB3Q/Mf4BMzPaNUajW7bjx4MbU9LxknVlRUb1vv2Jv +Pd6mUm0vLy6P/zl8/pZRpcnn91omFJ+PgZMoRzUPTBNDrgUEeXNsLzaLHg6t4fLb +xd1QMsg99Upo643Q/Hb8Xfz2ogm82jRURXkiHhQgxPjUvk76N4obNW9noMlZEUpF +/68/WwMc2CrWvWZ1HKWfpJDN6C2hjOqWvVWBng6LssVZdIBnAgMBAAGjUzBRMB0G +A1UdDgQWBBSPklwhHPcnDnP8tNSQ2i+VAs/gvjAfBgNVHSMEGDAWgBSPklwhHPcn +DnP8tNSQ2i+VAs/gvjAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IC +AQA6KnQfPV1gS53ZwakZAzE3XEDx+ef1C0iFZx282PWcIwnBPYbkswTt8RJj5806 +MiKyBtHSDN3Agde5LP6C2BhCx2GeguUDcTDPY0PGj5/TaES2gRiu6rsKkQJhSUTs +RSPekDT30yCJMJSz/1QnOqGXwToSpyr5rsWxAyYGAAz0fSAwpJ2XuE77vHCk+p3k +zIOjWxrkkLTSxoqIhOjmq8hO5qvwojudUTn0PBcAakND6r5csWLR2i+Am7u36WJc +m8dX8L+SXmWYw85Fs4tLDwUsORFuJclY32g7fKOvrgV2rhHXwwWw8pjpbjs0AAKA ++Gk3QQMT0cF4FpnH8VBdK82/nOtcbNvWz994K18kEzGUzN5Dq7zi+7n99HH4rwjO +2eyWNl7hsatvtKfcpwHhIt63EHG3owhc+Wf0iG3i6BR1b9jO3XLJmIlgkQ+4pSpV ++9BB2HNfklUOvJdsWwSxdhChavHJokl0rdGRf8weWhqhLkUCC3z2lIIylXgay71W +++48SxdaMbiqnuEZrt4cMlOt+KAxvtEl1krWLAIi6URHLSvdER7w/Dtkvg+PPb1t +yLiugEvmBIpagsbw3zirMza8Rg1CchRB+0sRGVSE4ppRe5EiIAe4aZUKufIk0TtS +8yO67j6Lx35sfhqzg/Jl31HOk42M8MpZqAGy13Cyw1kGmg== +-----END CERTIFICATE-----` + +const KeyPEM = `-----BEGIN RSA PRIVATE KEY----- +MIIJKgIBAAKCAgEA0llTS/KtLtpE6Ln7R7oBHhUG6UUSOiK8uthprU+XUQ0cdOB+ +OOact/b7bYYiYyn3lFQNghPa1nYaGSwsmKX58kS+/scyKGxZhhOMbOvm7eQ+R7T5 +8H/t1tFgn6EW+WwsSQi1g9P2IEevJK27kJFXc/VU21H3/sUXxh1feQ6oxABKcdCE +Fg1ntq/Y65HekM8sPQ8pxyox5Ndb7IywH7PQvTdFBAbcs6aL/udaiC+pWue7A/sl +vJs60cS8s7SssSgB1fB0IJIAd0ELPapq38j0K5saUp1SWBXDx5dVZl5gl5U45DLx +/sv0AQ+hzmS5ekTJE8LOmj9P/opLGLajFoQjgFJlS3LFFL8tp7N2Cz4acqSM7npn +eUw8K/DAsU4M123EDc6Asi9xlID3NY+Cg5JVDScIhCXcveoQn9OrJRaRw8g2GIs2 +CdVJ27V1DuqsVl6CaLlEAoDKXQzWtKWZMAHdD8x/gEzM9o1RqNbtuPHgxtT0vGSd +WVFRvW+/Ym893qZSbS8vLo//OXz+llGlyef3WiYUn4+BkyhHNQ9ME0OuBQR5c2wv +NoseDq3h8tvF3VAyyD31SmjrjdD8dvxd/PaiCbzaNFRFeSIeFCDE+NS+Tvo3ihs1 +b2egyVkRSkX/rz9bAxzYKta9ZnUcpZ+kkM3oLaGM6pa9VYGeDouyxVl0gGcCAwEA +AQKCAgArbEc2wXUg2+wnwuTtrKc4Z4zSsPCPUcZ2J+DA51JMaBF8yy8jXe/yRikn +Ne55XBuA4k0bki+14BGJKsZWCMVtTuXCwKpJD/z3Iaf2gEheyaRVtzV1gWM+2mBA +88dDXCJUPVkDSslfZozwXHEA6hAMnxOSZvxz+onq2vtviSgrtgeoMSxjRQco/mog +Ty+L40i1niC4vawpGpAeZ/ifwsYPmY5Ew4niCDqUN3xH6tbiLj48Fyd2JPFihmOS +EXUo6SJf4NCIPLud4q6IX1rKsbg+HDm13kY2at/MnyABDvCPuj1RVncAa2gGpAx6 +B+8GH5cG3ks6KmHAIRpZkrJeHo8ZOZOVsfokBDjTEWUn3sQH8RrVXtNnm/W66dAl +m4LnKyBWvyVaOHn65Jq06XTaUrT/9MY2RDmLPehzhcPczcZZn8RQikrarStQuHk2 +DOiiCvjSVnh+O13RdCKBMXfG4A482LucFnSSweiuFrXDU87GO++jKaZoYPeXsQul +jlTNFUyr7zHO5gqVf+JzboRG+pwashiFZBCfVqu/h9Abx7BTOfIXi5k5f0r1elZw +hQwJT0WgJKX4MjehojABNi+t4i6xqcsmCB9D68FBONSLHY4qpa1s9rhHEzd6BTJS +Fg4GPFooxQ2lwhjAyz2ZhG6HbF3QNuV6HgoimkcYlLiVowx5qQKCAQEA7XiDMQoA +4J+ZVlseGPinIVpOuAY8ehcd25I9xLmaqvk0CGeGBvpC125KGnG42m4l133AicOa ++Dz0yU0UudvfvnJEyPd6ojpAYMxX5/MH+85hJ8ARPwQQ+K+TlIPKH/8jXhS8D25D +pcvY4MJftwiuBhFTYlveNnmbH7QCAge/lS0BSnlOMGUI/yBg4DSOt0sAZOblWNlv +1FKG3aKCdd3atY1VvTyGnqpUqFiuU6ENwSbam77hTuQHjE3rULESj/wHIxpj/2Gr +VsjD/o29h2jjUApseUQqi55TllBj60K/DGuTiKSj8PXmXaBy20MDpef9HEYOVD/j +lsB4aPHqqn6+BQKCAQEA4sMNrA7xAfFRLJphzXd/6k9ISbKTjcp57f4Akib3FqCo +BJPz1F8cQJ5BHZLBJum/jyfbPEd36owr5bn/JlMqXsPgzb3eik/ZoIA/woucFosh +8MrebpARuSMmNtC0F2VfEDG4G/p7c+/aYWLPJJbte0XmIvsIVXwtlAt6k+HCAVAW +PA+MLAelEC0gtHOk1ea9NN2VCfsfpsbw/4GSUlL6Efev5ufXyH2z+tPmZyusPnBS +fAGZ78d000mH3RVsGN3o22Tzv8Hpx62MV8U08TvCESZsjggX2lCpVIsi3GRDR5wj +NFU7LOEy7TljYMES8GIyNc8U9csIgyh2+WLSjcSkewKCAQEAvuFa2uVGlUfUkoSF +ad8dQIL9uZBRtnW0a1Vezy299GaB+6tzIVKyvcYKTL1SsElPo6qSRGp1u8oLnW+X +FFp3u/bP8ZZz/cjDDMvUcT56EV7v22rYsgWLusou33cb1qJYBHy4OdMRD0kO2IOF +OnQApiHxG6Pqt3ECTvZ7krQ1vCxD2GAviFj+ZUzacf3tJcpk07aBbezBpjJ789V3 +9lRRRBQKciUftJQHnpZB8jkH/FVF7WD+bFKA+rd7Sg47dH9KIV5KOPKCLi0M1iWK +zjhyV1k5njQ72qR2XeHanzW0qcAjA/gLS1ntRR7+k96HJSmX2804IWKFhxzI7Npg +HZHpHQKCAQEAiabMGuErDfnWQ9QngJmE7dBY2lvr1EvP/leNMysyHOtDcxv5DLb7 +qIIolvIqDBwi65zPKeVcduXGE/r3VuVvN/2B7oLOn3lfa13O1qL3CnxFCy2rHsSX +7aHXpbjFSdqAfY0g7OL9o+A62Zkok1aHLKi+zgdDBNmPtWnObAzEPxXFmYn6lhPB +8HLkgoYczrf1rSzBN0DY8t2bGA8oqo6yPMv1XJ7qT0t3QND28TQCqBh5CcvTDUov +sb7WGa/SYbn7i4rZqFLnPg4svm7492NGKDEB/qoNCLqkP60CaXT3nnW6rR77//9o +cba/i9FIVOHXBvEBET/BmBStPC/wDp0LFwKCAQEA1sZKLH3IhWMbdqOcWbO8H1VC +4NT74peijfTjJ1JxcLllyv1H0MXW9qXG0Sksmy41CdyPBZUQuYbzi78p8S1aNIlx +sj1VGbGIHk+YNMJYHBlTBpn9hjDjXP3tZHtVHRzZN0rjpFV76ODpxatvKxryPl3X +mAPMhTvwnxnQ2rNF7RBPC8H8qJVBbG98k9HCqHolbNiMhV2Iow2SKSuX57IjT0cY +7mxps5zU94dTyJARNZaP7nlGHv1qx2ihRqxksIWxFetQ+U1JrM/14aeFUtS23HM+ +MJQxyIWYaidDHJzHy1MiZBZ5dpC2hwqNSjLq/OoDEV2cAYHEMnLnXCiz7ZulUA== +-----END RSA PRIVATE KEY-----` + +// default certificate +var DefaultCertificate, _ = tls.X509KeyPair([]byte(CertPEM), []byte(KeyPEM)) diff --git a/cert/container.go b/cert/container.go new file mode 100644 index 0000000..5031bd7 --- /dev/null +++ b/cert/container.go @@ -0,0 +1,13 @@ +package cert + +import "crypto/tls" + +// certificate storage Container +type Container interface { + + // Get the certificate for host + Get(host string) (*tls.Certificate, error) + + // Set the certificate for host + Set(host string, cert *tls.Certificate) error +} diff --git a/cert/mem_provider.go b/cert/mem_provider.go new file mode 100644 index 0000000..bdcef44 --- /dev/null +++ b/cert/mem_provider.go @@ -0,0 +1,32 @@ +package cert + +import ( + "crypto/tls" + "fmt" + "strings" +) + +type MemProvider struct { + cache map[string]*tls.Certificate +} + +func NewMemProvider() *MemProvider { + return &MemProvider{ + cache: make(map[string]*tls.Certificate), + } +} + +func (m *MemProvider) Get(host string) (cert *tls.Certificate, err error) { + var ok bool + cert, ok = m.cache[strings.TrimSpace(host)] + if !ok { + err = fmt.Errorf("cert not exist") + } + return +} + +func (m *MemProvider) Set(host string, cert *tls.Certificate) error { + host = strings.TrimSpace(host) + m.cache[host] = cert + return nil +} diff --git a/chunked.go b/chunked.go new file mode 100644 index 0000000..e96d6d4 --- /dev/null +++ b/chunked.go @@ -0,0 +1,57 @@ +// Taken from $GOROOT/src/pkg/net/http/chunked +// needed to write https responses to client. +package mps + +import ( + "io" + "strconv" +) + +// newChunkedWriter returns a new chunkedWriter that translates writes into HTTP +// "chunked" format before writing them to w. Closing the returned chunkedWriter +// sends the final 0-length chunk that marks the end of the stream. +// +// newChunkedWriter is not needed by normal applications. The http +// package adds chunking automatically if handlers don't set a +// Content-Length header. Using newChunkedWriter inside a handler +// would result in double chunking or chunking with a Content-Length +// length, both of which are wrong. +func newChunkedWriter(w io.Writer) io.WriteCloser { + return &chunkedWriter{w} +} + +// Writing to chunkedWriter translates to writing in HTTP chunked Transfer +// Encoding wire format to the underlying Wire chunkedWriter. +type chunkedWriter struct { + Wire io.Writer +} + +// Write the contents of data as one chunk to Wire. +// NOTE: Note that the corresponding chunk-writing procedure in Conn.Write has +// a bug since it does not check for success of io.WriteString +func (cw *chunkedWriter) Write(data []byte) (n int, err error) { + // Don't send 0-length data. It looks like EOF for chunked encoding. + if len(data) == 0 { + return 0, nil + } + + head := strconv.FormatInt(int64(len(data)), 16) + "\r\n" + + if _, err = io.WriteString(cw.Wire, head); err != nil { + return 0, err + } + if n, err = cw.Wire.Write(data); err != nil { + return + } + if n != len(data) { + err = io.ErrShortWrite + return + } + _, err = io.WriteString(cw.Wire, "\r\n") + return +} + +func (cw *chunkedWriter) Close() error { + _, err := io.WriteString(cw.Wire, "0\r\n") + return err +} diff --git a/context.go b/context.go index d71db34..a40ca76 100644 --- a/context.go +++ b/context.go @@ -1,5 +1,100 @@ package mps -type ProxyContext struct { +import ( + "context" + "crypto/tls" + "net/http" +) +type Context struct { + Context context.Context + Request *http.Request + Response *http.Response + Transport *http.Transport + + // In some cases it is not always necessary to remove the Proxy Header. + // For example, cascade proxy + KeepHeader bool + + // KeepDestinationHeaders indicates the proxy should retain any headers + // present in the http.Response before proxying + KeepDestinationHeaders bool + + // requests Middleware + mi int + middlewares []Middleware +} + +func NewContext() *Context { + return &Context{ + Context: context.Background(), + Transport: &http.Transport{ + //DialContext: (&net.Dialer{ + // Timeout: 15 * time.Second, + // KeepAlive: 30 * time.Second, + // DualStack: true, + //}).DialContext, + ////ForceAttemptHTTP2: true, + //MaxIdleConns: 100, + //IdleConnTimeout: 90 * time.Second, + //TLSHandshakeTimeout: 10 * time.Second, + //ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + Proxy: http.ProxyFromEnvironment, + }, + Request: nil, + Response: nil, + KeepHeader: false, + KeepDestinationHeaders: false, + mi: -1, + middlewares: make([]Middleware, 0), + } +} + +func (ctx *Context) Use(middleware ...Middleware) { + if ctx.middlewares == nil { + ctx.middlewares = make([]Middleware, 0) + } + + ctx.middlewares = append(ctx.middlewares, middleware...) +} + +func (ctx *Context) UseFunc(fns ...MiddlewareFunc) { + if ctx.middlewares == nil { + ctx.middlewares = make([]Middleware, 0) + } + + for _, fn := range fns { + ctx.middlewares = append(ctx.middlewares, fn) + } +} + +func (ctx *Context) Next(req *http.Request) (*http.Response, error) { + var ( + err error + total = len(ctx.middlewares) + ) + ctx.mi++ + if ctx.mi >= total { + ctx.mi = -1 + return ctx.Transport.RoundTrip(req) + } + + middleware := ctx.middlewares[ctx.mi] + ctx.Response, err = middleware.Handle(req, ctx) + ctx.mi = -1 + return ctx.Response, err +} + +func (ctx *Context) Copy() *Context { + return &Context{ + Context: context.Background(), + Request: nil, + Response: nil, + KeepHeader: false, + KeepDestinationHeaders: false, + Transport: ctx.Transport, + mi: -1, + middlewares: ctx.middlewares, + } } diff --git a/counter_encryptor.go b/counter_encryptor.go new file mode 100644 index 0000000..4542a64 --- /dev/null +++ b/counter_encryptor.go @@ -0,0 +1,73 @@ +package mps + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "errors" +) + +type CounterEncryptorRand struct { + cipher cipher.Block + counter []byte + rand []byte + ix int +} + +func NewCounterEncryptorRand(key interface{}, seed []byte) (r CounterEncryptorRand, err error) { + var keyBytes []byte + switch key := key.(type) { + case *rsa.PrivateKey: + keyBytes = x509.MarshalPKCS1PrivateKey(key) + case *ecdsa.PrivateKey: + if keyBytes, err = x509.MarshalECPrivateKey(key); err != nil { + return + } + default: + err = errors.New("only RSA and ECDSA keys supported") + return + } + h := sha256.New() + if r.cipher, err = aes.NewCipher(h.Sum(keyBytes)[:aes.BlockSize]); err != nil { + return + } + r.counter = make([]byte, r.cipher.BlockSize()) + if seed != nil { + copy(r.counter, h.Sum(seed)[:r.cipher.BlockSize()]) + } + r.rand = make([]byte, r.cipher.BlockSize()) + r.ix = len(r.rand) + return +} + +func (c *CounterEncryptorRand) Seed(b []byte) { + if len(b) != len(c.counter) { + panic("SetCounter: wrong counter size") + } + copy(c.counter, b) +} + +func (c *CounterEncryptorRand) refill() { + c.cipher.Encrypt(c.rand, c.counter) + for i := 0; i < len(c.counter); i++ { + if c.counter[i]++; c.counter[i] != 0 { + break + } + } + c.ix = 0 +} + +func (c *CounterEncryptorRand) Read(b []byte) (n int, err error) { + if c.ix == len(c.rand) { + c.refill() + } + if n = len(c.rand) - c.ix; n > len(b) { + n = len(b) + } + copy(b, c.rand[c.ix:c.ix+n]) + c.ix += n + return +} diff --git a/counter_encryptor_test.go b/counter_encryptor_test.go new file mode 100644 index 0000000..63c5de6 --- /dev/null +++ b/counter_encryptor_test.go @@ -0,0 +1,104 @@ +package mps + +import ( + "bytes" + "crypto/rsa" + "encoding/binary" + "io" + "math" + "math/rand" + "testing" +) + +type RandSeedReader struct { + r rand.Rand +} + +func (r *RandSeedReader) Read(b []byte) (n int, err error) { + for i := range b { + b[i] = byte(r.r.Int() & 0xFF) + } + return len(b), nil +} + +func TestCounterEncDifferentConsecutive(t *testing.T) { + k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) + fatalOnErr(err, "rsa.GenerateKey", t) + c, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) + for i := 0; i < 100*1000; i++ { + var a, b int64 + binary.Read(&c, binary.BigEndian, &a) + binary.Read(&c, binary.BigEndian, &b) + if a == b { + t.Fatal("two consecutive equal int64", a, b) + } + } +} + +func TestCounterEncIdenticalStreams(t *testing.T) { + k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) + fatalOnErr(err, "rsa.GenerateKey", t) + c1, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) + c2, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) + nout := 1000 + out1, out2 := make([]byte, nout), make([]byte, nout) + io.ReadFull(&c1, out1) + tmp := out2[:] + rand.Seed(0xFF43109) + for len(tmp) > 0 { + n := 1 + rand.Intn(256) + if n > len(tmp) { + n = len(tmp) + } + n, err := c2.Read(tmp[:n]) + fatalOnErr(err, "CounterEncryptorRand.Read", t) + tmp = tmp[n:] + } + if !bytes.Equal(out1, out2) { + t.Error("identical CSPRNG does not produce the same output") + } +} + +func stddev(data []int) float64 { + var sum, sum_sqr float64 = 0, 0 + for _, h := range data { + sum += float64(h) + sum_sqr += float64(h) * float64(h) + } + n := float64(len(data)) + variance := (sum_sqr - ((sum * sum) / n)) / (n - 1) + return math.Sqrt(variance) +} + +func TestCounterEncStreamHistogram(t *testing.T) { + k, err := rsa.GenerateKey(&RandSeedReader{*rand.New(rand.NewSource(0xFF43109))}, 128) + fatalOnErr(err, "rsa.GenerateKey", t) + c, err := NewCounterEncryptorRand(k, []byte("the quick brown fox run over the lazy dog")) + fatalOnErr(err, "NewCounterEncryptorRandFromKey", t) + nout := 100 * 1000 + out := make([]byte, nout) + io.ReadFull(&c, out) + refhist := make([]int, 512) + for i := 0; i < nout; i++ { + refhist[rand.Intn(256)]++ + } + hist := make([]int, 512) + for _, b := range out { + hist[int(b)]++ + } + refstddev, stddev := stddev(refhist), stddev(hist) + // due to lack of time, I guestimate + t.Logf("ref:%v - act:%v = %v", refstddev, stddev, math.Abs(refstddev-stddev)) + if math.Abs(refstddev-stddev) >= 1 { + t.Errorf("stddev of ref histogram different than regular PRNG: %v %v", refstddev, stddev) + } +} + +func fatalOnErr(err error, msg string, t *testing.T) { + if err != nil { + t.Fatal(msg, err) + } +} diff --git a/examples/generateCert/openssl-gen.sh b/examples/generateCert/openssl-gen.sh new file mode 100755 index 0000000..fdaab6a --- /dev/null +++ b/examples/generateCert/openssl-gen.sh @@ -0,0 +1,8 @@ +#!/bin/bash +set -ex +# generate CA's key +openssl genrsa -aes256 -passout pass:1 -out ca.key.pem 4096 +openssl rsa -passin pass:1 -in ca.key.pem -out ca.key.pem.tmp +mv ca.key.pem.tmp ca.key.pem + +openssl req -config openssl.cnf -key ca.key.pem -new -x509 -days 7300 -sha256 -extensions v3_ca -out ca.pem diff --git a/examples/generateCert/openssl.cnf b/examples/generateCert/openssl.cnf new file mode 100644 index 0000000..54b5e5d --- /dev/null +++ b/examples/generateCert/openssl.cnf @@ -0,0 +1,39 @@ +[ ca ] +default_ca = CA_default +[ CA_default ] +default_md = sha256 +[ v3_ca ] +subjectKeyIdentifier=hash +authorityKeyIdentifier=keyid:always,issuer +basicConstraints = critical,CA:true +[ req ] +distinguished_name = req_distinguished_name +[ req_distinguished_name ] +countryName = Country Name (2 letter code) +countryName_default = CN +countryName_min = 2 +countryName_max = 2 + +stateOrProvinceName = State or Province Name (full name) +stateOrProvinceName_default = ZheJiang + +localityName = Locality Name (eg, city) +localityName_default = HangZhou + +0.organizationName = Organization Name (eg, company) +0.organizationName_default = mps + +# we can do this but it is not needed normally :-) +#1.organizationName = Second Organization Name (eg, company) +#1.organizationName_default = World Wide Web Pty Ltd + +organizationalUnitName = Organizational Unit Name (eg, section) +organizationalUnitName_default = mps + +commonName = Common Name (e.g. server FQDN or YOUR name) +commonName_default = mps.github.io +commonName_max = 64 + +emailAddress = Email Address +emailAddress_default = telanflow@gmail.com +emailAddress_max = 64 diff --git a/examples/main.go b/examples/main.go index e53dc72..7905807 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,6 +2,4 @@ package main func main() { - - } diff --git a/filter.go b/filter.go new file mode 100644 index 0000000..c0769ed --- /dev/null +++ b/filter.go @@ -0,0 +1,31 @@ +package mps + +import ( + "net/http" + "regexp" +) + +type Filter interface { + Match(expr string) bool +} + +type FilterFunc func(expr string) bool + +func (f FilterFunc) Match(expr string) bool { + return f(expr) +} + +// 匹配域名 +var MatchIsHost = func(expr string, req *http.Request) Filter { + exp, err := regexp.Compile(expr) + if err != nil { + panic(err) + } + return FilterFunc(func(expr string) bool { + return exp.MatchString(req.Host) + }) +} + +type ReqHandler interface { + Handler(ctx *Context) +} diff --git a/forward_handler.go b/forward_handler.go new file mode 100644 index 0000000..1c97528 --- /dev/null +++ b/forward_handler.go @@ -0,0 +1,61 @@ +package mps + +import ( + "io" + "net/http" +) + +type ForwardHandler struct { + Ctx *Context +} + +func NewForwardHandler() *ForwardHandler { + return &ForwardHandler{ + Ctx: NewContext(), + } +} + +func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler { + return &ForwardHandler{ + Ctx: ctx, + } +} + +func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // Copying a Context preserves the Transport, Middleware + ctx := forward.Ctx.Copy() + ctx.Request = req + + // In some cases it is not always necessary to remove the Proxy Header. + // For example, cascade proxy + if !forward.Ctx.KeepHeader { + removeProxyHeaders(req) + } + + resp, err := ctx.Next(req) + if err != nil { + http.Error(rw, err.Error(), 500) + return + } + + origBody := resp.Body + defer origBody.Close() + + // http.ResponseWriter will take care of filling the correct response length + // Setting it now, might impose wrong value, contradicting the actual new + // body the user returned. + // We keep the original body to remove the header only if things changed. + // This will prevent problems with HEAD requests where there's no body, yet, + // the Content-Length header should be set. + if origBody != resp.Body { + resp.Header.Del("Content-Length") + } + copyHeaders(rw.Header(), resp.Header, forward.Ctx.KeepDestinationHeaders) + rw.WriteHeader(resp.StatusCode) + io.Copy(rw, resp.Body) + resp.Body.Close() +} + +func (forward *ForwardHandler) Transport() *http.Transport { + return forward.Ctx.Transport +} diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/http_proxy.go b/http_proxy.go new file mode 100644 index 0000000..67cc119 --- /dev/null +++ b/http_proxy.go @@ -0,0 +1,110 @@ +package mps + +import ( + "errors" + "fmt" + "net" + "net/http" +) + +type HttpProxy struct { + // HTTPS requests use the TunnelHandler proxy by default + HttpsHandler http.Handler + + // HTTP requests use the ForwardHandler proxy by default + HttpHandler http.Handler + + Ctx *Context +} + +func NewHttpProxy() *HttpProxy { + // default Context with Proxy + ctx := NewContext() + + return &HttpProxy{ + Ctx: ctx, + // default HTTP proxy + HttpHandler: &ForwardHandler{Ctx: ctx}, + // default HTTPS proxy + HttpsHandler: &TunnelHandler{Ctx: ctx}, + } +} + +func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if req.Method == http.MethodConnect { + proxy.HttpsHandler.ServeHTTP(rw, req) + } + proxy.HttpHandler.ServeHTTP(rw, req) +} + +func (proxy *HttpProxy) Use(middleware ...Middleware) { + proxy.Ctx.Use(middleware...) +} + +func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) { + proxy.Ctx.UseFunc(fus...) +} + +func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) { + hij, ok := rw.(http.Hijacker) + if !ok { + err = errors.New("not a hijacker") + return + } + + conn, _, err = hij.Hijack() + if err != nil { + err = fmt.Errorf("cannot hijack connection %v", err) + } + return +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +func removeProxyHeaders(r *http.Request) { + r.RequestURI = "" // this must be reset when serving a request with the client + // If no Accept-Encoding header exists, Transport will add the headers it can accept + // and would wrap the response body with the relevant reader. + r.Header.Del("Accept-Encoding") + // curl can add that, see + // https://jdebp.eu./FGA/web-proxy-connection-header.html + + // RFC 2616 (section 13.5.1) + // https://www.ietf.org/rfc/rfc2616.txt + r.Header.Del("Proxy-Connection") + r.Header.Del("Proxy-Authenticate") + r.Header.Del("Proxy-Authorization") + // Connection, Authenticate and Authorization are single hop Header: + // http://www.w3.org/Protocols/rfc2616/rfc2616.txt + // 14.10 Connection + // The Connection general-header field allows the sender to specify + // options that are desired for that particular connection and MUST NOT + // be communicated by proxies over further connections. + + // When server reads http request it sets req.Close to true if + // "Connection" header contains "close". + // https://github.com/golang/go/blob/master/src/net/http/request.go#L1080 + // Later, transfer.go adds "Connection: close" back when req.Close is true + // https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275 + // That's why tests that checks "Connection: close" removal fail + if r.Header.Get("Connection") == "close" { + r.Close = false + } + r.Header.Del("Connection") +} + +func copyHeaders(dst, src http.Header, keepDestHeaders bool) { + if !keepDestHeaders { + for k := range dst { + dst.Del(k) + } + } + for k, vs := range src { + for _, v := range vs { + dst.Add(k, v) + } + } +} diff --git a/http_proxy_test.go b/http_proxy_test.go new file mode 100644 index 0000000..afe03f9 --- /dev/null +++ b/http_proxy_test.go @@ -0,0 +1,64 @@ +package mps + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestNewHttpProxy(t *testing.T) { + proxy := NewHttpProxy() + srv := httptest.NewServer(proxy) + defer srv.Close() + + req, _ := http.NewRequest(http.MethodGet, "http://httpbin.org/get", nil) + http.DefaultClient.Transport = &http.Transport{ + Proxy: func(r *http.Request) (*url.URL, error) { + return url.Parse(srv.URL) + }, + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := ioutil.ReadAll(resp.Body) + + log.Println(err) + log.Println(resp.Status) + log.Println(string(body)) +} + +func TestMiddlewareFunc(t *testing.T) { + proxy := NewHttpProxy() + proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + log.Println(req.URL.String()) + return ctx.Next(req) + }) + srv := httptest.NewServer(proxy) + defer srv.Close() + + req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) + http.DefaultClient.Transport = &http.Transport{ + Proxy: func(r *http.Request) (*url.URL, error) { + return url.Parse(srv.URL) + }, + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := ioutil.ReadAll(resp.Body) + + log.Println(err) + log.Println(resp.Status) + log.Println(string(body)) +} \ No newline at end of file diff --git a/middleware.go b/middleware.go index 7629859..8287db5 100644 --- a/middleware.go +++ b/middleware.go @@ -2,6 +2,12 @@ package mps import "net/http" -type Middleware func(req *http.Request, resp *http.Response) +type Middleware interface { + Handle(req *http.Request, ctx *Context) (*http.Response, error) +} -type a http.HandlerFunc \ No newline at end of file +type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error) + +func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) { + return f(req, ctx) +} diff --git a/mitm_handler.go b/mitm_handler.go new file mode 100644 index 0000000..7c12a36 --- /dev/null +++ b/mitm_handler.go @@ -0,0 +1,300 @@ +package mps + +import ( + "bufio" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "github.com/telanflow/mps/cert" + "io" + "math/big" + "math/rand" + "net" + "net/http" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "time" +) + +var ( + HttpMitmOk = []byte("HTTP/1.0 200 OK\r\n\r\n") + httpsRegexp = regexp.MustCompile("^https://") +) + +type MitmHandler struct { + Certificate tls.Certificate + Ctx *Context + + TLSConfig *tls.Config + CertContainer cert.Container +} + +func NewMitmHandler() *MitmHandler { + return &MitmHandler{ + Ctx: NewContext(), + // default MPS Certificate + Certificate: cert.DefaultCertificate, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + // Certificate cache storage container + CertContainer: cert.NewMemProvider(), + } +} + +func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // get hijacker connection + proxyClient, err := hijacker(w) + if err != nil { + http.Error(w, err.Error(), 502) + return + } + + // this goes in a separate goroutine, so that the net/http server won't think we're + // still handling the request even after hijacking the connection. Those HTTP CONNECT + // request can take forever, and the server will be stuck when "closed". + // TODO: Allow Server.Close() mechanism to shut down this connection as nicely as possible + tlsConfig, err := mitm.TLSConfigFromCA(r.URL.Host) + if err != nil { + ConnError(proxyClient) + return + } + + _, _ = proxyClient.Write(HttpMitmOk) + + go func() { + // TODO: cache connections to the remote website + rawClientTls := tls.Server(proxyClient, tlsConfig) + if err := rawClientTls.Handshake(); err != nil { + ConnError(proxyClient) + //ctx.Warnf("Cannot handshake client %v %v", r.Host, err) + return + } + defer rawClientTls.Close() + + clientTlsReader := bufio.NewReader(rawClientTls) + for !isEof(clientTlsReader) { + req, err := http.ReadRequest(clientTlsReader) + if err != nil { + //ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err) + break + } + + // since we're converting the request, need to carry over the original connecting IP as well + req.RemoteAddr = r.RemoteAddr + + if !httpsRegexp.MatchString(req.URL.String()) { + req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) + } + if err != nil { + //ctx.Warnf("Illegal URL %s", "https://"+r.Host+req.URL.Path) + return + } + + // Copying a Context preserves the Transport, Middleware + ctx := mitm.Ctx.Copy() + ctx.Request = req + + // In some cases it is not always necessary to remove the Proxy Header. + // For example, cascade proxy + if !mitm.Ctx.KeepHeader { + removeProxyHeaders(req) + } + + var resp *http.Response + resp, err = ctx.Next(req) + if err != nil { + //ctx.Warnf("Cannot read TLS response from mitm'd server %v", err) + return + } + defer resp.Body.Close() + + status := resp.Status + statusCode := strconv.Itoa(resp.StatusCode) + " " + if strings.HasPrefix(status, statusCode) { + status = status[len(statusCode):] + } + + // always use 1.1 to support chunked encoding + if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+status+"\r\n"); err != nil { + //ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) + return + } + + // Since we don't know the length of resp, return chunked encoded response + // TODO: use a more reasonable scheme + resp.Header.Del("Content-Length") + resp.Header.Set("Transfer-Encoding", "chunked") + + // Force connection close otherwise chrome will keep CONNECT tunnel open forever + resp.Header.Set("Connection", "close") + + err = resp.Header.Write(rawClientTls) + if err != nil { + //ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err) + return + } + _, err = io.WriteString(rawClientTls, "\r\n") + if err != nil { + //ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err) + return + } + + chunked := newChunkedWriter(rawClientTls) + _, err = io.Copy(chunked, resp.Body) + if err != nil { + //ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err) + return + } + if err := chunked.Close(); err != nil { + //ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err) + return + } + if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { + //ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err) + return + } + } + + }() +} + +func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) { + host = stripPort(host) + + // Returned existing certificate for the host + crt, err := mitm.CertContainer.Get(host) + if err == nil { + config := cloneTLSConfig(mitm.TLSConfig) + config.Certificates = append(config.Certificates, *crt) + return config, nil + } + + // Issue a certificate for host + crt, err = signHost(mitm.Certificate, []string{host}) + if err != nil { + err = fmt.Errorf("cannot sign host certificate with provided CA: %v", err) + return nil, err + } + + // Set certificate to container + mitm.CertContainer.Set(host, crt) + + config := &tls.Config{ + Certificates: []tls.Certificate{*crt}, + } + return config, nil +} + +func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) { + // Use the provided ca for certificate generation. + var x509ca *x509.Certificate + x509ca, err = x509.ParseCertificate(ca.Certificate[0]) + if err != nil { + return + } + + var random CounterEncryptorRand + random, err = NewCounterEncryptorRand(ca.PrivateKey, hashHosts(hosts)) + if err != nil { + return + } + + var pk crypto.Signer + switch ca.PrivateKey.(type) { + case *rsa.PrivateKey: + pk, err = rsa.GenerateKey(&random, 2048) + case *ecdsa.PrivateKey: + pk, err = ecdsa.GenerateKey(elliptic.P256(), &random) + default: + err = fmt.Errorf("unsupported key type %T", ca.PrivateKey) + } + if err != nil { + return + } + + // certificate template + tpl := x509.Certificate{ + // SerialNumber 是 CA 颁布的唯一序列号,在此使用一个大随机数来代表它 + SerialNumber: big.NewInt(rand.Int63()), + Issuer: x509ca.Subject, + // pkix.Name代表一个X.509识别名。只包含识别名的公共属性,额外的属性被忽略。 + Subject: pkix.Name{ + Organization: []string{"MPS untrusted MITM proxy Inc"}, + }, + NotBefore: time.Unix(0, 0), + NotAfter: time.Now().AddDate(20, 0, 0), + // KeyUsage 与 ExtKeyUsage 用来表明该证书是用来做服务器认证的 + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + // 密钥扩展用途的序列 + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + EmailAddresses: x509ca.EmailAddresses, + } + + total := len(hosts) + for i := 0; i < total; i++ { + if ip := net.ParseIP(hosts[i]); ip != nil { + tpl.IPAddresses = append(tpl.IPAddresses, ip) + } else { + tpl.DNSNames = append(tpl.DNSNames, hosts[i]) + tpl.Subject.CommonName = hosts[i] + } + } + + var der []byte + der, err = x509.CreateCertificate(&random, &tpl, x509ca, pk.Public(), ca.PrivateKey) + if err != nil { + return + } + + cert = &tls.Certificate{ + Certificate: [][]byte{der, ca.Certificate[0]}, + PrivateKey: pk, + } + return +} + +func stripPort(s string) string { + ix := strings.IndexRune(s, ':') + if ix == -1 { + return s + } + return s[:ix] +} + +func hashHosts(lst []string) []byte { + c := make([]string, len(lst)) + copy(c, lst) + sort.Strings(c) + h := sha1.New() + h.Write([]byte(strings.Join(c, ","))) + return h.Sum(nil) +} + +// cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if +// cfg is nil. This is safe to call even if cfg is in active use by a TLS +// client or server. +func cloneTLSConfig(cfg *tls.Config) *tls.Config { + if cfg == nil { + return &tls.Config{} + } + return cfg.Clone() +} + +func isEof(r *bufio.Reader) bool { + _, err := r.Peek(1) + if err == io.EOF { + return true + } + return false +} diff --git a/mitm_handler_test.go b/mitm_handler_test.go new file mode 100644 index 0000000..f325c7b --- /dev/null +++ b/mitm_handler_test.go @@ -0,0 +1,49 @@ +package mps + +import ( + "crypto/tls" + "crypto/x509" + "github.com/telanflow/mps/cert" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestNewMitmHandler(t *testing.T) { + mitm := NewMitmHandler() + mitmSrv := httptest.NewServer(mitm) + defer mitmSrv.Close() + + clientCertPool := x509.NewCertPool() + ok := clientCertPool.AppendCertsFromPEM([]byte(cert.CertPEM)) + if !ok { + panic("failed to parse root certificate") + } + + req, _ := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) + http.DefaultClient.Transport = &http.Transport{ + Proxy: func(r *http.Request) (*url.URL, error) { + return url.Parse(mitmSrv.URL) + }, + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{cert.DefaultCertificate}, + ClientAuth: tls.RequireAndVerifyClientCert, + RootCAs: clientCertPool, + }, + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := ioutil.ReadAll(resp.Body) + + log.Println(err) + log.Println(resp.Status) + log.Println(string(body)) +} diff --git a/mps.go b/mps.go index b80fbdc..a027f75 100644 --- a/mps.go +++ b/mps.go @@ -1,2 +1 @@ package mps - diff --git a/proxy_server.go b/proxy_server.go deleted file mode 100644 index e943a2b..0000000 --- a/proxy_server.go +++ /dev/null @@ -1,122 +0,0 @@ -package mps - -import ( - "bufio" - "crypto/tls" - "io" - "net/http" -) - -type ProxyServer struct { - middleware []Middleware - Transport *http.Transport - KeepHeader bool - // KeepDestinationHeaders indicates the proxy should retain any headers present in the http.Response before proxying - KeepDestinationHeaders bool -} - -func NewProxyServer() *ProxyServer { - return &ProxyServer{ - middleware: nil, - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - Proxy: http.ProxyFromEnvironment, - }, - } -} - -func (p *ProxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodConnect { - p.handlerHttps(w, r) - return - } - - p.handlerHttp(w, r) -} - -func (p *ProxyServer) handlerHttp(w http.ResponseWriter, r *http.Request) { - var err error - - if !p.KeepHeader { - removeProxyHeaders(r) - } - - resp, err := p.Transport.RoundTrip(r) - if err != nil { - http.Error(w, err.Error(), 500) - return - } - - origBody := resp.Body - defer origBody.Close() - - // http.ResponseWriter will take care of filling the correct response length - // Setting it now, might impose wrong value, contradicting the actual new - // body the user returned. - // We keep the original body to remove the header only if things changed. - // This will prevent problems with HEAD requests where there's no body, yet, - // the Content-Length header should be set. - if origBody != resp.Body { - resp.Header.Del("Content-Length") - } - copyHeaders(w.Header(), resp.Header, p.KeepDestinationHeaders) - w.WriteHeader(resp.StatusCode) - io.Copy(w, resp.Body) - resp.Body.Close() -} - -func (p *ProxyServer) handlerHttps(w http.ResponseWriter, r *http.Request) { - hij, ok := w.(http.Hijacker) - if !ok { - http.Error(w, "Not a hijacker?", 500) - return - } - - proxyClientConn, _, e := hij.Hijack() - if e != nil { - http.Error(w, "Cannot hijack connection " + e.Error(), 500) - return - } - - -} - -func removeProxyHeaders(r *http.Request) { - r.RequestURI = "" // this must be reset when serving a request with the client - // If no Accept-Encoding header exists, Transport will add the headers it can accept - // and would wrap the response body with the relevant reader. - r.Header.Del("Accept-Encoding") - // curl can add that, see - // https://jdebp.eu./FGA/web-proxy-connection-header.html - r.Header.Del("Proxy-Connection") - r.Header.Del("Proxy-Authenticate") - r.Header.Del("Proxy-Authorization") - // Connection, Authenticate and Authorization are single hop Header: - // http://www.w3.org/Protocols/rfc2616/rfc2616.txt - // 14.10 Connection - // The Connection general-header field allows the sender to specify - // options that are desired for that particular connection and MUST NOT - // be communicated by proxies over further connections. - r.Header.Del("Connection") -} - -func copyHeaders(dst, src http.Header, keepDestHeaders bool) { - if !keepDestHeaders { - for k := range dst { - dst.Del(k) - } - } - for k, vs := range src { - for _, v := range vs { - dst.Add(k, v) - } - } -} - -func isEof(r *bufio.Reader) bool { - _, err := r.Peek(1) - if err == io.EOF { - return true - } - return false -} \ No newline at end of file diff --git a/proxy_server_test.go b/proxy_server_test.go deleted file mode 100644 index cbd9f10..0000000 --- a/proxy_server_test.go +++ /dev/null @@ -1,9 +0,0 @@ -package mps - -import "testing" - -func TestProxyServer_ServeHTTP(t *testing.T) { - - - -} \ No newline at end of file diff --git a/request_handle.go b/request_handle.go new file mode 100644 index 0000000..ab166ef --- /dev/null +++ b/request_handle.go @@ -0,0 +1,13 @@ +package mps + +import "net/http" + +type RequestHandle interface { + Handle(req *http.Request) (*http.Request, *http.Response) +} + +type RequestHandleFunc func(req *http.Request) (*http.Request, *http.Response) + +func (f RequestHandleFunc) Handle(req *http.Request) (*http.Request, *http.Response) { + return f(req) +} diff --git a/response_handle.go b/response_handle.go new file mode 100644 index 0000000..ab2f29e --- /dev/null +++ b/response_handle.go @@ -0,0 +1,13 @@ +package mps + +import "net/http" + +type ResponseHandle interface { + Handle(resp *http.Response) *http.Response +} + +type ResponseHandleFunc func(resp *http.Response) *http.Response + +func (f ResponseHandleFunc) Handle(resp *http.Response) *http.Response { + return f(resp) +} diff --git a/reverse_handler.go b/reverse_handler.go new file mode 100644 index 0000000..39d0769 --- /dev/null +++ b/reverse_handler.go @@ -0,0 +1,17 @@ +package mps + +import "net/http" + +type ReverseHandler struct { + Ctx *Context +} + +func NewReverseHandler() *ReverseHandler { + return &ReverseHandler{ + Ctx: NewContext(), + } +} + +func (reverse *ReverseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + +} diff --git a/tunnel_handler.go b/tunnel_handler.go new file mode 100644 index 0000000..c30b3aa --- /dev/null +++ b/tunnel_handler.go @@ -0,0 +1,101 @@ +package mps + +import ( + "context" + "io" + "net" + "net/http" + "net/url" + "regexp" +) + +var ( + HttpTunnelOk = []byte("HTTP/1.0 200 OK\r\n\r\n") + HttpTunnelFail = []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n") + hasPort = regexp.MustCompile(`:\d+$`) +) + +type TunnelHandler struct { + Ctx *Context +} + +func NewTunnelHandler() *TunnelHandler { + return &TunnelHandler{ + Ctx: NewContext(), + } +} + +func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + // hijacker connection + proxyClient, err := hijacker(rw) + if err != nil { + http.Error(rw, err.Error(), 502) + return + } + + var ( + u *url.URL = nil + targetConn net.Conn = nil + targetAddr = hostAndPort(req.URL.Host) + ) + if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Proxy != nil { + u, err = tunnel.Ctx.Transport.Proxy(req) + if err != nil { + ConnError(proxyClient) + return + } + if u != nil { + // connect addr eg. "localhost:80" + targetAddr = hostAndPort(u.Host) + } + + } else { + _, _ = proxyClient.Write(HttpTunnelOk) + } + + // connect to targetAddr + targetConn, err = tunnel.ConnectDial("tcp", targetAddr) + if err != nil { + ConnError(proxyClient) + return + } + + go func() { + buf := make([]byte, 2048) + _, _ = io.CopyBuffer(targetConn, proxyClient, buf) + targetConn.Close() + proxyClient.Close() + }() + buf := make([]byte, 2048) + _, _ = io.CopyBuffer(proxyClient, targetConn, buf) +} + +func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) { + //if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Dial != nil { + // return tunnel.Ctx.Transport.Dial(network, addr) + //} + return net.Dial(network, addr) +} + +func (tunnel *TunnelHandler) Context() context.Context { + if tunnel.Ctx.Context != nil { + return tunnel.Ctx.Context + } + return context.Background() +} + +func (tunnel *TunnelHandler) Transport() *http.Transport { + return tunnel.Ctx.Transport +} + +func hostAndPort(addr string) string { + if !hasPort.MatchString(addr) { + addr += ":80" + } + return addr +} + +func ConnError(w net.Conn) { + _, _ = w.Write(HttpTunnelFail) + _ = w.Close() +} diff --git a/tunnel_handler_test.go b/tunnel_handler_test.go new file mode 100644 index 0000000..62372c6 --- /dev/null +++ b/tunnel_handler_test.go @@ -0,0 +1,40 @@ +package mps + +import ( + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestNewTunnelHandler(t *testing.T) { + tunnel := NewTunnelHandler() + //tunnel.Transport().Proxy = func(r *http.Request) (*url.URL, error) { + // //return url.Parse("http://59.58.58.92:4235") + // return url.Parse("http://127.0.0.1:7890") + //} + tunnel.Transport().Dial = nil + tunnelSrv := httptest.NewServer(tunnel) + defer tunnelSrv.Close() + + req, _ := http.NewRequest(http.MethodGet, "http://httpbin.org/get", nil) + http.DefaultClient.Transport = &http.Transport{ + Proxy: func(r *http.Request) (*url.URL, error) { + return url.Parse(tunnelSrv.URL) + }, + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + body, _ := ioutil.ReadAll(resp.Body) + + log.Println(err) + log.Println(resp.Status) + log.Println(string(body)) +}