diff --git a/opennotrd/core/tcpforward_test.go b/opennotrd/core/tcpforward_test.go index b642856..95c1d64 100644 --- a/opennotrd/core/tcpforward_test.go +++ b/opennotrd/core/tcpforward_test.go @@ -4,13 +4,20 @@ import ( "fmt" "io" "net" + "net/http" + _ "net/http/pprof" "os" + "sync" "testing" "time" "github.com/hashicorp/yamux" ) +func init() { + go http.ListenAndServe("127.0.0.1:6060", nil) +} + // client -----> tproxy | opennotr server <------ opennotr client var backendAddr = "127.0.0.1:8522" @@ -37,30 +44,22 @@ func (c *mockConn) LocalAddr() net.Addr { return c.addr } -func (c *mockConn) Write(buf []byte) (int, error) { - fmt.Printf("receive %d bytes\n", len(buf)) - return len(buf), nil -} - -func runBackend(t *testing.T) { +func runBackend() { conn, err := net.Dial("tcp", serverAddr) if err != nil { - t.Error(err) - return + panic(err) } defer conn.Close() sess, err := yamux.Client(conn, nil) if err != nil { - t.Error(err) - t.FailNow() + panic(err) } defer sess.Close() for { stream, err := sess.AcceptStream() if err != nil { - t.Error(err) - t.FailNow() + panic(err) } go func() { @@ -69,7 +68,7 @@ func runBackend(t *testing.T) { for { nr, err := stream.Read(buf) if err != nil { - fmt.Println(err) + fmt.Println("read stream fail:", err) break } stream.Write(buf[:nr]) @@ -78,7 +77,7 @@ func runBackend(t *testing.T) { } } -func runserver(t *testing.T, listener net.Listener) { +func runserver(listener net.Listener) { for { conn, err := listener.Accept() if err != nil { @@ -88,18 +87,17 @@ func runserver(t *testing.T, listener net.Listener) { go func() { sess, err := yamux.Server(conn, nil) if err != nil { - t.Error(err) - t.FailNow() + panic(err) } sessMgr := GetSessionManager() sessMgr.AddSession(vip, &Session{conn: sess}) - t.Log("add session: ", vip) + fmt.Println("add session: ", vip) }() } } -func runtproxy(t *testing.T, tcpfw *TCPForward, listener net.Listener) { +func runtproxy(tcpfw *TCPForward, listener net.Listener) { for { conn, err := listener.Accept() if err != nil { @@ -123,7 +121,7 @@ func TestTCPForward(t *testing.T) { t.Error(err) return } - defer listener.Close() + // defer listener.Close() srvlistener, err := net.Listen("tcp", serverAddr) if err != nil { @@ -132,10 +130,11 @@ func TestTCPForward(t *testing.T) { } defer srvlistener.Close() - go runBackend(t) - go runserver(t, srvlistener) - go runtproxy(t, tcpfw, listener) - + go runBackend() + go runserver(srvlistener) + go runtproxy(tcpfw, listener) + // wait for session created + time.Sleep(time.Second * 1) conn, err := net.Dial("tcp", tproxyAddr) if err != nil { t.FailNow() @@ -143,11 +142,87 @@ func TestTCPForward(t *testing.T) { defer conn.Close() go func() { - for i := 0; i < 100; i++ { + defer conn.Close() + for i := 0; i < 10; i++ { conn.Write([]byte("ping\n")) time.Sleep(time.Second * 1) } + fmt.Println("connection close") }() - io.Copy(os.Stdout, conn) + buf := make([]byte, 128) + c := 0 + for { + nr, err := conn.Read(buf) + if err != nil { + break + } + fmt.Printf("receive %d %s\n", c+1, string(buf[:nr])) + c += 1 + } +} + +func benchmark(t *testing.B, nconn int) { + // listen tproxy + tcpfw := NewTCPForward() + listener, err := tcpfw.Listen(tproxyAddr) + if err != nil { + t.Error(err) + return + } + // defer listener.Close() + + srvlistener, err := net.Listen("tcp", serverAddr) + if err != nil { + t.Error(err) + return + } + defer srvlistener.Close() + + go runBackend() + go runserver(srvlistener) + go runtproxy(tcpfw, listener) + + // wait for session created + time.Sleep(time.Second * 1) + wg := sync.WaitGroup{} + wg.Add(nconn) + defer wg.Wait() + for i := 0; i < nconn; i++ { + go func() { + defer wg.Done() + conn, err := net.Dial("tcp", tproxyAddr) + if err != nil { + t.FailNow() + } + defer conn.Close() + + go func() { + defer conn.Close() + for i := 0; i < 10; i++ { + conn.Write([]byte("ping\n")) + time.Sleep(time.Second * 1) + } + }() + fp, _ := os.Open(os.DevNull) + defer fp.Close() + io.Copy(fp, conn) + }() + } +} + +func Benchmark1K(b *testing.B) { + benchmark(b, 1024) +} + +func Benchmark2K(b *testing.B) { + benchmark(b, 1024*2) +} + +func Benchmark4K(b *testing.B) { + benchmark(b, 1024*4) +} + +func Benchmark8K(b *testing.B) { + benchmark(b, 1024*8) }