mirror of
https://github.com/bolucat/Archive.git
synced 2025-12-24 13:28:37 +08:00
Update On Sat Nov 15 19:34:53 CET 2025
This commit is contained in:
@@ -104,8 +104,8 @@ type ClientConfig struct {
|
||||
|
||||
// If set, the resolver translates proxy server domain name into IP addresses.
|
||||
//
|
||||
// This field is not required, if Dialer is able to do DNS, or proxy server
|
||||
// endpoints are IP addresses rather than domain names.
|
||||
// This field is not required, if Dialer or PacketDialer is able to do DNS,
|
||||
// or proxy server endpoints are IP addresses rather than domain names.
|
||||
// Otherwise, the proxy server won't be reachable.
|
||||
Resolver apicommon.DNSResolver
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ const (
|
||||
oneMinusAlpha = 1 - rttAlpha
|
||||
rttBeta = 0.25
|
||||
oneMinusBeta = 1 - rttBeta
|
||||
defaultInitialRTT = 500 * time.Millisecond
|
||||
defaultInitialRTT = time.Second
|
||||
infDuration = time.Duration(math.MaxInt64)
|
||||
)
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ func TestRTO(t *testing.T) {
|
||||
|
||||
// Reset measurement.
|
||||
s.Reset()
|
||||
if s.RTO() != 1000*time.Millisecond {
|
||||
t.Errorf("RTO() = %v, want %v", s.RTO(), 1000*time.Millisecond)
|
||||
if s.RTO() != 2000*time.Millisecond {
|
||||
t.Errorf("RTO() = %v, want %v", s.RTO(), 2000*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,6 +510,7 @@ func (m *Mux) acceptUnderlayLoop(ctx context.Context, properties UnderlayPropert
|
||||
underlay := &PacketUnderlay{
|
||||
baseUnderlay: *newBaseUnderlay(false, properties.MTU()),
|
||||
conn: conn,
|
||||
packetQueue: make(chan bufferWithAddr, packetChanCapacityServer),
|
||||
sessionCleanTicker: time.NewTicker(sessionCleanInterval),
|
||||
users: m.users,
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -126,6 +127,12 @@ func segmentLessFunc(a, b *segment) bool {
|
||||
return a.Less(b)
|
||||
}
|
||||
|
||||
// bufferWithAddr associate a raw network packet payload with a remote network address.
|
||||
type bufferWithAddr struct {
|
||||
b []byte
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
// segmentIterator processes the given segment.
|
||||
// If it returns false, stop the iteration.
|
||||
type segmentIterator func(*segment) bool
|
||||
|
||||
@@ -54,11 +54,14 @@ const (
|
||||
serverRespTimeout = 10 * time.Second
|
||||
sessionHeartbeatInterval = 5 * time.Second
|
||||
|
||||
earlyRetransmission = 3 // number of ack to trigger early retransmission
|
||||
earlyRetransmissionLimit = 2 // maximum number of early retransmission attempt
|
||||
maxRetransmissionBatchSize = 16 // maximum number of segments in a retransmission batch
|
||||
txTimeoutBackOff = 1.25 // tx timeout back off multiplier
|
||||
maxBackOffMultiplier = 20.0 // maximum back off multiplier
|
||||
// Number of ack to trigger early retransmission.
|
||||
earlyRetransmission = 3
|
||||
// Maximum number of early retransmission attempt.
|
||||
earlyRetransmissionLimit = 1
|
||||
// Send timeout back off multiplier.
|
||||
txTimeoutBackOff = 1.25
|
||||
// Maximum back off multiplier.
|
||||
maxBackOffMultiplier = 20.0
|
||||
)
|
||||
|
||||
type sessionState byte
|
||||
@@ -127,10 +130,10 @@ type Session struct {
|
||||
uploadBytes metrics.Metric // number of bytes from client to server, only used by server
|
||||
downloadBytes metrics.Metric // number of bytes from server to client, only used by server
|
||||
|
||||
rttStat *congestion.RTTStats
|
||||
legacysendAlgorithm *congestion.CubicSendAlgorithm
|
||||
sendAlgorithm *congestion.BBRSender
|
||||
remoteWindowSize uint16
|
||||
rttStat *congestion.RTTStats
|
||||
cubicSendAlgorithm *congestion.CubicSendAlgorithm
|
||||
bbrSendAlgorithm *congestion.BBRSender
|
||||
remoteWindowSize uint16
|
||||
|
||||
wg sync.WaitGroup
|
||||
rLock sync.Mutex // serialize application read
|
||||
@@ -153,31 +156,31 @@ func NewSession(id uint32, isClient bool, mtu int, users map[string]*appctlpb.Us
|
||||
rttStat.SetMaxAckDelay(periodicOutputInterval)
|
||||
rttStat.SetRTOMultiplier(txTimeoutBackOff)
|
||||
return &Session{
|
||||
conn: nil,
|
||||
block: atomic.Pointer[cipher.BlockCipher]{},
|
||||
id: id,
|
||||
isClient: isClient,
|
||||
mtu: mtu,
|
||||
state: sessionInit,
|
||||
status: statusOK,
|
||||
users: users,
|
||||
ready: make(chan struct{}),
|
||||
closedChan: make(chan struct{}),
|
||||
readDeadline: time.Time{},
|
||||
writeDeadline: time.Time{},
|
||||
inputErr: make(chan error),
|
||||
outputErr: make(chan error),
|
||||
sendQueue: newSegmentTree(segmentTreeCapacity),
|
||||
sendBuf: newSegmentTree(segmentTreeCapacity),
|
||||
recvBuf: newSegmentTree(segmentTreeCapacity),
|
||||
recvQueue: newSegmentTree(segmentTreeRecvQueueCapacity),
|
||||
recvChan: make(chan *segment, segmentChanCapacity),
|
||||
lastRXTime: time.Now(),
|
||||
lastTXTime: time.Now(),
|
||||
rttStat: rttStat,
|
||||
legacysendAlgorithm: congestion.NewCubicSendAlgorithm(minWindowSize, maxWindowSize),
|
||||
sendAlgorithm: congestion.NewBBRSender(fmt.Sprintf("%d", id), rttStat),
|
||||
remoteWindowSize: minWindowSize,
|
||||
conn: nil,
|
||||
block: atomic.Pointer[cipher.BlockCipher]{},
|
||||
id: id,
|
||||
isClient: isClient,
|
||||
mtu: mtu,
|
||||
state: sessionInit,
|
||||
status: statusOK,
|
||||
users: users,
|
||||
ready: make(chan struct{}),
|
||||
closedChan: make(chan struct{}),
|
||||
readDeadline: time.Time{},
|
||||
writeDeadline: time.Time{},
|
||||
inputErr: make(chan error),
|
||||
outputErr: make(chan error),
|
||||
sendQueue: newSegmentTree(segmentTreeCapacity),
|
||||
sendBuf: newSegmentTree(segmentTreeCapacity),
|
||||
recvBuf: newSegmentTree(segmentTreeCapacity),
|
||||
recvQueue: newSegmentTree(segmentTreeRecvQueueCapacity),
|
||||
recvChan: make(chan *segment, segmentChanCapacity),
|
||||
lastRXTime: time.Now(),
|
||||
lastTXTime: time.Now(),
|
||||
rttStat: rttStat,
|
||||
cubicSendAlgorithm: congestion.NewCubicSendAlgorithm(minWindowSize, maxWindowSize),
|
||||
bbrSendAlgorithm: congestion.NewBBRSender(fmt.Sprintf("%d", id), rttStat),
|
||||
remoteWindowSize: minWindowSize,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -535,7 +538,7 @@ func (s *Session) writeChunk(b []byte) (n int, err error) {
|
||||
sessionID: s.id,
|
||||
seq: s.nextSend,
|
||||
unAckSeq: s.nextRecv,
|
||||
windowSize: uint16(mathext.Max(0, int(s.legacysendAlgorithm.CongestionWindowSize())-s.recvBuf.Len())),
|
||||
windowSize: uint16(s.receiveWindowSize()),
|
||||
fragment: uint8(i),
|
||||
payloadLen: uint16(partLen),
|
||||
},
|
||||
@@ -680,11 +683,14 @@ func (s *Session) runOutputOncePacket() {
|
||||
// Resend segments in sendBuf.
|
||||
//
|
||||
// Iterate all the segments in sendBuf to calculate bytesInFlight,
|
||||
// but only resend first a few segments if needed.
|
||||
// but only resend segments if needed.
|
||||
//
|
||||
// Retransmission is not limited by window.
|
||||
//
|
||||
// To avoid deadlock, session can't be closed inside Ascend().
|
||||
s.oLock.Lock()
|
||||
retransmissionCount := 0
|
||||
totalTransmissionCount := 0
|
||||
skipSendNewSegment := s.sendWindowSize() <= 0
|
||||
s.sendBuf.Ascend(func(iter *segment) bool {
|
||||
bytesInFlight += int64(packetOverhead + len(iter.payload))
|
||||
if iter.txCount >= txCountLimit {
|
||||
@@ -696,8 +702,9 @@ func (s *Session) runOutputOncePacket() {
|
||||
closeSessionReason = err
|
||||
return false
|
||||
}
|
||||
if retransmissionCount <= maxRetransmissionBatchSize && ((iter.ackCount >= earlyRetransmission && iter.txCount <= earlyRetransmissionLimit) || time.Since(iter.txTime) > iter.txTimeout) {
|
||||
if iter.ackCount >= earlyRetransmission {
|
||||
satisfyEarlyRetransmission := iter.ackCount >= earlyRetransmission && iter.txCount <= earlyRetransmissionLimit
|
||||
if satisfyEarlyRetransmission || time.Since(iter.txTime) > iter.txTimeout {
|
||||
if satisfyEarlyRetransmission {
|
||||
hasLoss = true
|
||||
} else {
|
||||
hasTimeout = true
|
||||
@@ -720,7 +727,7 @@ func (s *Session) runOutputOncePacket() {
|
||||
return false
|
||||
}
|
||||
bytesInFlight += int64(packetOverhead + len(iter.payload))
|
||||
retransmissionCount++
|
||||
totalTransmissionCount++
|
||||
return true
|
||||
}
|
||||
return true
|
||||
@@ -729,12 +736,15 @@ func (s *Session) runOutputOncePacket() {
|
||||
if closeSessionReason != nil {
|
||||
s.closeWithError(closeSessionReason)
|
||||
}
|
||||
if hasLoss || hasTimeout {
|
||||
s.legacysendAlgorithm.OnLoss() // OnTimeout() is too aggressive.
|
||||
if hasTimeout {
|
||||
s.cubicSendAlgorithm.OnTimeout()
|
||||
} else if hasLoss {
|
||||
s.cubicSendAlgorithm.OnLoss()
|
||||
}
|
||||
|
||||
// Send new segments in sendQueue.
|
||||
if s.sendQueue.Len() > 0 {
|
||||
skipSendNewSegment = skipSendNewSegment || totalTransmissionCount >= s.sendWindowSize()
|
||||
if s.sendQueue.Len() > 0 && !skipSendNewSegment {
|
||||
s.oLock.Lock()
|
||||
for {
|
||||
if s.sendBuf.Remaining() <= 1 {
|
||||
@@ -743,7 +753,9 @@ func (s *Session) runOutputOncePacket() {
|
||||
}
|
||||
|
||||
seg, deleted := s.sendQueue.DeleteMinIf(func(iter *segment) bool {
|
||||
return s.sendAlgorithm.CanSend(bytesInFlight, int64(packetOverhead+len(iter.payload)))
|
||||
bbrCanSend := s.bbrSendAlgorithm.CanSend(bytesInFlight, int64(packetOverhead+len(iter.payload)))
|
||||
congestionWindowCanSend := totalTransmissionCount < s.sendWindowSize()
|
||||
return bbrCanSend && congestionWindowCanSend
|
||||
})
|
||||
if !deleted {
|
||||
s.oLock.Unlock()
|
||||
@@ -789,15 +801,17 @@ func (s *Session) runOutputOncePacket() {
|
||||
break
|
||||
}
|
||||
newBytesInFlight := int64(packetOverhead + len(seg.payload))
|
||||
s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true)
|
||||
s.bbrSendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true)
|
||||
bytesInFlight += newBytesInFlight
|
||||
totalTransmissionCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
s.sendAlgorithm.OnApplicationLimited(bytesInFlight)
|
||||
s.bbrSendAlgorithm.OnApplicationLimited(bytesInFlight)
|
||||
}
|
||||
|
||||
// Send ACK or heartbeat if needed.
|
||||
// ACK is not limited by window.
|
||||
exceedHeartbeatInterval := time.Since(s.lastTXTime) > sessionHeartbeatInterval
|
||||
if s.ackOnDataRecv.Load() || exceedHeartbeatInterval {
|
||||
baseStruct := baseStruct{}
|
||||
@@ -813,7 +827,7 @@ func (s *Session) runOutputOncePacket() {
|
||||
sessionID: s.id,
|
||||
seq: uint32(mathext.Max(0, int(s.nextSend)-1)),
|
||||
unAckSeq: s.nextRecv,
|
||||
windowSize: uint16(mathext.Max(0, int(s.legacysendAlgorithm.CongestionWindowSize())-s.recvBuf.Len())),
|
||||
windowSize: uint16(s.receiveWindowSize()),
|
||||
},
|
||||
transport: s.transportProtocol}
|
||||
if err := s.output(ackSeg, s.RemoteAddr()); err != nil {
|
||||
@@ -827,17 +841,11 @@ func (s *Session) runOutputOncePacket() {
|
||||
} else {
|
||||
seq, err := ackSeg.Seq()
|
||||
if err != nil {
|
||||
s.oLock.Unlock()
|
||||
err = fmt.Errorf("failed to get sequence number from %v: %w", ackSeg, err)
|
||||
log.Debugf("%v %v", s, err)
|
||||
if s.outputHasErr.CompareAndSwap(false, true) {
|
||||
close(s.outputErr)
|
||||
}
|
||||
s.closeWithError(err)
|
||||
panic(fmt.Sprintf("failed to get sequence number from ack segment %v: %v", ackSeg, err))
|
||||
} else {
|
||||
s.oLock.Unlock()
|
||||
newBytesInFlight := int64(packetOverhead + len(ackSeg.payload))
|
||||
s.sendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true)
|
||||
s.bbrSendAlgorithm.OnPacketSent(time.Now(), bytesInFlight, int64(seq), newBytesInFlight, true)
|
||||
bytesInFlight += newBytesInFlight
|
||||
}
|
||||
}
|
||||
@@ -936,7 +944,7 @@ func (s *Session) inputData(seg *segment) error {
|
||||
break
|
||||
}
|
||||
s.rttStat.UpdateRTT(time.Since(seg2.txTime))
|
||||
s.legacysendAlgorithm.OnAck()
|
||||
s.cubicSendAlgorithm.OnAck()
|
||||
seq, _ := seg2.Seq()
|
||||
ackedPackets = append(ackedPackets, congestion.AckedPacketInfo{
|
||||
PacketNumber: int64(seq),
|
||||
@@ -945,7 +953,7 @@ func (s *Session) inputData(seg *segment) error {
|
||||
})
|
||||
}
|
||||
if len(ackedPackets) > 0 {
|
||||
s.sendAlgorithm.OnCongestionEvent(priorInFlight, time.Now(), ackedPackets, nil)
|
||||
s.bbrSendAlgorithm.OnCongestionEvent(priorInFlight, time.Now(), ackedPackets, nil)
|
||||
}
|
||||
s.remoteWindowSize = das.windowSize
|
||||
}
|
||||
@@ -1060,7 +1068,7 @@ func (s *Session) inputAck(seg *segment) error {
|
||||
break
|
||||
}
|
||||
s.rttStat.UpdateRTT(time.Since(seg2.txTime))
|
||||
s.legacysendAlgorithm.OnAck()
|
||||
s.cubicSendAlgorithm.OnAck()
|
||||
seq, _ := seg2.Seq()
|
||||
ackedPackets = append(ackedPackets, congestion.AckedPacketInfo{
|
||||
PacketNumber: int64(seq),
|
||||
@@ -1069,7 +1077,7 @@ func (s *Session) inputAck(seg *segment) error {
|
||||
})
|
||||
}
|
||||
if len(ackedPackets) > 0 {
|
||||
s.sendAlgorithm.OnCongestionEvent(priorInFlight, time.Now(), ackedPackets, nil)
|
||||
s.bbrSendAlgorithm.OnCongestionEvent(priorInFlight, time.Now(), ackedPackets, nil)
|
||||
}
|
||||
s.remoteWindowSize = das.windowSize
|
||||
|
||||
@@ -1222,6 +1230,24 @@ func (s *Session) closeWithError(err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendWindowSize determines how many packets this session can send.
|
||||
func (s *Session) sendWindowSize() int {
|
||||
return mathext.Max(0, mathext.Min(int(s.cubicSendAlgorithm.CongestionWindowSize()), int(s.remoteWindowSize)))
|
||||
}
|
||||
|
||||
// receiveWindowSize determines how many packets this session can receive.
|
||||
func (s *Session) receiveWindowSize() int {
|
||||
var underlayWaitingPackets int
|
||||
if s.conn != nil {
|
||||
packetUnderlay, ok := s.conn.(*PacketUnderlay)
|
||||
if ok {
|
||||
// Other packets sharing the same UDP socket reduce the congestion window.
|
||||
underlayWaitingPackets = len(packetUnderlay.packetQueue)
|
||||
}
|
||||
}
|
||||
return mathext.Max(0, int(s.cubicSendAlgorithm.CongestionWindowSize())-s.recvBuf.Len()-underlayWaitingPackets)
|
||||
}
|
||||
|
||||
func (s *Session) checkQuota(userName string) (ok bool, err error) {
|
||||
if len(s.users) == 0 {
|
||||
return true, fmt.Errorf("no registered user")
|
||||
|
||||
@@ -30,8 +30,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
sessionChanCapacity = 64
|
||||
// Number of ready sessions before they are consumed by Accept().
|
||||
sessionChanCapacity = 64
|
||||
|
||||
sessionCleanInterval = 5 * time.Second
|
||||
|
||||
// Buffer received network packets before they are dropped by OS kernel.
|
||||
packetChanCapacityClient = 4 * 1024
|
||||
packetChanCapacityServer = 16 * 1024
|
||||
)
|
||||
|
||||
// baseUnderlay contains a partial implementation of underlay.
|
||||
|
||||
@@ -49,6 +49,9 @@ type PacketUnderlay struct {
|
||||
baseUnderlay
|
||||
conn net.PacketConn
|
||||
|
||||
// packetQueue stores raw network packets payload not parsed to segments.
|
||||
packetQueue chan bufferWithAddr
|
||||
|
||||
sessionCleanTicker *time.Ticker
|
||||
|
||||
// ---- client fields ----
|
||||
@@ -87,6 +90,7 @@ func NewPacketUnderlay(ctx context.Context, packetDialer apicommon.PacketDialer,
|
||||
u := &PacketUnderlay{
|
||||
baseUnderlay: *newBaseUnderlay(true, mtu),
|
||||
conn: conn,
|
||||
packetQueue: make(chan bufferWithAddr, packetChanCapacityClient),
|
||||
sessionCleanTicker: time.NewTicker(sessionCleanInterval),
|
||||
serverAddr: remoteAddr,
|
||||
block: block,
|
||||
@@ -168,6 +172,28 @@ func (u *PacketUnderlay) RunEventLoop(ctx context.Context) error {
|
||||
return stderror.ErrNullPointer
|
||||
}
|
||||
|
||||
// OS has limited buffer to store received UDP packets.
|
||||
// Move the received UDP packets to user space as quickly as possible,
|
||||
// so we can process them later at a slower pace.
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if err := u.readOneSegment(); err != nil {
|
||||
if stderror.IsTimeout(err) {
|
||||
continue
|
||||
}
|
||||
log.Debugf("%v readOneSegment() failed: %v", u, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -180,12 +206,9 @@ func (u *PacketUnderlay) RunEventLoop(ctx context.Context) error {
|
||||
u.cleanSessions()
|
||||
default:
|
||||
}
|
||||
seg, addr, err := u.readOneSegment()
|
||||
seg, addr, err := u.parseOneSegment()
|
||||
if err != nil {
|
||||
if stderror.IsTimeout(err) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("readOneSegment() failed: %w", err)
|
||||
return fmt.Errorf("parseOneSegment() failed: %w", err)
|
||||
}
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("%v received %v from peer %v", u, seg, addr)
|
||||
@@ -294,28 +317,29 @@ func (u *PacketUnderlay) onCloseSession(seg *segment) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *PacketUnderlay) readOneSegment() (*segment, net.Addr, error) {
|
||||
func (u *PacketUnderlay) readOneSegment() error {
|
||||
var n int
|
||||
var addr net.Addr
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-u.done:
|
||||
return nil, nil, io.ErrClosedPipe
|
||||
return io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
common.SetReadTimeout(u.conn, readOneSegmentTimeout)
|
||||
defer common.SetReadTimeout(u.conn, 0)
|
||||
|
||||
// Peer may select a different MTU.
|
||||
// Use the largest possible value here to avoid error.
|
||||
b := make([]byte, 1500)
|
||||
n, addr, err = u.conn.ReadFrom(b)
|
||||
if err != nil {
|
||||
if stderror.IsTimeout(err) {
|
||||
return nil, nil, stderror.ErrTimeout
|
||||
return stderror.ErrTimeout
|
||||
}
|
||||
return nil, nil, fmt.Errorf("ReadFrom() failed: %w", err)
|
||||
return fmt.Errorf("ReadFrom() failed: %w", err)
|
||||
}
|
||||
if u.isClient && addr.String() != u.serverAddr.String() {
|
||||
UnderlayUnsolicitedUDP.Add(1)
|
||||
@@ -337,153 +361,171 @@ func (u *PacketUnderlay) readOneSegment() (*segment, net.Addr, error) {
|
||||
} else {
|
||||
metrics.UploadBytes.Add(int64(n))
|
||||
}
|
||||
|
||||
// Read encrypted metadata.
|
||||
encryptedMeta := b[:packetNonHeaderPosition]
|
||||
isNewSessionReplay := false
|
||||
if packetReplayCache.IsDuplicate(encryptedMeta[:cipher.DefaultOverhead], addr.String()) {
|
||||
replay.NewSession.Add(1)
|
||||
isNewSessionReplay = true
|
||||
u.packetQueue <- bufferWithAddr{
|
||||
b: b,
|
||||
addr: addr,
|
||||
}
|
||||
nonce := encryptedMeta[:cipher.DefaultNonceSize]
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt metadata.
|
||||
var decryptedMeta []byte
|
||||
var blockCipher cipher.BlockCipher
|
||||
if u.isClient {
|
||||
decryptedMeta, err = u.block.Decrypt(encryptedMeta)
|
||||
cipher.ClientDirectDecrypt.Add(1)
|
||||
if err != nil {
|
||||
cipher.ClientFailedDirectDecrypt.Add(1)
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("%v Decrypt() failed with packet from %v", u, addr)
|
||||
}
|
||||
continue
|
||||
func (u *PacketUnderlay) parseOneSegment() (*segment, net.Addr, error) {
|
||||
var err error
|
||||
for {
|
||||
select {
|
||||
case <-u.done:
|
||||
return nil, nil, io.ErrClosedPipe
|
||||
case raw := <-u.packetQueue:
|
||||
b := raw.b
|
||||
addr := raw.addr
|
||||
|
||||
// Read encrypted metadata.
|
||||
encryptedMeta := b[:packetNonHeaderPosition]
|
||||
isNewSessionReplay := false
|
||||
if packetReplayCache.IsDuplicate(encryptedMeta[:cipher.DefaultOverhead], addr.String()) {
|
||||
replay.NewSession.Add(1)
|
||||
isNewSessionReplay = true
|
||||
}
|
||||
} else {
|
||||
var decrypted bool
|
||||
var err error
|
||||
// Try existing sessions.
|
||||
cipher.ServerIterateDecrypt.Add(1)
|
||||
u.sessionMap.Range(func(k, v any) bool {
|
||||
session := v.(*Session)
|
||||
if session.block.Load() != nil && session.RemoteAddr().String() == addr.String() {
|
||||
decryptedMeta, err = (*session.block.Load()).Decrypt(encryptedMeta)
|
||||
if err == nil {
|
||||
decrypted = true
|
||||
blockCipher = *session.block.Load()
|
||||
return false
|
||||
nonce := encryptedMeta[:cipher.DefaultNonceSize]
|
||||
|
||||
// Decrypt metadata.
|
||||
var decryptedMeta []byte
|
||||
var blockCipher cipher.BlockCipher
|
||||
if u.isClient {
|
||||
decryptedMeta, err = u.block.Decrypt(encryptedMeta)
|
||||
cipher.ClientDirectDecrypt.Add(1)
|
||||
if err != nil {
|
||||
cipher.ClientFailedDirectDecrypt.Add(1)
|
||||
if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("%v Decrypt() failed with packet from %v", u, addr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
var decrypted bool
|
||||
var err error
|
||||
// Try existing sessions.
|
||||
cipher.ServerIterateDecrypt.Add(1)
|
||||
u.sessionMap.Range(func(k, v any) bool {
|
||||
session := v.(*Session)
|
||||
if session.block.Load() != nil && session.RemoteAddr().String() == addr.String() {
|
||||
decryptedMeta, err = (*session.block.Load()).Decrypt(encryptedMeta)
|
||||
if err == nil {
|
||||
decrypted = true
|
||||
blockCipher = *session.block.Load()
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !decrypted {
|
||||
// This is a new session. Try all registered users.
|
||||
for _, user := range u.users {
|
||||
var password []byte
|
||||
password, err = hex.DecodeString(user.GetHashedPassword())
|
||||
if err != nil {
|
||||
log.Debugf("Unable to decode hashed password %q from user %q", user.GetHashedPassword(), user.GetName())
|
||||
continue
|
||||
}
|
||||
if len(password) == 0 {
|
||||
password = cipher.HashPassword([]byte(user.GetPassword()), []byte(user.GetName()))
|
||||
}
|
||||
blockCipher, decryptedMeta, err = cipher.TryDecrypt(encryptedMeta, password, true)
|
||||
if err == nil {
|
||||
decrypted = true
|
||||
blockCipher.SetBlockContext(cipher.BlockContext{
|
||||
UserName: user.GetName(),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !decrypted {
|
||||
// This is a new session. Try all registered users.
|
||||
for _, user := range u.users {
|
||||
var password []byte
|
||||
password, err = hex.DecodeString(user.GetHashedPassword())
|
||||
if err != nil {
|
||||
log.Debugf("Unable to decode hashed password %q from user %q", user.GetHashedPassword(), user.GetName())
|
||||
if !decrypted {
|
||||
cipher.ServerFailedIterateDecrypt.Add(1)
|
||||
if isNewSessionReplay {
|
||||
log.Debugf("found possible replay attack in %v from %v", u, addr)
|
||||
} else if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("%v TryDecrypt() failed with packet from %v", u, addr)
|
||||
}
|
||||
continue
|
||||
} else {
|
||||
if blockCipher == nil {
|
||||
panic("PacketUnderlay parseOneSegment(): block cipher is nil after decryption is successful")
|
||||
}
|
||||
if isNewSessionReplay {
|
||||
replay.NewSessionDecrypted.Add(1)
|
||||
log.Debugf("found possible replay attack with payload decrypted in %v from %v", u, addr)
|
||||
continue
|
||||
}
|
||||
if len(password) == 0 {
|
||||
password = cipher.HashPassword([]byte(user.GetPassword()), []byte(user.GetName()))
|
||||
}
|
||||
blockCipher, decryptedMeta, err = cipher.TryDecrypt(encryptedMeta, password, true)
|
||||
if err == nil {
|
||||
decrypted = true
|
||||
blockCipher.SetBlockContext(cipher.BlockContext{
|
||||
UserName: user.GetName(),
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !decrypted {
|
||||
cipher.ServerFailedIterateDecrypt.Add(1)
|
||||
if isNewSessionReplay {
|
||||
log.Debugf("found possible replay attack in %v from %v", u, addr)
|
||||
} else if log.IsLevelEnabled(log.TraceLevel) {
|
||||
log.Tracef("%v TryDecrypt() failed with packet from %v", u, addr)
|
||||
}
|
||||
if len(decryptedMeta) != MetadataLength {
|
||||
log.Debugf("decrypted metadata size %d is unexpected", len(decryptedMeta))
|
||||
continue
|
||||
} else {
|
||||
if blockCipher == nil {
|
||||
panic("PacketUnderlay readOneSegment(): block cipher is nil after decryption is successful")
|
||||
}
|
||||
if isNewSessionReplay {
|
||||
replay.NewSessionDecrypted.Add(1)
|
||||
log.Debugf("found possible replay attack with payload decrypted in %v from %v", u, addr)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(decryptedMeta) != MetadataLength {
|
||||
log.Debugf("decrypted metadata size %d is unexpected", len(decryptedMeta))
|
||||
continue
|
||||
}
|
||||
|
||||
// Read payload and construct segment.
|
||||
var seg *segment
|
||||
p := decryptedMeta[0]
|
||||
if isSessionProtocol(protocolType(p)) {
|
||||
ss := &sessionStruct{}
|
||||
if err := ss.Unmarshal(decryptedMeta); err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, fmt.Errorf("Unmarshal() to sessionStruct failed: %w", err)
|
||||
} else {
|
||||
log.Debugf("%v Unmarshal() to sessionStruct failed: %v", u, err)
|
||||
continue
|
||||
// Read payload and construct segment.
|
||||
var seg *segment
|
||||
p := decryptedMeta[0]
|
||||
if isSessionProtocol(protocolType(p)) {
|
||||
ss := &sessionStruct{}
|
||||
if err := ss.Unmarshal(decryptedMeta); err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, fmt.Errorf("Unmarshal() to sessionStruct failed: %w", err)
|
||||
} else {
|
||||
log.Debugf("%v Unmarshal() to sessionStruct failed: %v", u, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
seg, err = u.readSessionSegment(ss, nonce, b[packetNonHeaderPosition:], blockCipher)
|
||||
if err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, err
|
||||
} else {
|
||||
log.Debugf("%v readSessionSegment() failed: %v", u, err)
|
||||
continue
|
||||
seg, err = u.parseSessionSegment(ss, nonce, b[packetNonHeaderPosition:], blockCipher)
|
||||
if err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, err
|
||||
} else {
|
||||
log.Debugf("%v parseSessionSegment() failed: %v", u, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if blockCipher != nil {
|
||||
seg.block = blockCipher
|
||||
}
|
||||
return seg, addr, nil
|
||||
} else if isDataAckProtocol(protocolType(p)) {
|
||||
das := &dataAckStruct{}
|
||||
if err := das.Unmarshal(decryptedMeta); err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, fmt.Errorf("Unmarshal() to dataAckStruct failed: %w", err)
|
||||
} else {
|
||||
log.Debugf("%v Unmarshal() to dataAckStruct failed: %v", u, err)
|
||||
continue
|
||||
if blockCipher != nil {
|
||||
seg.block = blockCipher
|
||||
}
|
||||
}
|
||||
seg, err = u.readDataAckSegment(das, nonce, b[packetNonHeaderPosition:], blockCipher)
|
||||
if err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, err
|
||||
} else {
|
||||
log.Debugf("%v readDataAckSegment() failed: %v", u, err)
|
||||
continue
|
||||
return seg, addr, nil
|
||||
} else if isDataAckProtocol(protocolType(p)) {
|
||||
das := &dataAckStruct{}
|
||||
if err := das.Unmarshal(decryptedMeta); err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, fmt.Errorf("Unmarshal() to dataAckStruct failed: %w", err)
|
||||
} else {
|
||||
log.Debugf("%v Unmarshal() to dataAckStruct failed: %v", u, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if blockCipher != nil {
|
||||
seg.block = blockCipher
|
||||
}
|
||||
return seg, addr, nil
|
||||
} else {
|
||||
if u.isClient {
|
||||
return nil, nil, fmt.Errorf("unable to handle protocol %d", p)
|
||||
seg, err = u.parseDataAckSegment(das, nonce, b[packetNonHeaderPosition:], blockCipher)
|
||||
if err != nil {
|
||||
if u.isClient {
|
||||
return nil, nil, err
|
||||
} else {
|
||||
log.Debugf("%v parseDataAckSegment() failed: %v", u, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if blockCipher != nil {
|
||||
seg.block = blockCipher
|
||||
}
|
||||
return seg, addr, nil
|
||||
} else {
|
||||
log.Debugf("%v unable to handle protocol %d", u, p)
|
||||
continue
|
||||
if u.isClient {
|
||||
return nil, nil, fmt.Errorf("unable to handle protocol %d", p)
|
||||
} else {
|
||||
log.Debugf("%v unable to handle protocol %d", u, p)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *PacketUnderlay) readSessionSegment(ss *sessionStruct, nonce, remaining []byte, blockCipher cipher.BlockCipher) (*segment, error) {
|
||||
func (u *PacketUnderlay) parseSessionSegment(ss *sessionStruct, nonce, remaining []byte, blockCipher cipher.BlockCipher) (*segment, error) {
|
||||
var decryptedPayload []byte
|
||||
var err error
|
||||
|
||||
@@ -529,7 +571,7 @@ func (u *PacketUnderlay) readSessionSegment(ss *sessionStruct, nonce, remaining
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (u *PacketUnderlay) readDataAckSegment(das *dataAckStruct, nonce, remaining []byte, blockCipher cipher.BlockCipher) (*segment, error) {
|
||||
func (u *PacketUnderlay) parseDataAckSegment(das *dataAckStruct, nonce, remaining []byte, blockCipher cipher.BlockCipher) (*segment, error) {
|
||||
var decryptedPayload []byte
|
||||
var err error
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ type StreamUnderlay struct {
|
||||
send cipher.BlockCipher
|
||||
recv cipher.BlockCipher
|
||||
|
||||
// Candidates are block ciphers that can be used to encrypt or decrypt data.
|
||||
// candidates are block ciphers that can be used to encrypt or decrypt data.
|
||||
// When isClient is true, there must be exactly 1 element in the slice.
|
||||
candidates []cipher.BlockCipher
|
||||
|
||||
|
||||
Reference in New Issue
Block a user