mirror of
https://github.com/bolucat/Archive.git
synced 2025-09-26 20:21:35 +08:00
1049 lines
30 KiB
Go
1049 lines
30 KiB
Go
// Copyright (C) 2023 mieru authors
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package protocol
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/enfein/mieru/pkg/appctl/appctlpb"
|
|
"github.com/enfein/mieru/pkg/cipher"
|
|
"github.com/enfein/mieru/pkg/congestion"
|
|
"github.com/enfein/mieru/pkg/log"
|
|
"github.com/enfein/mieru/pkg/mathext"
|
|
"github.com/enfein/mieru/pkg/metrics"
|
|
"github.com/enfein/mieru/pkg/stderror"
|
|
"github.com/enfein/mieru/pkg/util"
|
|
)
|
|
|
|
const (
|
|
segmentTreeCapacity = 4096
|
|
segmentChanCapacity = 1024
|
|
|
|
minWindowSize = 16
|
|
maxWindowSize = segmentTreeCapacity
|
|
|
|
serverRespTimeout = 10 * time.Second
|
|
sessionHeartbeatInterval = 5 * time.Second
|
|
|
|
earlyRetransmission = 3 // number of ack to trigger early retransmission
|
|
earlyRetransmissionLimit = 5 // maximum number of early retransmission attempt
|
|
txTimeoutBackOff = 1.25
|
|
maxBackOffMultiplier = 20.0
|
|
)
|
|
|
|
type sessionState byte
|
|
|
|
const (
|
|
sessionInit sessionState = 0
|
|
sessionAttached sessionState = 1
|
|
sessionEstablished sessionState = 2
|
|
sessionClosed sessionState = 3
|
|
)
|
|
|
|
func (ss sessionState) String() string {
|
|
switch ss {
|
|
case sessionInit:
|
|
return "INITIALIZED"
|
|
case sessionAttached:
|
|
return "ATTACHED"
|
|
case sessionEstablished:
|
|
return "ESTABLISHED"
|
|
case sessionClosed:
|
|
return "CLOSED"
|
|
default:
|
|
return "UNKNOWN"
|
|
}
|
|
}
|
|
|
|
type Session struct {
|
|
conn Underlay // underlay connection
|
|
block cipher.BlockCipher // cipher to encrypt and decrypt data
|
|
|
|
id uint32 // session ID number
|
|
isClient bool // if this session is owned by client
|
|
mtu int // L2 maxinum transmission unit
|
|
remoteAddr net.Addr // specify remote network address, used by UDP
|
|
state sessionState // session state
|
|
status statusCode // session status
|
|
users map[string]*appctlpb.User
|
|
|
|
ready chan struct{} // indicate the session is ready to use
|
|
done chan struct{} // indicate the session is complete
|
|
readDeadline time.Time // read deadline
|
|
writeDeadline time.Time // write deadline
|
|
inputErr chan error // input error
|
|
outputErr chan error // output error
|
|
|
|
sendQueue *segmentTree // segments waiting to send
|
|
sendBuf *segmentTree // segments sent but not acknowledged
|
|
recvBuf *segmentTree // segments received but acknowledge is not sent
|
|
recvQueue *segmentTree // segments waiting to be read by application
|
|
recvChan chan *segment // channel to receive segments from underlay
|
|
|
|
nextSend uint32 // next sequence number to send a segment
|
|
nextRecv uint32 // next sequence number to receive
|
|
lastRXTime time.Time // last timestamp when a segment is received
|
|
lastTXTime time.Time // last timestamp when a segment is sent
|
|
ackOnDataRecv atomic.Bool // whether ack should be sent due to receive of new data
|
|
unreadBuf []byte // payload removed from the recvQueue that haven't been read by application
|
|
|
|
uploadBytes metrics.Metric // number of bytes from client to server, only used by server
|
|
downloadBytes metrics.Metric // number of bytes from server to client, only used by server
|
|
|
|
rttStat *congestion.RTTStats
|
|
legacysendAlgorithm *congestion.CubicSendAlgorithm
|
|
sendAlgorithm *congestion.BBRSender
|
|
remoteWindowSize uint16
|
|
|
|
wg sync.WaitGroup
|
|
rLock sync.Mutex
|
|
wLock sync.Mutex
|
|
cLock sync.Mutex
|
|
sLock sync.Mutex
|
|
}
|
|
|
|
// Session must implement net.Conn interface.
|
|
var _ net.Conn = &Session{}
|
|
|
|
// NewSession creates a new session.
|
|
func NewSession(id uint32, isClient bool, mtu int) *Session {
|
|
rttStat := congestion.NewRTTStats()
|
|
rttStat.SetMaxAckDelay(outputLoopInterval)
|
|
rttStat.SetRTOMultiplier(txTimeoutBackOff)
|
|
return &Session{
|
|
conn: nil,
|
|
block: nil,
|
|
id: id,
|
|
isClient: isClient,
|
|
mtu: mtu,
|
|
state: sessionInit,
|
|
status: statusOK,
|
|
ready: make(chan struct{}),
|
|
done: make(chan struct{}),
|
|
readDeadline: time.Time{},
|
|
writeDeadline: time.Time{},
|
|
inputErr: make(chan error, 2), // allow nested
|
|
outputErr: make(chan error, 2), // allow nested
|
|
sendQueue: newSegmentTree(segmentTreeCapacity),
|
|
sendBuf: newSegmentTree(segmentTreeCapacity),
|
|
recvBuf: newSegmentTree(segmentTreeCapacity),
|
|
recvQueue: newSegmentTree(segmentTreeCapacity),
|
|
recvChan: make(chan *segment, segmentChanCapacity),
|
|
lastRXTime: time.Now(),
|
|
lastTXTime: time.Now(),
|
|
rttStat: rttStat,
|
|
legacysendAlgorithm: congestion.NewCubicSendAlgorithm(minWindowSize, maxWindowSize),
|
|
sendAlgorithm: congestion.NewBBRSender(fmt.Sprintf("%d", id), rttStat),
|
|
remoteWindowSize: minWindowSize,
|
|
}
|
|
}
|
|
|
|
func (s *Session) String() string {
|
|
if s.conn == nil {
|
|
return fmt.Sprintf("Session{id=%v}", s.id)
|
|
}
|
|
return fmt.Sprintf("Session{id=%v, local=%v, remote=%v}", s.id, s.LocalAddr(), s.RemoteAddr())
|
|
}
|
|
|
|
// Read lets a user to read data from receive queue.
|
|
func (s *Session) Read(b []byte) (n int, err error) {
|
|
s.rLock.Lock()
|
|
defer s.rLock.Unlock()
|
|
if s.isStateBefore(sessionAttached, false) {
|
|
return 0, fmt.Errorf("%v is not ready for Read()", s)
|
|
}
|
|
if s.isStateAfter(sessionClosed, true) {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
defer func() {
|
|
s.readDeadline = time.Time{}
|
|
}()
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v trying to read %d bytes", s, len(b))
|
|
}
|
|
|
|
// Read remaining data that application failed to read last time.
|
|
if len(s.unreadBuf) > 0 {
|
|
n = copy(b, s.unreadBuf)
|
|
if n == len(s.unreadBuf) {
|
|
s.unreadBuf = nil
|
|
} else {
|
|
s.unreadBuf = s.unreadBuf[n:]
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v read %d bytes", s, n)
|
|
}
|
|
if !s.isClient && s.uploadBytes != nil {
|
|
s.uploadBytes.Add(int64(n))
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Stop reading when deadline is reached.
|
|
var timeC <-chan time.Time
|
|
if !s.readDeadline.IsZero() {
|
|
timeC = time.After(time.Until(s.readDeadline))
|
|
}
|
|
|
|
for {
|
|
if s.recvQueue.Len() > 0 {
|
|
// Some segments in segment tree are ready to read.
|
|
for {
|
|
seg, ok := s.recvQueue.DeleteMin()
|
|
if !ok {
|
|
break
|
|
}
|
|
if s.isClient && seg.metadata.Protocol() == openSessionResponse && s.isState(sessionAttached) {
|
|
s.forwardStateTo(sessionEstablished)
|
|
}
|
|
if s.unreadBuf == nil {
|
|
s.unreadBuf = make([]byte, 0)
|
|
}
|
|
s.unreadBuf = append(s.unreadBuf, seg.payload...)
|
|
}
|
|
if len(s.unreadBuf) > 0 {
|
|
break
|
|
}
|
|
} else {
|
|
// Wait for incoming segments.
|
|
select {
|
|
case <-s.done:
|
|
return 0, io.EOF
|
|
case <-s.inputErr:
|
|
return 0, io.ErrUnexpectedEOF
|
|
case <-timeC:
|
|
return 0, stderror.ErrTimeout
|
|
case <-s.recvQueue.chanNotEmptyEvent:
|
|
// New segments are ready to read.
|
|
}
|
|
}
|
|
}
|
|
|
|
n = copy(b, s.unreadBuf)
|
|
if n == len(s.unreadBuf) {
|
|
s.unreadBuf = nil
|
|
} else {
|
|
s.unreadBuf = s.unreadBuf[n:]
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v read %d bytes", s, n)
|
|
}
|
|
if !s.isClient && s.uploadBytes != nil {
|
|
s.uploadBytes.Add(int64(n))
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Write stores the data to send queue.
|
|
func (s *Session) Write(b []byte) (n int, err error) {
|
|
s.wLock.Lock()
|
|
defer s.wLock.Unlock()
|
|
|
|
if s.isStateBefore(sessionAttached, false) {
|
|
return 0, fmt.Errorf("%v is not ready for Write()", s)
|
|
}
|
|
if s.isStateAfter(sessionClosed, true) {
|
|
return 0, io.ErrClosedPipe
|
|
}
|
|
defer func() {
|
|
s.writeDeadline = time.Time{}
|
|
}()
|
|
|
|
if s.isClient && s.isState(sessionAttached) {
|
|
// Before the first write, client needs to send open session request.
|
|
seg := &segment{
|
|
metadata: &sessionStruct{
|
|
baseStruct: baseStruct{
|
|
protocol: uint8(openSessionRequest),
|
|
},
|
|
sessionID: s.id,
|
|
seq: s.nextSend,
|
|
},
|
|
transport: s.conn.TransportProtocol(),
|
|
}
|
|
s.nextSend++
|
|
if len(b) <= MaxSessionOpenPayload {
|
|
seg.metadata.(*sessionStruct).payloadLen = uint16(len(b))
|
|
seg.payload = make([]byte, len(b))
|
|
copy(seg.payload, b)
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v writing %d bytes with open session request", s, len(seg.payload))
|
|
}
|
|
s.sendQueue.InsertBlocking(seg)
|
|
if len(seg.payload) > 0 {
|
|
return len(seg.payload), nil
|
|
}
|
|
}
|
|
|
|
n = len(b)
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v writing %d bytes", s, n)
|
|
}
|
|
for len(b) > 0 {
|
|
sizeToSend := mathext.Min(len(b), maxPDU)
|
|
if _, err = s.writeChunk(b[:sizeToSend]); err != nil {
|
|
return 0, err
|
|
}
|
|
b = b[sizeToSend:]
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v wrote %d bytes", s, n)
|
|
}
|
|
if !s.isClient && s.downloadBytes != nil {
|
|
s.downloadBytes.Add(int64(n))
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// Close terminates the session.
|
|
func (s *Session) Close() error {
|
|
s.cLock.Lock()
|
|
defer s.cLock.Unlock()
|
|
select {
|
|
case <-s.done:
|
|
s.forwardStateTo(sessionClosed)
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
log.Debugf("Closing %v", s)
|
|
s.sendQueue.DeleteAll()
|
|
s.sendBuf.DeleteAll()
|
|
s.recvBuf.DeleteAll()
|
|
s.recvQueue.DeleteAll()
|
|
if s.isState(sessionAttached) || s.isState(sessionEstablished) {
|
|
// Send closeSessionRequest, but don't wait for closeSessionResponse,
|
|
// because the underlay connection may be already broken.
|
|
// The closeSessionRequest won't be sent again.
|
|
seg := &segment{
|
|
metadata: &sessionStruct{
|
|
baseStruct: baseStruct{
|
|
protocol: uint8(closeSessionRequest),
|
|
},
|
|
sessionID: s.id,
|
|
seq: s.nextSend,
|
|
statusCode: uint8(s.status),
|
|
},
|
|
transport: s.conn.TransportProtocol(),
|
|
}
|
|
s.nextSend++
|
|
switch s.conn.TransportProtocol() {
|
|
case util.TCPTransport:
|
|
s.sendQueue.InsertBlocking(seg)
|
|
case util.UDPTransport:
|
|
if err := s.output(seg, s.RemoteAddr()); err != nil {
|
|
log.Debugf("output() failed: %v", err)
|
|
}
|
|
default:
|
|
log.Debugf("Unsupported transport protocol %v", s.conn.TransportProtocol())
|
|
}
|
|
}
|
|
|
|
// Wait for sendQueue to flush.
|
|
timeC := time.After(time.Second)
|
|
select {
|
|
case <-timeC:
|
|
case <-s.sendQueue.chanEmptyEvent:
|
|
}
|
|
|
|
s.forwardStateTo(sessionClosed)
|
|
close(s.done)
|
|
log.Debugf("Closed %v", s)
|
|
metrics.CurrEstablished.Add(-1)
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) LocalAddr() net.Addr {
|
|
return s.conn.LocalAddr()
|
|
}
|
|
|
|
func (s *Session) RemoteAddr() net.Addr {
|
|
if !util.IsNilNetAddr(s.remoteAddr) {
|
|
return s.remoteAddr
|
|
}
|
|
return s.conn.RemoteAddr()
|
|
}
|
|
|
|
// SetDeadline implements net.Conn.
|
|
func (s *Session) SetDeadline(t time.Time) error {
|
|
s.readDeadline = t
|
|
s.writeDeadline = t
|
|
return nil
|
|
}
|
|
|
|
// SetReadDeadline implements net.Conn.
|
|
func (s *Session) SetReadDeadline(t time.Time) error {
|
|
s.readDeadline = t
|
|
return nil
|
|
}
|
|
|
|
// SetWriteDeadline implements net.Conn.
|
|
func (s *Session) SetWriteDeadline(t time.Time) error {
|
|
s.writeDeadline = t
|
|
return nil
|
|
}
|
|
|
|
// ToSessionInfo creates related SessionInfo structure.
|
|
func (s *Session) ToSessionInfo() SessionInfo {
|
|
info := SessionInfo{
|
|
ID: fmt.Sprintf("%d", s.id),
|
|
LocalAddr: s.LocalAddr().String(),
|
|
RemoteAddr: s.RemoteAddr().String(),
|
|
State: s.state.String(),
|
|
RecvQBuf: fmt.Sprintf("%d+%d", s.recvQueue.Len(), s.recvBuf.Len()),
|
|
SendQBuf: fmt.Sprintf("%d+%d", s.sendQueue.Len(), s.sendBuf.Len()),
|
|
LastSend: fmt.Sprintf("%v (%d)", time.Since(s.lastTXTime).Truncate(time.Second), s.nextSend-1),
|
|
}
|
|
if _, ok := s.conn.(*TCPUnderlay); ok {
|
|
info.Protocol = "TCP"
|
|
info.LastRecv = fmt.Sprintf("%v", time.Since(s.lastRXTime).Truncate(time.Second)) // TCP nextRecv is not used
|
|
} else if _, ok := s.conn.(*UDPUnderlay); ok {
|
|
info.Protocol = "UDP"
|
|
info.LastRecv = fmt.Sprintf("%v (%d)", time.Since(s.lastRXTime).Truncate(time.Second), s.nextRecv-1)
|
|
} else {
|
|
info.Protocol = "UNKNOWN"
|
|
}
|
|
return info
|
|
}
|
|
|
|
func (s *Session) isState(target sessionState) bool {
|
|
s.sLock.Lock()
|
|
defer s.sLock.Unlock()
|
|
return s.state == target
|
|
}
|
|
|
|
func (s *Session) isStateBefore(target sessionState, include bool) bool {
|
|
s.sLock.Lock()
|
|
defer s.sLock.Unlock()
|
|
if include {
|
|
return s.state <= target
|
|
} else {
|
|
return s.state < target
|
|
}
|
|
}
|
|
|
|
func (s *Session) isStateAfter(target sessionState, include bool) bool {
|
|
s.sLock.Lock()
|
|
defer s.sLock.Unlock()
|
|
if include {
|
|
return s.state >= target
|
|
} else {
|
|
return s.state > target
|
|
}
|
|
}
|
|
|
|
func (s *Session) forwardStateTo(new sessionState) {
|
|
s.sLock.Lock()
|
|
defer s.sLock.Unlock()
|
|
if new < s.state {
|
|
panic(fmt.Sprintf("Can't move state back from %v to %v", s.state, new))
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) && s.state != new {
|
|
log.Tracef("%v %v => %v", s, s.state, new)
|
|
}
|
|
s.state = new
|
|
}
|
|
|
|
func (s *Session) writeChunk(b []byte) (n int, err error) {
|
|
if len(b) > maxPDU {
|
|
return 0, io.ErrShortWrite
|
|
}
|
|
|
|
// Stop writing when deadline is reached.
|
|
var timeC <-chan time.Time
|
|
if !s.writeDeadline.IsZero() {
|
|
timeC = time.After(time.Until(s.writeDeadline))
|
|
}
|
|
|
|
seqBeforeWrite, _ := s.sendQueue.MinSeq()
|
|
|
|
// Determine number of fragments to write.
|
|
nFragment := 1
|
|
fragmentSize := MaxFragmentSize(s.mtu, s.conn.IPVersion(), s.conn.TransportProtocol())
|
|
if len(b) > fragmentSize {
|
|
nFragment = (len(b)-1)/fragmentSize + 1
|
|
}
|
|
|
|
ptr := b
|
|
for i := nFragment - 1; i >= 0; i-- {
|
|
select {
|
|
case <-s.done:
|
|
return 0, io.EOF
|
|
case <-s.outputErr:
|
|
return 0, io.ErrClosedPipe
|
|
case <-timeC:
|
|
return 0, stderror.ErrTimeout
|
|
default:
|
|
}
|
|
var protocol uint8
|
|
if s.isClient {
|
|
protocol = uint8(dataClientToServer)
|
|
} else {
|
|
protocol = uint8(dataServerToClient)
|
|
}
|
|
partLen := mathext.Min(fragmentSize, len(ptr))
|
|
part := ptr[:partLen]
|
|
seg := &segment{
|
|
metadata: &dataAckStruct{
|
|
baseStruct: baseStruct{
|
|
protocol: protocol,
|
|
},
|
|
sessionID: s.id,
|
|
seq: s.nextSend,
|
|
unAckSeq: s.nextRecv,
|
|
windowSize: uint16(mathext.Max(0, int(s.legacysendAlgorithm.CongestionWindowSize())-s.recvBuf.Len())),
|
|
fragment: uint8(i),
|
|
payloadLen: uint16(partLen),
|
|
},
|
|
payload: make([]byte, partLen),
|
|
transport: s.conn.TransportProtocol(),
|
|
}
|
|
copy(seg.payload, part)
|
|
s.nextSend++
|
|
s.sendQueue.InsertBlocking(seg)
|
|
ptr = ptr[partLen:]
|
|
}
|
|
|
|
// To create back pressure, wait until sendQueue is moving.
|
|
for {
|
|
shouldReturn := false
|
|
select {
|
|
case <-s.sendQueue.chanEmptyEvent:
|
|
shouldReturn = true
|
|
case <-s.sendQueue.chanNotEmptyEvent:
|
|
seqAfterWrite, err := s.sendQueue.MinSeq()
|
|
if err != nil || seqAfterWrite > seqBeforeWrite {
|
|
shouldReturn = true
|
|
}
|
|
}
|
|
if shouldReturn {
|
|
break
|
|
}
|
|
}
|
|
|
|
if s.isClient {
|
|
s.readDeadline = time.Now().Add(serverRespTimeout)
|
|
}
|
|
return len(b), nil
|
|
}
|
|
|
|
func (s *Session) runInputLoop(ctx context.Context) error {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-s.done:
|
|
return nil
|
|
case seg := <-s.recvChan:
|
|
if err := s.input(seg); err != nil {
|
|
err = fmt.Errorf("input() failed: %w", err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.inputErr <- err
|
|
s.Close()
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Session) runOutputLoop(ctx context.Context) error {
|
|
ticker := time.NewTicker(outputLoopInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-s.done:
|
|
return nil
|
|
case <-ticker.C:
|
|
case <-s.sendQueue.chanNotEmptyEvent:
|
|
}
|
|
|
|
switch s.conn.TransportProtocol() {
|
|
case util.TCPTransport:
|
|
for {
|
|
seg, ok := s.sendQueue.DeleteMin()
|
|
if !ok {
|
|
break
|
|
}
|
|
if err := s.output(seg, nil); err != nil {
|
|
err = fmt.Errorf("output() failed: %w", err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
s.Close()
|
|
break
|
|
}
|
|
}
|
|
case util.UDPTransport:
|
|
closeSession := false
|
|
hasLoss := false
|
|
hasTimeout := false
|
|
|
|
// Resend segments in sendBuf.
|
|
// To avoid deadlock, session can't be closed inside Ascend().
|
|
var bytesInFlight int64
|
|
s.sendBuf.Ascend(func(iter *segment) bool {
|
|
bytesInFlight += int64(udpOverhead + len(iter.payload))
|
|
if iter.txCount >= txCountLimit {
|
|
err := fmt.Errorf("too many retransmission of %v", iter)
|
|
log.Debugf("%v is unhealthy: %v", s, err)
|
|
s.outputErr <- err
|
|
closeSession = true
|
|
return false
|
|
}
|
|
if (iter.ackCount >= earlyRetransmission && iter.txCount <= earlyRetransmissionLimit) || time.Since(iter.txTime) > iter.txTimeout {
|
|
if iter.ackCount >= earlyRetransmission {
|
|
hasLoss = true
|
|
} else {
|
|
hasTimeout = true
|
|
}
|
|
iter.ackCount = 0
|
|
iter.txCount++
|
|
iter.txTime = time.Now()
|
|
iter.txTimeout = s.rttStat.RTO() * time.Duration(mathext.Min(math.Pow(txTimeoutBackOff, float64(iter.txCount)), maxBackOffMultiplier))
|
|
if isDataAckProtocol(iter.metadata.Protocol()) {
|
|
das, _ := toDataAckStruct(iter.metadata)
|
|
das.unAckSeq = s.nextRecv
|
|
}
|
|
if err := s.output(iter, s.RemoteAddr()); err != nil {
|
|
err = fmt.Errorf("output() failed: %w", err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
closeSession = true
|
|
return false
|
|
}
|
|
bytesInFlight += int64(udpOverhead + len(iter.payload))
|
|
return true
|
|
}
|
|
return true
|
|
})
|
|
if closeSession {
|
|
s.Close()
|
|
}
|
|
if hasLoss || hasTimeout {
|
|
s.legacysendAlgorithm.OnLoss() // OnTimeout() is too aggressive.
|
|
}
|
|
|
|
// Send new segments in sendQueue.
|
|
if s.sendQueue.Len() > 0 {
|
|
for {
|
|
seg, deleted := s.sendQueue.DeleteMinIf(func(iter *segment) bool {
|
|
return s.sendAlgorithm.CanSend(bytesInFlight, int64(udpOverhead+len(iter.payload)))
|
|
})
|
|
if !deleted {
|
|
break
|
|
}
|
|
seg.txCount++
|
|
seg.txTime = time.Now()
|
|
seg.txTimeout = s.rttStat.RTO() * time.Duration(mathext.Min(math.Pow(txTimeoutBackOff, float64(seg.txCount)), maxBackOffMultiplier))
|
|
if isDataAckProtocol(seg.metadata.Protocol()) {
|
|
das, _ := toDataAckStruct(seg.metadata)
|
|
das.unAckSeq = s.nextRecv
|
|
}
|
|
s.sendBuf.InsertBlocking(seg)
|
|
if err := s.output(seg, s.RemoteAddr()); err != nil {
|
|
err = fmt.Errorf("output() failed: %w", err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
s.Close()
|
|
break
|
|
} else {
|
|
seq, err := seg.Seq()
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to get sequence number from %v: %w", seg, err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
s.Close()
|
|
break
|
|
}
|
|
newBytesInFlight := int64(udpOverhead + len(seg.payload))
|
|
s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true)
|
|
bytesInFlight += newBytesInFlight
|
|
}
|
|
}
|
|
} else {
|
|
s.sendAlgorithm.OnApplicationLimited(bytesInFlight)
|
|
}
|
|
|
|
// Send ACK or heartbeat if needed.
|
|
exceedHeartbeatInterval := time.Since(s.lastTXTime) > sessionHeartbeatInterval
|
|
if s.ackOnDataRecv.Load() || exceedHeartbeatInterval {
|
|
baseStruct := baseStruct{}
|
|
if s.isClient {
|
|
baseStruct.protocol = uint8(ackClientToServer)
|
|
} else {
|
|
baseStruct.protocol = uint8(ackServerToClient)
|
|
}
|
|
ackSeg := &segment{
|
|
metadata: &dataAckStruct{
|
|
baseStruct: baseStruct,
|
|
sessionID: s.id,
|
|
seq: uint32(mathext.Max(0, int(s.nextSend)-1)),
|
|
unAckSeq: s.nextRecv,
|
|
windowSize: uint16(mathext.Max(0, int(s.legacysendAlgorithm.CongestionWindowSize())-s.recvBuf.Len())),
|
|
},
|
|
transport: s.conn.TransportProtocol(),
|
|
}
|
|
if err := s.output(ackSeg, s.RemoteAddr()); err != nil {
|
|
err = fmt.Errorf("output() failed: %w", err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
s.Close()
|
|
} else {
|
|
seq, err := ackSeg.Seq()
|
|
if err != nil {
|
|
err = fmt.Errorf("failed to get sequence number from %v: %w", ackSeg, err)
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
s.Close()
|
|
}
|
|
newBytesInFlight := int64(udpOverhead + len(ackSeg.payload))
|
|
s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true)
|
|
bytesInFlight += newBytesInFlight
|
|
}
|
|
s.ackOnDataRecv.Store(false)
|
|
}
|
|
default:
|
|
err := fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol())
|
|
log.Debugf("%v %v", s, err)
|
|
s.outputErr <- err
|
|
s.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
// input reads incoming packets from network and assemble
|
|
// them in the receive buffer and receive queue.
|
|
func (s *Session) input(seg *segment) error {
|
|
protocol := seg.Protocol()
|
|
if s.isClient {
|
|
if protocol != openSessionResponse && protocol != dataServerToClient && protocol != ackServerToClient && protocol != closeSessionRequest && protocol != closeSessionResponse {
|
|
return stderror.ErrInvalidArgument
|
|
}
|
|
} else {
|
|
if protocol != openSessionRequest && protocol != dataClientToServer && protocol != ackClientToServer && protocol != closeSessionRequest && protocol != closeSessionResponse {
|
|
return stderror.ErrInvalidArgument
|
|
}
|
|
}
|
|
if seg.block != nil {
|
|
s.block = seg.block
|
|
if !s.isClient {
|
|
if s.uploadBytes == nil && s.block.BlockContext().UserName != "" {
|
|
s.uploadBytes = metrics.RegisterMetric(fmt.Sprintf(metrics.UserMetricGroupFormat, s.block.BlockContext().UserName), metrics.UserMetricUploadBytes, metrics.COUNTER_TIME_SERIES)
|
|
}
|
|
if s.downloadBytes == nil && s.block.BlockContext().UserName != "" {
|
|
s.downloadBytes = metrics.RegisterMetric(fmt.Sprintf(metrics.UserMetricGroupFormat, s.block.BlockContext().UserName), metrics.UserMetricDownloadBytes, metrics.COUNTER_TIME_SERIES)
|
|
}
|
|
}
|
|
}
|
|
s.lastRXTime = time.Now()
|
|
if protocol == openSessionRequest || protocol == openSessionResponse || protocol == dataServerToClient || protocol == dataClientToServer {
|
|
return s.inputData(seg)
|
|
} else if protocol == ackServerToClient || protocol == ackClientToServer {
|
|
return s.inputAck(seg)
|
|
} else if protocol == closeSessionRequest || protocol == closeSessionResponse {
|
|
return s.inputClose(seg)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) inputData(seg *segment) error {
|
|
switch s.conn.TransportProtocol() {
|
|
case util.TCPTransport:
|
|
// Deliver the segment directly to recvQueue.
|
|
s.recvQueue.InsertBlocking(seg)
|
|
case util.UDPTransport:
|
|
// Delete all previous acknowledged segments from sendBuf.
|
|
var priorInFlight int64
|
|
s.sendBuf.Ascend(func(iter *segment) bool {
|
|
priorInFlight += int64(udpOverhead + len(iter.payload))
|
|
return true
|
|
})
|
|
das, ok := seg.metadata.(*dataAckStruct)
|
|
if ok {
|
|
unAckSeq := das.unAckSeq
|
|
ackedPackets := make([]congestion.AckedPacketInfo, 0)
|
|
for {
|
|
seg2, deleted := s.sendBuf.DeleteMinIf(func(iter *segment) bool {
|
|
seq, _ := iter.Seq()
|
|
return seq < unAckSeq
|
|
})
|
|
if !deleted {
|
|
break
|
|
}
|
|
s.rttStat.UpdateRTT(time.Since(seg2.txTime))
|
|
s.legacysendAlgorithm.OnAck()
|
|
seq, _ := seg2.Seq()
|
|
ackedPackets = append(ackedPackets, congestion.AckedPacketInfo{
|
|
PacketNumber: int64(seq),
|
|
BytesAcked: int64(udpOverhead + len(seg2.payload)),
|
|
ReceiveTimestamp: time.Now(),
|
|
})
|
|
}
|
|
if len(ackedPackets) > 0 {
|
|
s.sendAlgorithm.OnCongestionEvent(priorInFlight, time.Now(), ackedPackets, nil)
|
|
}
|
|
s.remoteWindowSize = das.windowSize
|
|
}
|
|
|
|
// Deliver the segment to recvBuf.
|
|
s.recvBuf.InsertBlocking(seg)
|
|
|
|
// Move recvBuf to recvQueue.
|
|
for {
|
|
seg3, deleted := s.recvBuf.DeleteMinIf(func(iter *segment) bool {
|
|
seq, _ := iter.Seq()
|
|
return seq <= s.nextRecv
|
|
})
|
|
if seg3 == nil || !deleted {
|
|
break
|
|
}
|
|
seq, _ := seg3.Seq()
|
|
if seq == s.nextRecv {
|
|
s.recvQueue.InsertBlocking(seg3)
|
|
s.nextRecv++
|
|
das, ok := seg3.metadata.(*dataAckStruct)
|
|
if ok {
|
|
s.remoteWindowSize = das.windowSize
|
|
}
|
|
}
|
|
}
|
|
s.ackOnDataRecv.Store(true)
|
|
default:
|
|
return fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol())
|
|
}
|
|
|
|
if !s.isClient && seg.metadata.Protocol() == openSessionRequest {
|
|
s.wLock.Lock()
|
|
if s.isState(sessionAttached) {
|
|
// Server needs to send open session response.
|
|
// Check user quota if we can identify the user.
|
|
var userName string
|
|
if seg.block != nil && seg.block.BlockContext().UserName != "" {
|
|
userName = seg.block.BlockContext().UserName
|
|
} else if s.block != nil && s.block.BlockContext().UserName != "" {
|
|
userName = s.block.BlockContext().UserName
|
|
}
|
|
if userName != "" {
|
|
quotaOK, err := s.checkQuota(userName)
|
|
if err != nil {
|
|
log.Debugf("%v checkQuota() failed: %v", s, err)
|
|
}
|
|
if !quotaOK {
|
|
s.status = statusQuotaExhausted
|
|
log.Debugf("Closing %v because user %s used all the quota", s, userName)
|
|
s.wLock.Unlock()
|
|
s.Close()
|
|
return nil
|
|
}
|
|
}
|
|
seg4 := &segment{
|
|
metadata: &sessionStruct{
|
|
baseStruct: baseStruct{
|
|
protocol: uint8(openSessionResponse),
|
|
},
|
|
sessionID: s.id,
|
|
seq: s.nextSend,
|
|
},
|
|
transport: s.conn.TransportProtocol(),
|
|
}
|
|
s.nextSend++
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v writing open session response", s)
|
|
}
|
|
s.sendQueue.InsertBlocking(seg4)
|
|
s.forwardStateTo(sessionEstablished)
|
|
}
|
|
s.wLock.Unlock()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) inputAck(seg *segment) error {
|
|
switch s.conn.TransportProtocol() {
|
|
case util.TCPTransport:
|
|
// Do nothing when receive ACK from TCP protocol.
|
|
return nil
|
|
case util.UDPTransport:
|
|
// Delete all previous acknowledged segments from sendBuf.
|
|
var priorInFlight int64
|
|
s.sendBuf.Ascend(func(iter *segment) bool {
|
|
priorInFlight += int64(udpOverhead + len(iter.payload))
|
|
return true
|
|
})
|
|
das := seg.metadata.(*dataAckStruct)
|
|
unAckSeq := das.unAckSeq
|
|
ackedPackets := make([]congestion.AckedPacketInfo, 0)
|
|
for {
|
|
seg2, deleted := s.sendBuf.DeleteMinIf(func(iter *segment) bool {
|
|
seq, _ := iter.Seq()
|
|
return seq < unAckSeq
|
|
})
|
|
if !deleted {
|
|
break
|
|
}
|
|
s.rttStat.UpdateRTT(time.Since(seg2.txTime))
|
|
s.legacysendAlgorithm.OnAck()
|
|
seq, _ := seg2.Seq()
|
|
ackedPackets = append(ackedPackets, congestion.AckedPacketInfo{
|
|
PacketNumber: int64(seq),
|
|
BytesAcked: int64(udpOverhead + len(seg2.payload)),
|
|
ReceiveTimestamp: time.Now(),
|
|
})
|
|
}
|
|
if len(ackedPackets) > 0 {
|
|
s.sendAlgorithm.OnCongestionEvent(priorInFlight, time.Now(), ackedPackets, nil)
|
|
}
|
|
s.remoteWindowSize = das.windowSize
|
|
|
|
// Update acknowledge count.
|
|
s.sendBuf.Ascend(func(iter *segment) bool {
|
|
seq, _ := iter.Seq()
|
|
if seq > unAckSeq {
|
|
return false
|
|
}
|
|
if seq == unAckSeq {
|
|
iter.ackCount++
|
|
}
|
|
return true
|
|
})
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol())
|
|
}
|
|
}
|
|
|
|
func (s *Session) inputClose(seg *segment) error {
|
|
s.wLock.Lock()
|
|
if seg.metadata.Protocol() == closeSessionRequest {
|
|
// Send close session response.
|
|
seg2 := &segment{
|
|
metadata: &sessionStruct{
|
|
baseStruct: baseStruct{
|
|
protocol: uint8(closeSessionResponse),
|
|
},
|
|
sessionID: s.id,
|
|
seq: s.nextSend,
|
|
statusCode: uint8(statusOK),
|
|
payloadLen: 0,
|
|
},
|
|
transport: s.conn.TransportProtocol(),
|
|
}
|
|
s.nextSend++
|
|
// The response will not retry if it is not delivered.
|
|
if err := s.output(seg2, s.RemoteAddr()); err != nil {
|
|
s.wLock.Unlock()
|
|
return fmt.Errorf("output() failed: %v", err)
|
|
}
|
|
// Immediately shutdown event loop.
|
|
if seg.metadata.(*sessionStruct).statusCode == uint8(statusQuotaExhausted) {
|
|
log.Infof("Remote requested to shut down the session because user has exhausted quota")
|
|
} else {
|
|
log.Debugf("Remote requested to shut down %v", s)
|
|
}
|
|
s.wLock.Unlock()
|
|
s.Close()
|
|
} else if seg.metadata.Protocol() == closeSessionResponse {
|
|
// Immediately shutdown event loop.
|
|
log.Debugf("Remote received the request from %v to shut down", s)
|
|
s.wLock.Unlock()
|
|
s.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) output(seg *segment, remoteAddr net.Addr) error {
|
|
switch s.conn.TransportProtocol() {
|
|
case util.TCPTransport:
|
|
if err := s.conn.(*TCPUnderlay).writeOneSegment(seg); err != nil {
|
|
return fmt.Errorf("TCPUnderlay.writeOneSegment() failed: %v", err)
|
|
}
|
|
case util.UDPTransport:
|
|
err := s.conn.(*UDPUnderlay).writeOneSegment(seg, remoteAddr.(*net.UDPAddr))
|
|
if err != nil {
|
|
if !stderror.IsNotReady(err) {
|
|
return fmt.Errorf("UDPUnderlay.writeOneSegment() failed: %v", err)
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("UDPUnderlay.writeOneSegment() failed: %v. Will retry later.", err)
|
|
}
|
|
return nil
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported transport protocol %v", s.conn.TransportProtocol())
|
|
}
|
|
s.lastTXTime = time.Now()
|
|
return nil
|
|
}
|
|
|
|
func (s *Session) checkQuota(userName string) (ok bool, err error) {
|
|
if len(s.users) == 0 {
|
|
return true, fmt.Errorf("no registered user")
|
|
}
|
|
user, found := s.users[userName]
|
|
if !found {
|
|
return true, fmt.Errorf("user %s is not found", userName)
|
|
}
|
|
if len(user.GetQuotas()) == 0 {
|
|
return true, nil
|
|
}
|
|
|
|
metricGroupName := fmt.Sprintf(metrics.UserMetricGroupFormat, userName)
|
|
metricGroup := metrics.GetMetricGroupByName(metricGroupName)
|
|
if metricGroup == nil {
|
|
return true, fmt.Errorf("metric group %s is not found", metricGroupName)
|
|
}
|
|
uploadBytes, found := metricGroup.GetMetric(metrics.UserMetricUploadBytes)
|
|
if !found {
|
|
return true, fmt.Errorf("metric %s in group %s is not found", metrics.UserMetricUploadBytes, metricGroupName)
|
|
}
|
|
downloadBytes, found := metricGroup.GetMetric(metrics.UserMetricDownloadBytes)
|
|
if !found {
|
|
return true, fmt.Errorf("metric %s in group %s is not found", metrics.UserMetricDownloadBytes, metricGroupName)
|
|
}
|
|
for _, quota := range user.GetQuotas() {
|
|
now := time.Now()
|
|
then := now.Add(-time.Duration(quota.GetDays()) * 24 * time.Hour)
|
|
totalBytes := uploadBytes.(*metrics.Counter).DeltaBetween(then, now)
|
|
totalBytes += downloadBytes.(*metrics.Counter).DeltaBetween(then, now)
|
|
if totalBytes/1048576 > int64(quota.GetMegabytes()) {
|
|
return false, nil
|
|
}
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// SessionInfo provides a string representation of a Session.
|
|
type SessionInfo struct {
|
|
ID string
|
|
Protocol string
|
|
LocalAddr string
|
|
RemoteAddr string
|
|
State string
|
|
RecvQBuf string
|
|
SendQBuf string
|
|
LastRecv string
|
|
LastSend string
|
|
}
|