Implement better erroring and add descriptive comments

This commit is contained in:
Cassandra
2018-12-06 18:16:48 +01:00
parent e3ccd145e3
commit 060bdd4e43

View File

@@ -27,7 +27,7 @@ func RunServer(bind string, p []*protocols.Protocol) {
if err != nil {
fmt.Printf("Error while accepting connection: %s\n", err)
}
fmt.Printf("Accepted connection from %s.\n", conn.RemoteAddr())
fmt.Printf("%s: Connection accepted.\n", conn.RemoteAddr())
go connectionHandler(conn, p)
}
}
@@ -38,10 +38,12 @@ func connectionHandler(conn net.Conn, p []*protocols.Protocol) {
identifyBuffer := make([]byte, 1024) // at max 1KB buffer to identify payload
var protocol *protocols.Protocol
// read a byte and add it to our internal buffer
// read the handshake into our buffer
conn.SetReadDeadline(time.Now().Add(15 * time.Second)) // 15-second timeout to identify
n, err := conn.Read(identifyBuffer)
if err != nil {
conn.Close()
fmt.Printf("%s: Identify read error (%s). Connection closed.\n", connectionId, err)
return
}
conn.SetReadDeadline(time.Time{}) // reset our timeout
@@ -59,13 +61,13 @@ func connectionHandler(conn net.Conn, p []*protocols.Protocol) {
targetConn, err := net.Dial("tcp", protocol.Target)
if err != nil {
conn.Close()
fmt.Printf("%s: %s rejected our connection. Connection closed.\n", connectionId, protocol.Target)
fmt.Printf("%s: %s error (%s). Connection closed.\n", connectionId, protocol.Target, err)
return // we were unable to establish the connection with the proxy target
}
_, err = targetConn.Write(identifyBuffer[:n]) // tell them everything they just told us
if err != nil {
conn.Close()
fmt.Printf("%s: %s cut off our identification payload. Connection closed.\n", connectionId, protocol.Target)
fmt.Printf("%s: %s error (%s). Connection closed.\n", connectionId, protocol.Target, err)
return // remote rejected us?? okay.
}
@@ -90,6 +92,7 @@ func determineProtocol(data []byte, p []*protocols.Protocol) *protocols.Protocol
continue // avoids unnecessary comparisons
}
// compare against bytestrings first for efficiency
for _, byteSlice := range protocol.MatchBytes {
byteSliceLength := len(byteSlice)
if dataLength < byteSliceLength {