Files
tun2socks/core/tcp_conn.go
2019-08-11 18:00:23 +08:00

462 lines
9.8 KiB
Go

package core
/*
#cgo CFLAGS: -I./c/include
#include "lwip/tcp.h"
*/
import "C"
import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"sync"
"time"
"unsafe"
)
type tcpConnState uint
const (
// tcpNewConn is the initial state.
tcpNewConn tcpConnState = iota
// tcpConnecting indicates the handler is still connecting remote host.
tcpConnecting
// tcpConnected indicates the connection has been established, handler
// may write data to TUN, and read data from TUN.
tcpConnected
// tcpWriteClosed indicates the handler has closed the writing side
// of the connection, no more data will send to TUN, but handler can still
// read data from TUN.
tcpWriteClosed
// tcpReceiveClosed indicates lwIP has received a FIN segment from
// local peer, the reading side is closed, no more data can be read
// from TUN, but handler can still write data to TUN.
tcpReceiveClosed
// tcpClosing indicates both reading side and writing side are closed,
// resources deallocation will be triggered at any time in lwIP callbacks.
tcpClosing
// tcpAborting indicates the connection is aborting, resources deallocation
// will be triggered at any time in lwIP callbacks.
tcpAborting
// tcpClosed indicates the connection has been closed, resources were freed.
tcpClosed
// tcpErrord indicates an fatal error occured on the connection, resources
// were freed.
tcpErrored
)
type tcpConn struct {
sync.Mutex
pcb *C.struct_tcp_pcb
handler TCPConnHandler
remoteAddr *net.TCPAddr
localAddr *net.TCPAddr
connKeyArg unsafe.Pointer
connKey uint32
canWrite *sync.Cond // Condition variable to implement TCP back pressure.
state tcpConnState
sndPipeReader *io.PipeReader
sndPipeWriter *io.PipeWriter
closeOnce sync.Once
closeErr error
}
func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error) {
connKeyArg := newConnKeyArg()
connKey := rand.Uint32()
setConnKeyVal(unsafe.Pointer(connKeyArg), connKey)
// Pass the key as arg for subsequent tcp callbacks.
C.tcp_arg(pcb, unsafe.Pointer(connKeyArg))
// Register callbacks.
setTCPRecvCallback(pcb)
setTCPSentCallback(pcb)
setTCPErrCallback(pcb)
setTCPPollCallback(pcb, C.u8_t(TCP_POLL_INTERVAL))
pipeReader, pipeWriter := io.Pipe()
conn := &tcpConn{
pcb: pcb,
handler: handler,
localAddr: ParseTCPAddr(ipAddrNTOA(pcb.remote_ip), uint16(pcb.remote_port)),
remoteAddr: ParseTCPAddr(ipAddrNTOA(pcb.local_ip), uint16(pcb.local_port)),
connKeyArg: connKeyArg,
connKey: connKey,
canWrite: sync.NewCond(&sync.Mutex{}),
state: tcpNewConn,
sndPipeReader: pipeReader,
sndPipeWriter: pipeWriter,
}
// Associate conn with key and save to the global map.
tcpConns.Store(connKey, conn)
// Connecting remote host could take some time, do it in another goroutine
// to prevent blocking the lwip thread.
conn.Lock()
conn.state = tcpConnecting
conn.Unlock()
go func() {
err := handler.Handle(TCPConn(conn), conn.remoteAddr)
if err != nil {
conn.Abort()
} else {
conn.Lock()
conn.state = tcpConnected
conn.Unlock()
}
}()
return conn, NewLWIPError(LWIP_ERR_OK)
}
func (conn *tcpConn) RemoteAddr() net.Addr {
return conn.remoteAddr
}
func (conn *tcpConn) LocalAddr() net.Addr {
return conn.localAddr
}
func (conn *tcpConn) SetDeadline(t time.Time) error {
return nil
}
func (conn *tcpConn) SetReadDeadline(t time.Time) error {
return nil
}
func (conn *tcpConn) SetWriteDeadline(t time.Time) error {
return nil
}
func (conn *tcpConn) receiveCheck() error {
conn.Lock()
defer conn.Unlock()
switch conn.state {
case tcpConnected:
fallthrough
case tcpWriteClosed:
return nil
case tcpNewConn:
fallthrough
case tcpConnecting:
fallthrough
case tcpAborting:
fallthrough
case tcpClosed:
return NewLWIPError(LWIP_ERR_CONN)
case tcpReceiveClosed:
fallthrough
case tcpClosing:
return NewLWIPError(LWIP_ERR_CLSD)
case tcpErrored:
conn.abortInternal()
return NewLWIPError(LWIP_ERR_ABRT)
default:
panic("unexpected error")
}
return nil
}
func (conn *tcpConn) Receive(data []byte) error {
if err := conn.receiveCheck(); err != nil {
return err
}
n, err := conn.sndPipeWriter.Write(data)
if err != nil {
return NewLWIPError(LWIP_ERR_CLSD)
}
C.tcp_recved(conn.pcb, C.u16_t(n))
return NewLWIPError(LWIP_ERR_OK)
}
func (conn *tcpConn) Read(data []byte) (int, error) {
conn.Lock()
if conn.state == tcpReceiveClosed {
conn.Unlock()
return 0, io.EOF
}
if conn.state >= tcpClosing {
conn.Unlock()
return 0, io.ErrClosedPipe
}
conn.Unlock()
// Handler should get EOF.
n, err := conn.sndPipeReader.Read(data)
if err == io.ErrClosedPipe {
err = io.EOF
}
return n, err
}
// writeInternal enqueues data to snd_buf, and treats ERR_MEM returned by tcp_write not an error,
// but instead tells the caller that data is not successfully enqueued, and should try
// again another time. By calling this function, the lwIP thread is assumed to be already
// locked by the caller.
func (conn *tcpConn) writeInternal(data []byte) (int, error) {
err := C.tcp_write(conn.pcb, unsafe.Pointer(&data[0]), C.u16_t(len(data)), C.TCP_WRITE_FLAG_COPY)
if err == C.ERR_OK {
C.tcp_output(conn.pcb)
return len(data), nil
} else if err == C.ERR_MEM {
return 0, nil
}
return 0, fmt.Errorf("tcp_write failed (%v)", int(err))
}
func (conn *tcpConn) writeCheck() error {
conn.Lock()
defer conn.Unlock()
switch conn.state {
case tcpConnecting:
fallthrough
case tcpConnected:
fallthrough
case tcpReceiveClosed:
return nil
case tcpWriteClosed:
fallthrough
case tcpClosing:
fallthrough
case tcpClosed:
fallthrough
case tcpErrored:
fallthrough
case tcpAborting:
return io.ErrClosedPipe
default:
panic("unexpected error")
}
return nil
}
func (conn *tcpConn) Write(data []byte) (int, error) {
totalWritten := 0
conn.canWrite.L.Lock()
defer conn.canWrite.L.Unlock()
for len(data) > 0 {
if err := conn.writeCheck(); err != nil {
return totalWritten, err
}
lwipMutex.Lock()
toWrite := len(data)
if toWrite > int(conn.pcb.snd_buf) {
// Write at most the size of the LWIP buffer.
toWrite = int(conn.pcb.snd_buf)
}
if toWrite > 0 {
written, err := conn.writeInternal(data[0:toWrite])
totalWritten += written
if err != nil {
lwipMutex.Unlock()
return totalWritten, err
}
data = data[written:len(data)]
}
lwipMutex.Unlock()
if len(data) == 0 {
break // Don't block if all the data has been written.
}
conn.canWrite.Wait()
}
return totalWritten, nil
}
func (conn *tcpConn) CloseWrite() error {
conn.Lock()
if conn.state >= tcpClosing || conn.state == tcpWriteClosed {
conn.Unlock()
return nil
}
if conn.state == tcpReceiveClosed {
conn.state = tcpClosing
} else {
conn.state = tcpWriteClosed
}
conn.Unlock()
lwipMutex.Lock()
// FIXME Handle tcp_shutdown error.
C.tcp_shutdown(conn.pcb, 0, 1)
lwipMutex.Unlock()
return nil
}
func (conn *tcpConn) CloseRead() error {
return conn.sndPipeReader.Close()
}
func (conn *tcpConn) Sent(len uint16) error {
// Some packets are acknowledged by local client, check if any pending data to send.
return conn.checkState()
}
func (conn *tcpConn) checkClosing() error {
conn.Lock()
defer conn.Unlock()
if conn.state == tcpClosing {
conn.closeInternal()
return NewLWIPError(LWIP_ERR_OK)
}
return nil
}
func (conn *tcpConn) checkAborting() error {
conn.Lock()
defer conn.Unlock()
if conn.state == tcpAborting {
conn.abortInternal()
return NewLWIPError(LWIP_ERR_ABRT)
}
return nil
}
func (conn *tcpConn) isClosed() bool {
conn.Lock()
defer conn.Unlock()
return conn.state == tcpClosed
}
func (conn *tcpConn) checkState() error {
if conn.isClosed() {
return nil
}
err := conn.checkClosing()
if err != nil {
return err
}
err = conn.checkAborting()
if err != nil {
return err
}
// Signal the writer to try writting.
conn.canWrite.Broadcast()
return NewLWIPError(LWIP_ERR_OK)
}
func (conn *tcpConn) Close() error {
conn.closeOnce.Do(conn.close)
return conn.closeErr
}
func (conn *tcpConn) close() {
err := conn.CloseRead()
if err != nil {
conn.closeErr = err
}
err = conn.CloseWrite()
if err != nil {
conn.closeErr = err
}
}
func (conn *tcpConn) setLocalClosed() error {
conn.Lock()
defer conn.Unlock()
if conn.state >= tcpClosing || conn.state == tcpReceiveClosed {
return nil
}
// Causes the read half of the pipe returns.
conn.sndPipeWriter.Close()
if conn.state == tcpWriteClosed {
conn.state = tcpClosing
} else {
conn.state = tcpReceiveClosed
}
conn.canWrite.Broadcast()
return nil
}
// Never call this function outside of the lwIP thread.
func (conn *tcpConn) closeInternal() error {
C.tcp_arg(conn.pcb, nil)
C.tcp_recv(conn.pcb, nil)
C.tcp_sent(conn.pcb, nil)
C.tcp_err(conn.pcb, nil)
C.tcp_poll(conn.pcb, nil, 0)
conn.release()
// FIXME Handle error.
err := C.tcp_close(conn.pcb)
if err == C.ERR_OK {
return nil
} else {
return errors.New(fmt.Sprintf("close TCP connection failed, lwip error code %d", int(err)))
}
}
// Never call this function outside of the lwIP thread since it calls
// tcp_abort() and in that case we must return ERR_ABRT to lwIP.
func (conn *tcpConn) abortInternal() {
conn.release()
C.tcp_abort(conn.pcb)
}
func (conn *tcpConn) Abort() {
conn.Lock()
defer conn.Unlock()
conn.state = tcpAborting
conn.canWrite.Broadcast()
}
func (conn *tcpConn) Err(err error) {
conn.Lock()
defer conn.Unlock()
conn.release()
conn.state = tcpErrored
conn.canWrite.Broadcast()
}
func (conn *tcpConn) LocalClosed() error {
conn.setLocalClosed()
return conn.checkState()
}
func (conn *tcpConn) release() {
if _, found := tcpConns.Load(conn.connKey); found {
freeConnKeyArg(conn.connKeyArg)
tcpConns.Delete(conn.connKey)
}
conn.sndPipeWriter.Close()
conn.sndPipeReader.Close()
conn.state = tcpClosed
}
func (conn *tcpConn) Poll() error {
return conn.checkState()
}