mirror of
				https://github.com/xjasonlyu/tun2socks.git
				synced 2025-10-31 20:12:41 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			462 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			462 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package core
 | |
| 
 | |
| /*
 | |
| #cgo CFLAGS: -I./c/include
 | |
| #include "lwip/tcp.h"
 | |
| */
 | |
| import "C"
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"math/rand"
 | |
| 	"net"
 | |
| 	"sync"
 | |
| 	"time"
 | |
| 	"unsafe"
 | |
| )
 | |
| 
 | |
| type tcpConnState uint
 | |
| 
 | |
| const (
 | |
| 	// tcpNewConn is the initial state.
 | |
| 	tcpNewConn tcpConnState = iota
 | |
| 
 | |
| 	// tcpConnecting indicates the handler is still connecting remote host.
 | |
| 	tcpConnecting
 | |
| 
 | |
| 	// tcpConnected indicates the connection has been established, handler
 | |
| 	// may write data to TUN, and read data from TUN.
 | |
| 	tcpConnected
 | |
| 
 | |
| 	// tcpWriteClosed indicates the handler has closed the writing side
 | |
| 	// of the connection, no more data will send to TUN, but handler can still
 | |
| 	// read data from TUN.
 | |
| 	tcpWriteClosed
 | |
| 
 | |
| 	// tcpReceiveClosed indicates lwIP has received a FIN segment from
 | |
| 	// local peer, the reading side is closed, no more data can be read
 | |
| 	// from TUN, but handler can still write data to TUN.
 | |
| 	tcpReceiveClosed
 | |
| 
 | |
| 	// tcpClosing indicates both reading side and writing side are closed,
 | |
| 	// resources deallocation will be triggered at any time in lwIP callbacks.
 | |
| 	tcpClosing
 | |
| 
 | |
| 	// tcpAborting indicates the connection is aborting, resources deallocation
 | |
| 	// will be triggered at any time in lwIP callbacks.
 | |
| 	tcpAborting
 | |
| 
 | |
| 	// tcpClosed indicates the connection has been closed, resources were freed.
 | |
| 	tcpClosed
 | |
| 
 | |
| 	// tcpErrord indicates an fatal error occured on the connection, resources
 | |
| 	// were freed.
 | |
| 	tcpErrored
 | |
| )
 | |
| 
 | |
| type tcpConn struct {
 | |
| 	sync.Mutex
 | |
| 
 | |
| 	pcb           *C.struct_tcp_pcb
 | |
| 	handler       TCPConnHandler
 | |
| 	remoteAddr    *net.TCPAddr
 | |
| 	localAddr     *net.TCPAddr
 | |
| 	connKeyArg    unsafe.Pointer
 | |
| 	connKey       uint32
 | |
| 	canWrite      *sync.Cond // Condition variable to implement TCP back pressure.
 | |
| 	state         tcpConnState
 | |
| 	sndPipeReader *io.PipeReader
 | |
| 	sndPipeWriter *io.PipeWriter
 | |
| 	closeOnce     sync.Once
 | |
| 	closeErr      error
 | |
| }
 | |
| 
 | |
| func newTCPConn(pcb *C.struct_tcp_pcb, handler TCPConnHandler) (TCPConn, error) {
 | |
| 	connKeyArg := newConnKeyArg()
 | |
| 	connKey := rand.Uint32()
 | |
| 	setConnKeyVal(unsafe.Pointer(connKeyArg), connKey)
 | |
| 
 | |
| 	// Pass the key as arg for subsequent tcp callbacks.
 | |
| 	C.tcp_arg(pcb, unsafe.Pointer(connKeyArg))
 | |
| 
 | |
| 	// Register callbacks.
 | |
| 	setTCPRecvCallback(pcb)
 | |
| 	setTCPSentCallback(pcb)
 | |
| 	setTCPErrCallback(pcb)
 | |
| 	setTCPPollCallback(pcb, C.u8_t(TCP_POLL_INTERVAL))
 | |
| 
 | |
| 	pipeReader, pipeWriter := io.Pipe()
 | |
| 	conn := &tcpConn{
 | |
| 		pcb:           pcb,
 | |
| 		handler:       handler,
 | |
| 		localAddr:     ParseTCPAddr(ipAddrNTOA(pcb.remote_ip), uint16(pcb.remote_port)),
 | |
| 		remoteAddr:    ParseTCPAddr(ipAddrNTOA(pcb.local_ip), uint16(pcb.local_port)),
 | |
| 		connKeyArg:    connKeyArg,
 | |
| 		connKey:       connKey,
 | |
| 		canWrite:      sync.NewCond(&sync.Mutex{}),
 | |
| 		state:         tcpNewConn,
 | |
| 		sndPipeReader: pipeReader,
 | |
| 		sndPipeWriter: pipeWriter,
 | |
| 	}
 | |
| 
 | |
| 	// Associate conn with key and save to the global map.
 | |
| 	tcpConns.Store(connKey, conn)
 | |
| 
 | |
| 	// Connecting remote host could take some time, do it in another goroutine
 | |
| 	// to prevent blocking the lwip thread.
 | |
| 	conn.Lock()
 | |
| 	conn.state = tcpConnecting
 | |
| 	conn.Unlock()
 | |
| 	go func() {
 | |
| 		err := handler.Handle(TCPConn(conn), conn.remoteAddr)
 | |
| 		if err != nil {
 | |
| 			conn.Abort()
 | |
| 		} else {
 | |
| 			conn.Lock()
 | |
| 			conn.state = tcpConnected
 | |
| 			conn.Unlock()
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	return conn, NewLWIPError(LWIP_ERR_OK)
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) RemoteAddr() net.Addr {
 | |
| 	return conn.remoteAddr
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) LocalAddr() net.Addr {
 | |
| 	return conn.localAddr
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) SetDeadline(t time.Time) error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) SetReadDeadline(t time.Time) error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) SetWriteDeadline(t time.Time) error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) receiveCheck() error {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	switch conn.state {
 | |
| 	case tcpConnected:
 | |
| 		fallthrough
 | |
| 	case tcpWriteClosed:
 | |
| 		return nil
 | |
| 	case tcpNewConn:
 | |
| 		fallthrough
 | |
| 	case tcpConnecting:
 | |
| 		fallthrough
 | |
| 	case tcpAborting:
 | |
| 		fallthrough
 | |
| 	case tcpClosed:
 | |
| 		return NewLWIPError(LWIP_ERR_CONN)
 | |
| 	case tcpReceiveClosed:
 | |
| 		fallthrough
 | |
| 	case tcpClosing:
 | |
| 		return NewLWIPError(LWIP_ERR_CLSD)
 | |
| 	case tcpErrored:
 | |
| 		conn.abortInternal()
 | |
| 		return NewLWIPError(LWIP_ERR_ABRT)
 | |
| 	default:
 | |
| 		panic("unexpected error")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Receive(data []byte) error {
 | |
| 	if err := conn.receiveCheck(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	n, err := conn.sndPipeWriter.Write(data)
 | |
| 	if err != nil {
 | |
| 		return NewLWIPError(LWIP_ERR_CLSD)
 | |
| 	}
 | |
| 	C.tcp_recved(conn.pcb, C.u16_t(n))
 | |
| 	return NewLWIPError(LWIP_ERR_OK)
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Read(data []byte) (int, error) {
 | |
| 	conn.Lock()
 | |
| 	if conn.state == tcpReceiveClosed {
 | |
| 		conn.Unlock()
 | |
| 		return 0, io.EOF
 | |
| 	}
 | |
| 	if conn.state >= tcpClosing {
 | |
| 		conn.Unlock()
 | |
| 		return 0, io.ErrClosedPipe
 | |
| 	}
 | |
| 	conn.Unlock()
 | |
| 
 | |
| 	// Handler should get EOF.
 | |
| 	n, err := conn.sndPipeReader.Read(data)
 | |
| 	if err == io.ErrClosedPipe {
 | |
| 		err = io.EOF
 | |
| 	}
 | |
| 	return n, err
 | |
| }
 | |
| 
 | |
| // writeInternal enqueues data to snd_buf, and treats ERR_MEM returned by tcp_write not an error,
 | |
| // but instead tells the caller that data is not successfully enqueued, and should try
 | |
| // again another time. By calling this function, the lwIP thread is assumed to be already
 | |
| // locked by the caller.
 | |
| func (conn *tcpConn) writeInternal(data []byte) (int, error) {
 | |
| 	err := C.tcp_write(conn.pcb, unsafe.Pointer(&data[0]), C.u16_t(len(data)), C.TCP_WRITE_FLAG_COPY)
 | |
| 	if err == C.ERR_OK {
 | |
| 		C.tcp_output(conn.pcb)
 | |
| 		return len(data), nil
 | |
| 	} else if err == C.ERR_MEM {
 | |
| 		return 0, nil
 | |
| 	}
 | |
| 	return 0, fmt.Errorf("tcp_write failed (%v)", int(err))
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) writeCheck() error {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	switch conn.state {
 | |
| 	case tcpConnecting:
 | |
| 		fallthrough
 | |
| 	case tcpConnected:
 | |
| 		fallthrough
 | |
| 	case tcpReceiveClosed:
 | |
| 		return nil
 | |
| 	case tcpWriteClosed:
 | |
| 		fallthrough
 | |
| 	case tcpClosing:
 | |
| 		fallthrough
 | |
| 	case tcpClosed:
 | |
| 		fallthrough
 | |
| 	case tcpErrored:
 | |
| 		fallthrough
 | |
| 	case tcpAborting:
 | |
| 		return io.ErrClosedPipe
 | |
| 	default:
 | |
| 		panic("unexpected error")
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Write(data []byte) (int, error) {
 | |
| 	totalWritten := 0
 | |
| 
 | |
| 	conn.canWrite.L.Lock()
 | |
| 	defer conn.canWrite.L.Unlock()
 | |
| 
 | |
| 	for len(data) > 0 {
 | |
| 		if err := conn.writeCheck(); err != nil {
 | |
| 			return totalWritten, err
 | |
| 		}
 | |
| 
 | |
| 		lwipMutex.Lock()
 | |
| 		toWrite := len(data)
 | |
| 		if toWrite > int(conn.pcb.snd_buf) {
 | |
| 			// Write at most the size of the LWIP buffer.
 | |
| 			toWrite = int(conn.pcb.snd_buf)
 | |
| 		}
 | |
| 		if toWrite > 0 {
 | |
| 			written, err := conn.writeInternal(data[0:toWrite])
 | |
| 			totalWritten += written
 | |
| 			if err != nil {
 | |
| 				lwipMutex.Unlock()
 | |
| 				return totalWritten, err
 | |
| 			}
 | |
| 			data = data[written:len(data)]
 | |
| 		}
 | |
| 		lwipMutex.Unlock()
 | |
| 		if len(data) == 0 {
 | |
| 			break // Don't block if all the data has been written.
 | |
| 		}
 | |
| 		conn.canWrite.Wait()
 | |
| 	}
 | |
| 
 | |
| 	return totalWritten, nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) CloseWrite() error {
 | |
| 	conn.Lock()
 | |
| 	if conn.state >= tcpClosing || conn.state == tcpWriteClosed {
 | |
| 		conn.Unlock()
 | |
| 		return nil
 | |
| 	}
 | |
| 	if conn.state == tcpReceiveClosed {
 | |
| 		conn.state = tcpClosing
 | |
| 	} else {
 | |
| 		conn.state = tcpWriteClosed
 | |
| 	}
 | |
| 	conn.Unlock()
 | |
| 
 | |
| 	lwipMutex.Lock()
 | |
| 	// FIXME Handle tcp_shutdown error.
 | |
| 	C.tcp_shutdown(conn.pcb, 0, 1)
 | |
| 	lwipMutex.Unlock()
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) CloseRead() error {
 | |
| 	return conn.sndPipeReader.Close()
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Sent(len uint16) error {
 | |
| 	// Some packets are acknowledged by local client, check if any pending data to send.
 | |
| 	return conn.checkState()
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) checkClosing() error {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	if conn.state == tcpClosing {
 | |
| 		conn.closeInternal()
 | |
| 		return NewLWIPError(LWIP_ERR_OK)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) checkAborting() error {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	if conn.state == tcpAborting {
 | |
| 		conn.abortInternal()
 | |
| 		return NewLWIPError(LWIP_ERR_ABRT)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) isClosed() bool {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	return conn.state == tcpClosed
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) checkState() error {
 | |
| 	if conn.isClosed() {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	err := conn.checkClosing()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	err = conn.checkAborting()
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// Signal the writer to try writting.
 | |
| 	conn.canWrite.Broadcast()
 | |
| 
 | |
| 	return NewLWIPError(LWIP_ERR_OK)
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Close() error {
 | |
| 	conn.closeOnce.Do(conn.close)
 | |
| 	return conn.closeErr
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) close() {
 | |
| 	err := conn.CloseRead()
 | |
| 	if err != nil {
 | |
| 		conn.closeErr = err
 | |
| 	}
 | |
| 	err = conn.CloseWrite()
 | |
| 	if err != nil {
 | |
| 		conn.closeErr = err
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) setLocalClosed() error {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	if conn.state >= tcpClosing || conn.state == tcpReceiveClosed {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	// Causes the read half of the pipe returns.
 | |
| 	conn.sndPipeWriter.Close()
 | |
| 
 | |
| 	if conn.state == tcpWriteClosed {
 | |
| 		conn.state = tcpClosing
 | |
| 	} else {
 | |
| 		conn.state = tcpReceiveClosed
 | |
| 	}
 | |
| 	conn.canWrite.Broadcast()
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // Never call this function outside of the lwIP thread.
 | |
| func (conn *tcpConn) closeInternal() error {
 | |
| 	C.tcp_arg(conn.pcb, nil)
 | |
| 	C.tcp_recv(conn.pcb, nil)
 | |
| 	C.tcp_sent(conn.pcb, nil)
 | |
| 	C.tcp_err(conn.pcb, nil)
 | |
| 	C.tcp_poll(conn.pcb, nil, 0)
 | |
| 
 | |
| 	conn.release()
 | |
| 
 | |
| 	// FIXME Handle error.
 | |
| 	err := C.tcp_close(conn.pcb)
 | |
| 	if err == C.ERR_OK {
 | |
| 		return nil
 | |
| 	} else {
 | |
| 		return errors.New(fmt.Sprintf("close TCP connection failed, lwip error code %d", int(err)))
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // Never call this function outside of the lwIP thread since it calls
 | |
| // tcp_abort() and in that case we must return ERR_ABRT to lwIP.
 | |
| func (conn *tcpConn) abortInternal() {
 | |
| 	conn.release()
 | |
| 	C.tcp_abort(conn.pcb)
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Abort() {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	conn.state = tcpAborting
 | |
| 	conn.canWrite.Broadcast()
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Err(err error) {
 | |
| 	conn.Lock()
 | |
| 	defer conn.Unlock()
 | |
| 
 | |
| 	conn.release()
 | |
| 	conn.state = tcpErrored
 | |
| 	conn.canWrite.Broadcast()
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) LocalClosed() error {
 | |
| 	conn.setLocalClosed()
 | |
| 	return conn.checkState()
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) release() {
 | |
| 	if _, found := tcpConns.Load(conn.connKey); found {
 | |
| 		freeConnKeyArg(conn.connKeyArg)
 | |
| 		tcpConns.Delete(conn.connKey)
 | |
| 	}
 | |
| 	conn.sndPipeWriter.Close()
 | |
| 	conn.sndPipeReader.Close()
 | |
| 	conn.state = tcpClosed
 | |
| }
 | |
| 
 | |
| func (conn *tcpConn) Poll() error {
 | |
| 	return conn.checkState()
 | |
| }
 | 
