Revert "Use bufio reader"

This reverts commit c8003d469e.
This commit is contained in:
Antonio Mika
2024-04-30 10:45:33 -04:00
parent c8003d469e
commit ff656b0c94
3 changed files with 68 additions and 48 deletions

View File

@@ -31,7 +31,7 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
continue
}
tlsHello, teeConn, peekErr := utils.PeekTLSHello(cl)
tlsHello, buf, teeConn, peekErr := utils.PeekTLSHello(cl)
if peekErr != nil && tlsHello == nil {
return teeConn, nil
}
@@ -59,20 +59,20 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
connectionLocation, err := balancer.NextServer()
if err != nil {
log.Println("Unable to load connection location:", err)
teeConn.Close()
cl.Close()
continue
}
host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
if err != nil {
log.Println("Unable to decode connection location:", err)
teeConn.Close()
cl.Close()
continue
}
hostAddr := string(host)
logLine := fmt.Sprintf("Accepted connection from %s -> %s", teeConn.RemoteAddr().String(), teeConn.LocalAddr().String())
logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
log.Println(logLine)
if viper.GetBool("log-to-client") {
@@ -94,11 +94,18 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
conn, err := net.Dial("unix", hostAddr)
if err != nil {
log.Println("Error connecting to tcp balancer:", err)
teeConn.Close()
cl.Close()
continue
}
go utils.CopyBoth(conn, teeConn)
_, err = conn.Write(buf.Bytes())
if err != nil {
log.Println("Unable to write to conn:", err)
cl.Close()
continue
}
go utils.CopyBoth(conn, cl)
}
}

View File

@@ -1,7 +1,6 @@
package utils
import (
"bufio"
"bytes"
"crypto/tls"
"io"
@@ -89,22 +88,44 @@ func (s *SSHConnection) CleanUp(state *State) {
// TeeConn represents a simple net.Conn interface for SNI Processing.
type TeeConn struct {
Conn net.Conn
Buffer *bufio.ReadWriter
Conn net.Conn
Reader io.Reader
Buffer *bytes.Buffer
FirstRead bool
Flushed bool
}
// Read implements a reader ontop of the TeeReader.
func (conn *TeeConn) Read(p []byte) (int, error) {
return conn.Buffer.Read(p)
if !conn.FirstRead {
conn.FirstRead = true
return conn.Reader.Read(p)
}
if conn.FirstRead && !conn.Flushed {
conn.Flushed = true
copy(p[0:conn.Buffer.Len()], conn.Buffer.Bytes())
return conn.Buffer.Len(), nil
}
return conn.Conn.Read(p)
}
// Write is a shim function to fit net.Conn.
func (conn *TeeConn) Write(p []byte) (int, error) {
return conn.Buffer.Write(p)
if !conn.Flushed {
return 0, io.ErrClosedPipe
}
return conn.Conn.Write(p)
}
// Close is a shim function to fit net.Conn.
func (conn *TeeConn) Close() error {
if !conn.Flushed {
return nil
}
return conn.Conn.Close()
}
@@ -124,19 +145,22 @@ func (conn *TeeConn) SetReadDeadline(t time.Time) error { return conn.Conn.SetRe
func (conn *TeeConn) SetWriteDeadline(t time.Time) error { return conn.Conn.SetWriteDeadline(t) }
// GetBuffer returns the tee'd buffer.
func (conn *TeeConn) GetBuffer() *bufio.ReadWriter { return conn.Buffer }
func (conn *TeeConn) GetBuffer() *bytes.Buffer { return conn.Buffer }
func NewTeeConn(conn net.Conn) *TeeConn {
teeConn := &TeeConn{
Conn: conn,
Buffer: bufio.NewReadWriter(bufio.NewReaderSize(conn, 8192), bufio.NewWriterSize(conn, 8192)),
Conn: conn,
Buffer: bytes.NewBuffer([]byte{}),
Flushed: false,
}
teeConn.Reader = io.TeeReader(conn, teeConn.Buffer)
return teeConn
}
// PeekTLSHello peeks the TLS Connection Hello to proxy based on SNI.
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *TeeConn, error) {
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *bytes.Buffer, *TeeConn, error) {
var tlsHello *tls.ClientHelloInfo
tlsConfig := &tls.Config{
@@ -148,33 +172,11 @@ func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *TeeConn, error) {
teeConn := NewTeeConn(conn)
header, err := teeConn.GetBuffer().Peek(5)
if err != nil {
return tlsHello, teeConn, err
}
err := tls.Server(teeConn, tlsConfig).Handshake()
if header[0] != 0x16 {
return tlsHello, teeConn, err
}
helloBytes, err := teeConn.GetBuffer().Peek(len(header) + (int(header[3])<<8 | int(header[4])))
if err != nil {
return tlsHello, teeConn, err
}
err = tls.Server(bufConn{reader: bytes.NewReader(helloBytes)}, tlsConfig).Handshake()
return tlsHello, teeConn, err
return tlsHello, teeConn.GetBuffer(), teeConn, err
}
type bufConn struct {
reader io.Reader
net.Conn
}
func (b bufConn) Read(p []byte) (int, error) { return b.reader.Read(p) }
func (bufConn) Write(p []byte) (int, error) { return 0, io.EOF }
// IdleTimeoutConn handles the connection with a context deadline.
// code adapted from https://qiita.com/kwi/items/b38d6273624ad3f6ae79
type IdleTimeoutConn struct {

View File

@@ -1,6 +1,7 @@
package utils
import (
"bytes"
"encoding/base64"
"fmt"
"io"
@@ -93,18 +94,19 @@ func (tH *TCPHolder) Handle(state *State) {
continue
}
realConn := cl
var firstWrite *bytes.Buffer
balancerName := ""
if tH.SNIProxy {
tlsHello, realConn, err := PeekTLSHello(cl)
tlsHello, buf, _, err := PeekTLSHello(cl)
if err != nil && tlsHello == nil {
log.Printf("Unable to read TLS hello: %s", err)
realConn.Close()
cl.Close()
continue
}
balancerName = tlsHello.ServerName
firstWrite = buf
}
pB, ok := tH.Balancers.Load(balancerName)
@@ -119,7 +121,7 @@ func (tH *TCPHolder) Handle(state *State) {
if pB == nil {
log.Printf("Unable to load connection location: %s not found on TCP listener %s", balancerName, tH.TCPHost)
realConn.Close()
cl.Close()
continue
}
}
@@ -129,20 +131,20 @@ func (tH *TCPHolder) Handle(state *State) {
connectionLocation, err := balancer.NextServer()
if err != nil {
log.Println("Unable to load connection location:", err)
realConn.Close()
cl.Close()
continue
}
host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
if err != nil {
log.Println("Unable to decode connection location:", err)
realConn.Close()
cl.Close()
continue
}
hostAddr := string(host)
logLine := fmt.Sprintf("Accepted connection from %s -> %s", realConn.RemoteAddr().String(), realConn.LocalAddr().String())
logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
log.Println(logLine)
if viper.GetBool("log-to-client") {
@@ -164,11 +166,20 @@ func (tH *TCPHolder) Handle(state *State) {
conn, err := net.Dial("unix", hostAddr)
if err != nil {
log.Println("Error connecting to tcp balancer:", err)
realConn.Close()
cl.Close()
continue
}
go CopyBoth(conn, realConn)
if firstWrite != nil {
_, err := conn.Write(firstWrite.Bytes())
if err != nil {
log.Println("Unable to write to conn:", err)
cl.Close()
continue
}
}
go CopyBoth(conn, cl)
}
}