mirror of
https://github.com/antoniomika/sish.git
synced 2025-09-26 19:21:15 +08:00
@@ -31,7 +31,7 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
tlsHello, teeConn, peekErr := utils.PeekTLSHello(cl)
|
||||
tlsHello, buf, teeConn, peekErr := utils.PeekTLSHello(cl)
|
||||
if peekErr != nil && tlsHello == nil {
|
||||
return teeConn, nil
|
||||
}
|
||||
@@ -59,20 +59,20 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
|
||||
connectionLocation, err := balancer.NextServer()
|
||||
if err != nil {
|
||||
log.Println("Unable to load connection location:", err)
|
||||
teeConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
|
||||
if err != nil {
|
||||
log.Println("Unable to decode connection location:", err)
|
||||
teeConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
hostAddr := string(host)
|
||||
|
||||
logLine := fmt.Sprintf("Accepted connection from %s -> %s", teeConn.RemoteAddr().String(), teeConn.LocalAddr().String())
|
||||
logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
|
||||
log.Println(logLine)
|
||||
|
||||
if viper.GetBool("log-to-client") {
|
||||
@@ -94,11 +94,18 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
|
||||
conn, err := net.Dial("unix", hostAddr)
|
||||
if err != nil {
|
||||
log.Println("Error connecting to tcp balancer:", err)
|
||||
teeConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
go utils.CopyBoth(conn, teeConn)
|
||||
_, err = conn.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
log.Println("Unable to write to conn:", err)
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
go utils.CopyBoth(conn, cl)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
@@ -89,22 +88,44 @@ func (s *SSHConnection) CleanUp(state *State) {
|
||||
|
||||
// TeeConn represents a simple net.Conn interface for SNI Processing.
|
||||
type TeeConn struct {
|
||||
Conn net.Conn
|
||||
Buffer *bufio.ReadWriter
|
||||
Conn net.Conn
|
||||
Reader io.Reader
|
||||
Buffer *bytes.Buffer
|
||||
FirstRead bool
|
||||
Flushed bool
|
||||
}
|
||||
|
||||
// Read implements a reader ontop of the TeeReader.
|
||||
func (conn *TeeConn) Read(p []byte) (int, error) {
|
||||
return conn.Buffer.Read(p)
|
||||
if !conn.FirstRead {
|
||||
conn.FirstRead = true
|
||||
return conn.Reader.Read(p)
|
||||
}
|
||||
|
||||
if conn.FirstRead && !conn.Flushed {
|
||||
conn.Flushed = true
|
||||
copy(p[0:conn.Buffer.Len()], conn.Buffer.Bytes())
|
||||
return conn.Buffer.Len(), nil
|
||||
}
|
||||
|
||||
return conn.Conn.Read(p)
|
||||
}
|
||||
|
||||
// Write is a shim function to fit net.Conn.
|
||||
func (conn *TeeConn) Write(p []byte) (int, error) {
|
||||
return conn.Buffer.Write(p)
|
||||
if !conn.Flushed {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
return conn.Conn.Write(p)
|
||||
}
|
||||
|
||||
// Close is a shim function to fit net.Conn.
|
||||
func (conn *TeeConn) Close() error {
|
||||
if !conn.Flushed {
|
||||
return nil
|
||||
}
|
||||
|
||||
return conn.Conn.Close()
|
||||
}
|
||||
|
||||
@@ -124,19 +145,22 @@ func (conn *TeeConn) SetReadDeadline(t time.Time) error { return conn.Conn.SetRe
|
||||
func (conn *TeeConn) SetWriteDeadline(t time.Time) error { return conn.Conn.SetWriteDeadline(t) }
|
||||
|
||||
// GetBuffer returns the tee'd buffer.
|
||||
func (conn *TeeConn) GetBuffer() *bufio.ReadWriter { return conn.Buffer }
|
||||
func (conn *TeeConn) GetBuffer() *bytes.Buffer { return conn.Buffer }
|
||||
|
||||
func NewTeeConn(conn net.Conn) *TeeConn {
|
||||
teeConn := &TeeConn{
|
||||
Conn: conn,
|
||||
Buffer: bufio.NewReadWriter(bufio.NewReaderSize(conn, 8192), bufio.NewWriterSize(conn, 8192)),
|
||||
Conn: conn,
|
||||
Buffer: bytes.NewBuffer([]byte{}),
|
||||
Flushed: false,
|
||||
}
|
||||
|
||||
teeConn.Reader = io.TeeReader(conn, teeConn.Buffer)
|
||||
|
||||
return teeConn
|
||||
}
|
||||
|
||||
// PeekTLSHello peeks the TLS Connection Hello to proxy based on SNI.
|
||||
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *TeeConn, error) {
|
||||
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *bytes.Buffer, *TeeConn, error) {
|
||||
var tlsHello *tls.ClientHelloInfo
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
@@ -148,33 +172,11 @@ func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *TeeConn, error) {
|
||||
|
||||
teeConn := NewTeeConn(conn)
|
||||
|
||||
header, err := teeConn.GetBuffer().Peek(5)
|
||||
if err != nil {
|
||||
return tlsHello, teeConn, err
|
||||
}
|
||||
err := tls.Server(teeConn, tlsConfig).Handshake()
|
||||
|
||||
if header[0] != 0x16 {
|
||||
return tlsHello, teeConn, err
|
||||
}
|
||||
|
||||
helloBytes, err := teeConn.GetBuffer().Peek(len(header) + (int(header[3])<<8 | int(header[4])))
|
||||
if err != nil {
|
||||
return tlsHello, teeConn, err
|
||||
}
|
||||
|
||||
err = tls.Server(bufConn{reader: bytes.NewReader(helloBytes)}, tlsConfig).Handshake()
|
||||
|
||||
return tlsHello, teeConn, err
|
||||
return tlsHello, teeConn.GetBuffer(), teeConn, err
|
||||
}
|
||||
|
||||
type bufConn struct {
|
||||
reader io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b bufConn) Read(p []byte) (int, error) { return b.reader.Read(p) }
|
||||
func (bufConn) Write(p []byte) (int, error) { return 0, io.EOF }
|
||||
|
||||
// IdleTimeoutConn handles the connection with a context deadline.
|
||||
// code adapted from https://qiita.com/kwi/items/b38d6273624ad3f6ae79
|
||||
type IdleTimeoutConn struct {
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -93,18 +94,19 @@ func (tH *TCPHolder) Handle(state *State) {
|
||||
continue
|
||||
}
|
||||
|
||||
realConn := cl
|
||||
var firstWrite *bytes.Buffer
|
||||
|
||||
balancerName := ""
|
||||
if tH.SNIProxy {
|
||||
tlsHello, realConn, err := PeekTLSHello(cl)
|
||||
tlsHello, buf, _, err := PeekTLSHello(cl)
|
||||
if err != nil && tlsHello == nil {
|
||||
log.Printf("Unable to read TLS hello: %s", err)
|
||||
realConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
balancerName = tlsHello.ServerName
|
||||
firstWrite = buf
|
||||
}
|
||||
|
||||
pB, ok := tH.Balancers.Load(balancerName)
|
||||
@@ -119,7 +121,7 @@ func (tH *TCPHolder) Handle(state *State) {
|
||||
|
||||
if pB == nil {
|
||||
log.Printf("Unable to load connection location: %s not found on TCP listener %s", balancerName, tH.TCPHost)
|
||||
realConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -129,20 +131,20 @@ func (tH *TCPHolder) Handle(state *State) {
|
||||
connectionLocation, err := balancer.NextServer()
|
||||
if err != nil {
|
||||
log.Println("Unable to load connection location:", err)
|
||||
realConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
|
||||
if err != nil {
|
||||
log.Println("Unable to decode connection location:", err)
|
||||
realConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
hostAddr := string(host)
|
||||
|
||||
logLine := fmt.Sprintf("Accepted connection from %s -> %s", realConn.RemoteAddr().String(), realConn.LocalAddr().String())
|
||||
logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
|
||||
log.Println(logLine)
|
||||
|
||||
if viper.GetBool("log-to-client") {
|
||||
@@ -164,11 +166,20 @@ func (tH *TCPHolder) Handle(state *State) {
|
||||
conn, err := net.Dial("unix", hostAddr)
|
||||
if err != nil {
|
||||
log.Println("Error connecting to tcp balancer:", err)
|
||||
realConn.Close()
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
go CopyBoth(conn, realConn)
|
||||
if firstWrite != nil {
|
||||
_, err := conn.Write(firstWrite.Bytes())
|
||||
if err != nil {
|
||||
log.Println("Unable to write to conn:", err)
|
||||
cl.Close()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
go CopyBoth(conn, cl)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user