Files
WireGold/gold/p2p/tcp/pdu.go
2025-03-12 22:20:02 +09:00

134 lines
2.6 KiB
Go

package tcp
import (
"encoding/binary"
"errors"
"io"
"net"
"time"
"github.com/fumiama/WireGold/config"
"github.com/fumiama/WireGold/internal/bin"
"github.com/sirupsen/logrus"
)
var (
ErrInvalidMagic = errors.New("invalid magic")
)
type packetType uint8
const (
packetTypeKeepAlive packetType = iota
packetTypeNormal
packetTypeSubKeepAlive
packetTypeTop
)
var (
magicbuf = []byte("GET ")
magic = binary.LittleEndian.Uint32(magicbuf)
)
type packet struct {
typ packetType
len uint16
dat []byte
io.ReaderFrom
io.WriterTo
}
func (p *packet) pack() *net.Buffers {
return &net.Buffers{magicbuf, bin.NewWriterF(func(w *bin.Writer) {
w.WriteByte(byte(p.typ))
w.WriteUInt16(p.len)
}).Trans(), p.dat}
}
func (p *packet) Read(_ []byte) (int, error) {
panic("stub")
}
func (p *packet) Write(_ []byte) (int, error) {
panic("stub")
}
func (p *packet) ReadFrom(r io.Reader) (n int64, err error) {
var buf [4]byte
cnt, err := io.ReadFull(r, buf[:])
n = int64(cnt)
if err != nil {
return
}
if binary.LittleEndian.Uint32(buf[:]) != magic {
err = ErrInvalidMagic
if config.ShowDebugLog {
logrus.Debugf("[tcp] expect magic %08x but got %08x", magic, binary.LittleEndian.Uint32(buf[:]))
}
return
}
cnt, err = io.ReadFull(r, buf[:3])
n += int64(cnt)
if err != nil {
return
}
p.typ = packetType(buf[0])
p.len = binary.LittleEndian.Uint16(buf[1:3])
w := bin.SelectWriter()
copied, err := io.CopyN(w, r, int64(p.len))
n += copied
if err != nil {
return
}
p.dat = w.ToBytes().Trans()
return
}
func (p *packet) WriteTo(w io.Writer) (n int64, err error) {
return io.Copy(w, p.pack())
}
func isvalid(tcpconn *net.TCPConn, timeout time.Duration) (issub, ok bool) {
pckt := packet{}
stopch := make(chan struct{})
t := time.AfterFunc(timeout, func() {
stopch <- struct{}{}
})
var err error
copych := make(chan struct{})
go func() {
_, err = io.Copy(&pckt, tcpconn)
copych <- struct{}{}
}()
select {
case <-stopch:
if config.ShowDebugLog {
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "timeout")
}
return
case <-copych:
t.Stop()
}
if err != nil {
if config.ShowDebugLog {
logrus.Debugln("[tcp] validate recv from", tcpconn.RemoteAddr(), "err:", err)
}
return
}
if pckt.typ != packetTypeKeepAlive && pckt.typ != packetTypeSubKeepAlive {
if config.ShowDebugLog {
logrus.Debugln("[tcp] validate got invalid typ", pckt.typ, "from", tcpconn.RemoteAddr())
}
return
}
if config.ShowDebugLog {
logrus.Debugln("[tcp] passed validate recv from", tcpconn.RemoteAddr())
}
return pckt.typ == packetTypeSubKeepAlive, true
}