diff --git a/handshake.go b/handshake.go index 6cdd336..26cacdb 100644 --- a/handshake.go +++ b/handshake.go @@ -10,11 +10,12 @@ import ( "encoding/binary" "errors" "fmt" - "github.com/zhangpeihao/log" "io" "math/rand" "net" "time" + + "github.com/zhangpeihao/log" ) const ( @@ -272,6 +273,10 @@ func Handshake(c net.Conn, br *bufio.Reader, bw *bufio.Writer, timeout time.Dura err = bw.Flush() CheckError(err, "Handshake() Flush C2") + if timeout > 0 { + c.SetDeadline(time.Time{}) + } + return } @@ -354,18 +359,21 @@ func SHandshake(c net.Conn, br *bufio.Reader, bw *bufio.Writer, timeout time.Dur CheckError(err, "SHandshake() Send S2") if timeout > 0 { - // c.SetWriteDeadline(time.Now().Add(timeout)) + c.SetWriteDeadline(time.Now().Add(timeout)) } err = bw.Flush() CheckError(err, "SHandshake() Flush S2") // Read C2 if timeout > 0 { - // c.SetReadDeadline(time.Now().Add(timeout)) + c.SetReadDeadline(time.Now().Add(timeout)) } c2 := make([]byte, RTMP_SIG_SIZE) _, err = io.ReadAtLeast(br, c2, RTMP_SIG_SIZE) CheckError(err, "SHandshake() Read C2") // TODO: check C2 + if timeout > 0 { + c.SetDeadline(time.Time{}) + } return }