mirror of
https://github.com/fumiama/WireGold.git
synced 2025-09-26 19:21:11 +08:00
134 lines
2.6 KiB
Go
134 lines
2.6 KiB
Go
package tcp
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/fumiama/WireGold/config"
|
|
"github.com/fumiama/WireGold/internal/bin"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidMagic = errors.New("invalid magic")
|
|
)
|
|
|
|
type packetType uint8
|
|
|
|
const (
|
|
packetTypeKeepAlive packetType = iota
|
|
packetTypeNormal
|
|
packetTypeSubKeepAlive
|
|
packetTypeTop
|
|
)
|
|
|
|
var (
|
|
magicbuf = []byte("GET ")
|
|
magic = binary.LittleEndian.Uint32(magicbuf)
|
|
)
|
|
|
|
type packet struct {
|
|
typ packetType
|
|
len uint16
|
|
dat []byte
|
|
io.ReaderFrom
|
|
io.WriterTo
|
|
}
|
|
|
|
func (p *packet) pack() *net.Buffers {
|
|
return &net.Buffers{magicbuf, bin.NewWriterF(func(w *bin.Writer) {
|
|
w.WriteByte(byte(p.typ))
|
|
w.WriteUInt16(p.len)
|
|
}).Trans(), p.dat}
|
|
}
|
|
|
|
func (p *packet) Read(_ []byte) (int, error) {
|
|
panic("stub")
|
|
}
|
|
|
|
func (p *packet) Write(_ []byte) (int, error) {
|
|
panic("stub")
|
|
}
|
|
|
|
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
|
|
var buf [4]byte
|
|
cnt, err := io.ReadFull(r, buf[:])
|
|
n = int64(cnt)
|
|
if err != nil {
|
|
return
|
|
}
|
|
if binary.LittleEndian.Uint32(buf[:]) != magic {
|
|
err = ErrInvalidMagic
|
|
if config.ShowDebugLog {
|
|
logrus.Debugf("[tcp] expect magic %08x but got %08x", magic, binary.LittleEndian.Uint32(buf[:]))
|
|
}
|
|
return
|
|
}
|
|
cnt, err = io.ReadFull(r, buf[:3])
|
|
n += int64(cnt)
|
|
if err != nil {
|
|
return
|
|
}
|
|
p.typ = packetType(buf[0])
|
|
p.len = binary.LittleEndian.Uint16(buf[1:3])
|
|
w := bin.SelectWriter()
|
|
copied, err := io.CopyN(w, r, int64(p.len))
|
|
n += copied
|
|
if err != nil {
|
|
return
|
|
}
|
|
p.dat = w.ToBytes().Trans()
|
|
return
|
|
}
|
|
|
|
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
|
|
return io.Copy(w, p.pack())
|
|
}
|
|
|
|
func isvalid(tcpconn *net.TCPConn, timeout time.Duration) (issub, ok bool) {
|
|
pckt := packet{}
|
|
|
|
stopch := make(chan struct{})
|
|
t := time.AfterFunc(timeout, func() {
|
|
stopch <- struct{}{}
|
|
})
|
|
|
|
var err error
|
|
copych := make(chan struct{})
|
|
go func() {
|
|
_, err = io.Copy(&pckt, tcpconn)
|
|
copych <- struct{}{}
|
|
}()
|
|
|
|
select {
|
|
case <-stopch:
|
|
if config.ShowDebugLog {
|
|
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
|
|
}
|
|
return
|
|
case <-copych:
|
|
t.Stop()
|
|
}
|
|
|
|
if err != nil {
|
|
if config.ShowDebugLog {
|
|
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
|
|
}
|
|
return
|
|
}
|
|
if pckt.typ != packetTypeKeepAlive && pckt.typ != packetTypeSubKeepAlive {
|
|
if config.ShowDebugLog {
|
|
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
|
|
}
|
|
return
|
|
}
|
|
|
|
if config.ShowDebugLog {
|
|
logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
|
|
}
|
|
return pckt.typ == packetTypeSubKeepAlive, true
|
|
}
|