mirror of
https://github.com/telanflow/mps.git
synced 2025-09-26 20:41:25 +08:00
385 lines
9.4 KiB
Go
385 lines
9.4 KiB
Go
package mps
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rsa"
|
|
"crypto/sha1"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/telanflow/mps/cert"
|
|
"github.com/telanflow/mps/pool"
|
|
)
|
|
|
|
var (
|
|
HttpMitmOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n")
|
|
httpsRegexp = regexp.MustCompile("^https://")
|
|
)
|
|
|
|
// MitmHandler The Man-in-the-middle proxy type. Implements http.Handler.
|
|
type MitmHandler struct {
|
|
Ctx *Context
|
|
BufferPool httputil.BufferPool
|
|
Certificate tls.Certificate
|
|
// CertContainer is certificate storage container
|
|
CertContainer cert.Container
|
|
}
|
|
|
|
// NewMitmHandler Create a mitmHandler, use default cert.
|
|
func NewMitmHandler() *MitmHandler {
|
|
return &MitmHandler{
|
|
Ctx: NewContext(),
|
|
BufferPool: pool.DefaultBuffer,
|
|
Certificate: cert.DefaultCertificate,
|
|
CertContainer: cert.NewMemProvider(),
|
|
}
|
|
}
|
|
|
|
// NewMitmHandlerWithContext Create a MitmHandler, use default cert.
|
|
func NewMitmHandlerWithContext(ctx *Context) *MitmHandler {
|
|
return &MitmHandler{
|
|
Ctx: ctx,
|
|
BufferPool: pool.DefaultBuffer,
|
|
Certificate: cert.DefaultCertificate,
|
|
CertContainer: cert.NewMemProvider(),
|
|
}
|
|
}
|
|
|
|
// NewMitmHandlerWithCert Create a MitmHandler with cert pem block
|
|
func NewMitmHandlerWithCert(ctx *Context, certPEMBlock, keyPEMBlock []byte) (*MitmHandler, error) {
|
|
certificate, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &MitmHandler{
|
|
Ctx: ctx,
|
|
BufferPool: pool.DefaultBuffer,
|
|
Certificate: certificate,
|
|
CertContainer: cert.NewMemProvider(),
|
|
}, nil
|
|
}
|
|
|
|
// NewMitmHandlerWithCertFile Create a MitmHandler with cert file
|
|
func NewMitmHandlerWithCertFile(ctx *Context, certFile, keyFile string) (*MitmHandler, error) {
|
|
certificate, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &MitmHandler{
|
|
Ctx: ctx,
|
|
BufferPool: pool.DefaultBuffer,
|
|
Certificate: certificate,
|
|
CertContainer: cert.NewMemProvider(),
|
|
}, nil
|
|
}
|
|
|
|
// Standard net/http function. You can use it alone
|
|
func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|
// execution middleware
|
|
ctx := mitm.Ctx.WithRequest(req)
|
|
resp, err := ctx.Next(req)
|
|
if err != nil && !errors.Is(err, MethodNotSupportErr) {
|
|
if resp != nil {
|
|
copyHeaders(rw.Header(), resp.Header, mitm.Ctx.KeepDestinationHeaders)
|
|
rw.WriteHeader(resp.StatusCode)
|
|
buf := mitm.buffer().Get()
|
|
_, err = io.CopyBuffer(rw, resp.Body, buf)
|
|
mitm.buffer().Put(buf)
|
|
}
|
|
return
|
|
}
|
|
|
|
// get hijacker connection
|
|
clientConn, err := hijacker(rw)
|
|
if err != nil {
|
|
http.Error(rw, err.Error(), 502)
|
|
return
|
|
}
|
|
|
|
// this goes in a separate goroutine, so that the net/http server won't think we're
|
|
// still handling the request even after hijacking the connection. Those HTTP CONNECT
|
|
// request can take forever, and the server will be stuck when "closed".
|
|
// TODO: Allow Server.Close() mechanism to shut down this connection as nicely as possible
|
|
tlsConfig, err := mitm.TLSConfigFromCA(req.URL.Host)
|
|
if err != nil {
|
|
ConnError(clientConn)
|
|
return
|
|
}
|
|
|
|
_, _ = clientConn.Write(HttpMitmOk)
|
|
|
|
// data transmit
|
|
go mitm.transmit(clientConn, req, tlsConfig)
|
|
}
|
|
|
|
func (mitm *MitmHandler) transmit(clientConn net.Conn, originalReq *http.Request, tlsConfig *tls.Config) {
|
|
// TODO: cache connections to the remote website
|
|
rawClientTls := tls.Server(clientConn, tlsConfig)
|
|
if err := rawClientTls.Handshake(); err != nil {
|
|
ConnError(clientConn)
|
|
_ = rawClientTls.Close()
|
|
return
|
|
}
|
|
defer rawClientTls.Close()
|
|
|
|
clientTlsReader := bufio.NewReader(rawClientTls)
|
|
for !isEof(clientTlsReader) {
|
|
req, err := http.ReadRequest(clientTlsReader)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
// since we're converting the request, need to carry over the original connecting IP as well
|
|
req.RemoteAddr = originalReq.RemoteAddr
|
|
|
|
if !httpsRegexp.MatchString(req.URL.String()) {
|
|
req.URL, err = url.Parse("https://" + originalReq.Host + req.URL.String())
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var resp *http.Response
|
|
|
|
// Copying a Context preserves the Transport, Middleware
|
|
ctx := mitm.Ctx.WithRequest(req)
|
|
resp, err = ctx.Next(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var (
|
|
// Body buffer
|
|
buffer = new(bytes.Buffer)
|
|
// Body size
|
|
bufferSize int64
|
|
)
|
|
|
|
buf := mitm.buffer().Get()
|
|
bufferSize, err = io.CopyBuffer(buffer, resp.Body, buf)
|
|
mitm.buffer().Put(buf)
|
|
if err != nil {
|
|
_ = resp.Body.Close()
|
|
return
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
// reset Content-Length
|
|
resp.ContentLength = bufferSize
|
|
resp.Header.Set("Content-Length", strconv.Itoa(int(bufferSize)))
|
|
resp.Body = io.NopCloser(buffer)
|
|
err = resp.Write(rawClientTls)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use registers a Middleware to proxy
|
|
func (mitm *MitmHandler) Use(middleware ...Middleware) {
|
|
mitm.Ctx.Use(middleware...)
|
|
}
|
|
|
|
// UseFunc registers an MiddlewareFunc to proxy
|
|
func (mitm *MitmHandler) UseFunc(fus ...MiddlewareFunc) {
|
|
mitm.Ctx.UseFunc(fus...)
|
|
}
|
|
|
|
// OnRequest filter requests through Filters
|
|
func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
|
|
return &ReqFilterGroup{ctx: mitm.Ctx, filters: filters}
|
|
}
|
|
|
|
// OnResponse filter response through Filters
|
|
func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespFilterGroup {
|
|
return &RespFilterGroup{ctx: mitm.Ctx, filters: filters}
|
|
}
|
|
|
|
// Get buffer pool
|
|
func (mitm *MitmHandler) buffer() httputil.BufferPool {
|
|
if mitm.BufferPool != nil {
|
|
return mitm.BufferPool
|
|
}
|
|
return pool.DefaultBuffer
|
|
}
|
|
|
|
// Get cert.Container instance
|
|
func (mitm *MitmHandler) certContainer() cert.Container {
|
|
if mitm.CertContainer != nil {
|
|
return mitm.CertContainer
|
|
}
|
|
return cert.DefaultMemProvider
|
|
}
|
|
|
|
// Transport
|
|
func (mitm *MitmHandler) Transport() *http.Transport {
|
|
return mitm.Ctx.Transport
|
|
}
|
|
|
|
func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
|
|
host = stripPort(host)
|
|
|
|
// Returned existing certificate for the host
|
|
crt, err := mitm.certContainer().Get(host)
|
|
if err == nil && crt != nil {
|
|
return &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*crt},
|
|
}, nil
|
|
}
|
|
|
|
// Issue a certificate for host
|
|
crt, err = signHost(mitm.Certificate, []string{host})
|
|
if err != nil {
|
|
err = fmt.Errorf("cannot sign host certificate with provided CA: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
// Set certificate to container
|
|
_ = mitm.certContainer().Set(host, crt)
|
|
|
|
return &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
Certificates: []tls.Certificate{*crt},
|
|
}, nil
|
|
}
|
|
|
|
// sign host
|
|
func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) {
|
|
// Use the provided ca for certificate generation.
|
|
var x509ca *x509.Certificate
|
|
x509ca, err = x509.ParseCertificate(ca.Certificate[0])
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
start := time.Unix(time.Now().Unix()-2592000, 0) // 2592000 = 30 day
|
|
end := time.Unix(time.Now().Unix()+31536000, 0) // 31536000 = 365 day
|
|
|
|
var random CounterEncryptorRand
|
|
random, err = NewCounterEncryptorRand(ca.PrivateKey, hashHosts(hosts))
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var pk crypto.Signer
|
|
switch ca.PrivateKey.(type) {
|
|
case *rsa.PrivateKey:
|
|
pk, err = rsa.GenerateKey(&random, 2048)
|
|
case *ecdsa.PrivateKey:
|
|
pk, err = ecdsa.GenerateKey(elliptic.P256(), &random)
|
|
default:
|
|
err = fmt.Errorf("unsupported key type %T", ca.PrivateKey)
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// certificate template
|
|
serial := big.NewInt(mpsRand.Int63())
|
|
tpl := x509.Certificate{
|
|
SerialNumber: serial,
|
|
Issuer: x509ca.Subject,
|
|
Subject: pkix.Name{
|
|
Organization: []string{"MPS untrusted MITM proxy Inc"},
|
|
},
|
|
NotBefore: start,
|
|
NotAfter: end,
|
|
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
BasicConstraintsValid: true,
|
|
EmailAddresses: x509ca.EmailAddresses,
|
|
}
|
|
|
|
total := len(hosts)
|
|
for i := 0; i < total; i++ {
|
|
if ip := net.ParseIP(hosts[i]); ip != nil {
|
|
tpl.IPAddresses = append(tpl.IPAddresses, ip)
|
|
} else {
|
|
tpl.DNSNames = append(tpl.DNSNames, hosts[i])
|
|
tpl.Subject.CommonName = hosts[i]
|
|
}
|
|
}
|
|
|
|
var der []byte
|
|
der, err = x509.CreateCertificate(&random, &tpl, x509ca, pk.Public(), ca.PrivateKey)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
cert = &tls.Certificate{
|
|
Certificate: [][]byte{der, ca.Certificate[0]},
|
|
PrivateKey: pk,
|
|
}
|
|
return
|
|
}
|
|
|
|
func stripPort(s string) string {
|
|
var ix int
|
|
if strings.Contains(s, "[") && strings.Contains(s, "]") {
|
|
// ipv6 : for example : [2606:4700:4700::1111]:443
|
|
// strip '[' and ']'
|
|
s = strings.ReplaceAll(s, "[", "")
|
|
s = strings.ReplaceAll(s, "]", "")
|
|
|
|
ix = strings.LastIndexAny(s, ":")
|
|
if ix == -1 {
|
|
return s
|
|
}
|
|
} else {
|
|
//ipv4
|
|
ix = strings.IndexRune(s, ':')
|
|
if ix == -1 {
|
|
return s
|
|
}
|
|
}
|
|
return s[:ix]
|
|
}
|
|
|
|
func hashHosts(lst []string) []byte {
|
|
c := make([]string, len(lst))
|
|
copy(c, lst)
|
|
sort.Strings(c)
|
|
h := sha1.New()
|
|
h.Write([]byte(strings.Join(c, ",")))
|
|
return h.Sum(nil)
|
|
}
|
|
|
|
// cloneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if
|
|
// cfg is nil. This is safe to call even if cfg is in active use by a TLS
|
|
// client or server.
|
|
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
|
|
if cfg == nil {
|
|
return &tls.Config{
|
|
InsecureSkipVerify: true,
|
|
}
|
|
}
|
|
return cfg.Clone()
|
|
}
|
|
|
|
func isEof(r *bufio.Reader) bool {
|
|
_, err := r.Peek(1)
|
|
if err == io.EOF {
|
|
return true
|
|
}
|
|
return false
|
|
}
|