mirror of
https://github.com/bolucat/Archive.git
synced 2025-09-26 20:21:35 +08:00
730 lines
21 KiB
Go
730 lines
21 KiB
Go
// Copyright (C) 2023 mieru authors
|
|
//
|
|
// This program is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// This program is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with this program. If not, see <https://www.gnu.org/licenses/>
|
|
|
|
package protocol
|
|
|
|
import (
|
|
"context"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/enfein/mieru/pkg/appctl/appctlpb"
|
|
"github.com/enfein/mieru/pkg/cipher"
|
|
"github.com/enfein/mieru/pkg/log"
|
|
"github.com/enfein/mieru/pkg/mathext"
|
|
"github.com/enfein/mieru/pkg/metrics"
|
|
"github.com/enfein/mieru/pkg/replay"
|
|
"github.com/enfein/mieru/pkg/stderror"
|
|
"github.com/enfein/mieru/pkg/util"
|
|
"github.com/enfein/mieru/pkg/util/sockopts"
|
|
)
|
|
|
|
const (
|
|
udpOverhead = cipher.DefaultNonceSize + MetadataLength + cipher.DefaultOverhead*2
|
|
udpNonHeaderPosition = cipher.DefaultNonceSize + MetadataLength + cipher.DefaultOverhead
|
|
|
|
idleSessionTickerInterval = 5 * time.Second
|
|
idleSessionTimeout = time.Minute
|
|
|
|
readOneSegmentTimeout = 5 * time.Second
|
|
)
|
|
|
|
var udpReplayCache = replay.NewCache(4*1024*1024, cipher.KeyRefreshInterval*3)
|
|
|
|
type UDPUnderlay struct {
|
|
// ---- common fields ----
|
|
baseUnderlay
|
|
conn *net.UDPConn
|
|
|
|
idleSessionTicker *time.Ticker
|
|
|
|
// ---- client fields ----
|
|
serverAddr *net.UDPAddr
|
|
block cipher.BlockCipher
|
|
|
|
// ---- server fields ----
|
|
users map[string]*appctlpb.User
|
|
}
|
|
|
|
var _ Underlay = &UDPUnderlay{}
|
|
|
|
// NewUDPUnderlay connects to the remote address "raddr" on the network "udp"
|
|
// with packet encryption. If "laddr" is empty, an automatic address is used.
|
|
// "block" is the block encryption algorithm to encrypt packets.
|
|
func NewUDPUnderlay(ctx context.Context, network, laddr, raddr string, mtu int, block cipher.BlockCipher) (*UDPUnderlay, error) {
|
|
switch network {
|
|
case "udp", "udp4", "udp6":
|
|
default:
|
|
return nil, fmt.Errorf("network %s is not supported by UDP underlay", network)
|
|
}
|
|
if !block.IsStateless() {
|
|
return nil, fmt.Errorf("UDP block cipher must be stateless")
|
|
}
|
|
var localAddr *net.UDPAddr
|
|
var err error
|
|
if laddr != "" {
|
|
localAddr, err = net.ResolveUDPAddr("udp", laddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("net.ResolveUDPAddr() failed: %w", err)
|
|
}
|
|
}
|
|
remoteAddr, err := net.ResolveUDPAddr("udp", raddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("net.ResolveUDPAddr() failed: %w", err)
|
|
}
|
|
|
|
conn, err := net.ListenUDP(network, localAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("net.ListenUDP() failed: %w", err)
|
|
}
|
|
if err := sockopts.ApplyUDPControls(conn); err != nil {
|
|
return nil, fmt.Errorf("ApplyUDPControls() failed: %w", err)
|
|
}
|
|
u := &UDPUnderlay{
|
|
baseUnderlay: *newBaseUnderlay(true, mtu),
|
|
conn: conn,
|
|
idleSessionTicker: time.NewTicker(idleSessionTickerInterval),
|
|
serverAddr: remoteAddr,
|
|
block: block,
|
|
}
|
|
// The block cipher expires after this time.
|
|
u.scheduler.SetRemainingTime(cipher.KeyRefreshInterval / 2)
|
|
return u, nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) String() string {
|
|
if u.conn == nil {
|
|
return "UDPUnderlay{}"
|
|
}
|
|
if u.isClient {
|
|
return fmt.Sprintf("UDPUnderlay{local=%v, remote=%v, mtu=%v, ipVersion=%v}", u.LocalAddr(), u.RemoteAddr(), u.mtu, u.IPVersion())
|
|
} else {
|
|
return fmt.Sprintf("UDPUnderlay{local=%v, mtu=%v, ipVersion=%v}", u.LocalAddr(), u.mtu, u.IPVersion())
|
|
}
|
|
}
|
|
|
|
func (u *UDPUnderlay) Close() error {
|
|
u.closeMutex.Lock()
|
|
defer u.closeMutex.Unlock()
|
|
select {
|
|
case <-u.done:
|
|
return nil
|
|
default:
|
|
}
|
|
|
|
log.Debugf("Closing %v", u)
|
|
u.idleSessionTicker.Stop()
|
|
u.baseUnderlay.Close()
|
|
return u.conn.Close()
|
|
}
|
|
|
|
func (u *UDPUnderlay) IPVersion() util.IPVersion {
|
|
u.ipVersionMutex.Lock()
|
|
defer u.ipVersionMutex.Unlock()
|
|
if u.conn == nil {
|
|
return util.IPVersionUnknown
|
|
}
|
|
if u.ipVersion == util.IPVersionUnknown {
|
|
u.ipVersion = util.GetIPVersion(u.conn.LocalAddr().String())
|
|
}
|
|
return u.ipVersion
|
|
}
|
|
|
|
func (u *UDPUnderlay) TransportProtocol() util.TransportProtocol {
|
|
return util.UDPTransport
|
|
}
|
|
|
|
func (u *UDPUnderlay) LocalAddr() net.Addr {
|
|
return u.conn.LocalAddr()
|
|
}
|
|
|
|
func (u *UDPUnderlay) RemoteAddr() net.Addr {
|
|
if u.isClient && u.serverAddr != nil {
|
|
return u.serverAddr
|
|
}
|
|
return util.NilNetAddr()
|
|
}
|
|
|
|
func (u *UDPUnderlay) AddSession(s *Session, remoteAddr net.Addr) error {
|
|
if err := u.baseUnderlay.AddSession(s, remoteAddr); err != nil {
|
|
return err
|
|
}
|
|
s.conn = u // override base underlay
|
|
close(s.ready)
|
|
log.Debugf("Adding session %d to %v", s.id, u)
|
|
|
|
s.wg.Add(2)
|
|
go func() {
|
|
if err := s.runInputLoop(context.Background()); err != nil && !stderror.IsEOF(err) && !stderror.IsClosed(err) {
|
|
log.Debugf("%v runInputLoop(): %v", s, err)
|
|
}
|
|
s.wg.Done()
|
|
}()
|
|
go func() {
|
|
if err := s.runOutputLoop(context.Background()); err != nil && !stderror.IsEOF(err) && !stderror.IsClosed(err) {
|
|
log.Debugf("%v runOutputLoop(): %v", s, err)
|
|
}
|
|
s.wg.Done()
|
|
}()
|
|
return nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) RunEventLoop(ctx context.Context) error {
|
|
if u.conn == nil {
|
|
return stderror.ErrNullPointer
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case <-u.done:
|
|
return nil
|
|
case <-u.idleSessionTicker.C:
|
|
// Close idle sessions.
|
|
u.sessionMap.Range(func(k, v any) bool {
|
|
session := v.(*Session)
|
|
select {
|
|
case <-session.done:
|
|
log.Debugf("Found closed %v", session)
|
|
if err := u.RemoveSession(session); err != nil {
|
|
log.Debugf("%v RemoveSession() failed: %v", u, err)
|
|
}
|
|
default:
|
|
}
|
|
if time.Since(session.lastRXTime) > idleSessionTimeout {
|
|
log.Debugf("Found idle %v", session)
|
|
if err := u.RemoveSession(session); err != nil {
|
|
log.Debugf("%v RemoveSession() failed: %v", u, err)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
default:
|
|
}
|
|
seg, addr, err := u.readOneSegment()
|
|
if err != nil {
|
|
if stderror.IsTimeout(err) {
|
|
continue
|
|
}
|
|
return fmt.Errorf("readOneSegment() failed: %w", err)
|
|
}
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v received %v from peer %v", u, seg, addr)
|
|
}
|
|
if isSessionProtocol(seg.metadata.Protocol()) {
|
|
switch seg.metadata.Protocol() {
|
|
case openSessionRequest:
|
|
if err := u.onOpenSessionRequest(seg, addr); err != nil {
|
|
return fmt.Errorf("onOpenSessionRequest() failed: %w", err)
|
|
}
|
|
case openSessionResponse:
|
|
if err := u.onOpenSessionResponse(seg); err != nil {
|
|
return fmt.Errorf("onOpenSessionResponse() failed: %w", err)
|
|
}
|
|
case closeSessionRequest, closeSessionResponse:
|
|
if err := u.onCloseSession(seg); err != nil {
|
|
return fmt.Errorf("onCloseSession() failed: %w", err)
|
|
}
|
|
default:
|
|
panic(fmt.Sprintf("Protocol %d is a session protocol but not recognized by UDP underlay", seg.metadata.Protocol()))
|
|
}
|
|
} else if isDataAckProtocol(seg.metadata.Protocol()) {
|
|
das, _ := toDataAckStruct(seg.metadata)
|
|
session, ok := u.sessionMap.Load(das.sessionID)
|
|
if !ok {
|
|
log.Debugf("Session %d is not registered to %v", das.sessionID, u)
|
|
if seg.block != nil {
|
|
// Request the peer to close the session.
|
|
closeReq := &segment{
|
|
metadata: &sessionStruct{
|
|
baseStruct: baseStruct{
|
|
protocol: uint8(closeSessionRequest),
|
|
},
|
|
sessionID: das.sessionID,
|
|
seq: das.unAckSeq,
|
|
statusCode: 0,
|
|
payloadLen: 0,
|
|
},
|
|
transport: u.TransportProtocol(),
|
|
block: seg.block,
|
|
}
|
|
if err := u.writeOneSegment(closeReq, addr); err != nil {
|
|
return fmt.Errorf("writeOneSegment() failed: %w", err)
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
session.(*Session).recvChan <- seg
|
|
} else {
|
|
log.Debugf("Ignore unknown protocol %d", seg.metadata.Protocol())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *UDPUnderlay) onOpenSessionRequest(seg *segment, remoteAddr net.Addr) error {
|
|
if u.isClient {
|
|
return stderror.ErrInvalidOperation
|
|
}
|
|
|
|
// Create a new session.
|
|
sessionID := seg.metadata.(*sessionStruct).sessionID
|
|
if sessionID == 0 {
|
|
// 0 is reserved and can't be used.
|
|
return fmt.Errorf("reserved session ID %d is used", sessionID)
|
|
}
|
|
_, found := u.sessionMap.Load(sessionID)
|
|
if found {
|
|
log.Debugf("%v received open session request, but session ID %d is already used", u, sessionID)
|
|
return nil
|
|
}
|
|
session := NewSession(sessionID, false, u.MTU())
|
|
session.users = u.users
|
|
u.AddSession(session, remoteAddr)
|
|
session.recvChan <- seg
|
|
u.readySessions <- session
|
|
return nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) onOpenSessionResponse(seg *segment) error {
|
|
if !u.isClient {
|
|
return stderror.ErrInvalidOperation
|
|
}
|
|
|
|
sessionID := seg.metadata.(*sessionStruct).sessionID
|
|
session, found := u.sessionMap.Load(sessionID)
|
|
if !found {
|
|
return fmt.Errorf("session ID %d is not found", sessionID)
|
|
}
|
|
session.(*Session).recvChan <- seg
|
|
return nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) onCloseSession(seg *segment) error {
|
|
ss := seg.metadata.(*sessionStruct)
|
|
sessionID := ss.sessionID
|
|
session, found := u.sessionMap.Load(sessionID)
|
|
if !found {
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v received close session request or response, but session ID %d is not found", u, sessionID)
|
|
}
|
|
return nil
|
|
}
|
|
s := session.(*Session)
|
|
s.recvChan <- seg
|
|
s.wg.Wait()
|
|
u.RemoveSession(s)
|
|
return nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) readOneSegment() (*segment, *net.UDPAddr, error) {
|
|
var n int
|
|
var addr *net.UDPAddr
|
|
var err error
|
|
for {
|
|
select {
|
|
case <-u.done:
|
|
return nil, nil, io.ErrClosedPipe
|
|
default:
|
|
}
|
|
|
|
util.SetReadTimeout(u.conn, readOneSegmentTimeout)
|
|
defer util.SetReadTimeout(u.conn, 0)
|
|
// Peer may select a different MTU.
|
|
// Use the largest possible value here to avoid error.
|
|
b := make([]byte, 1500)
|
|
n, addr, err = u.conn.ReadFromUDP(b)
|
|
if err != nil {
|
|
if stderror.IsTimeout(err) {
|
|
return nil, nil, stderror.ErrTimeout
|
|
}
|
|
return nil, nil, fmt.Errorf("ReadFromUDP() failed: %w", err)
|
|
}
|
|
if u.isClient && addr.String() != u.serverAddr.String() {
|
|
UnderlayUnsolicitedUDP.Add(1)
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v received unsolicited UDP packet from %v", u, addr)
|
|
}
|
|
continue
|
|
}
|
|
if n < udpNonHeaderPosition {
|
|
UnderlayMalformedUDP.Add(1)
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v received UDP packet from %v with only %d bytes, which is too short", u, addr, n)
|
|
}
|
|
continue
|
|
}
|
|
b = b[:n]
|
|
if u.isClient {
|
|
metrics.DownloadBytes.Add(int64(n))
|
|
} else {
|
|
metrics.UploadBytes.Add(int64(n))
|
|
}
|
|
|
|
// Read encrypted metadata.
|
|
encryptedMeta := b[:udpNonHeaderPosition]
|
|
isNewSessionReplay := false
|
|
if udpReplayCache.IsDuplicate(encryptedMeta[:cipher.DefaultOverhead], addr.String()) {
|
|
replay.NewSession.Add(1)
|
|
isNewSessionReplay = true
|
|
}
|
|
nonce := encryptedMeta[:cipher.DefaultNonceSize]
|
|
|
|
// Decrypt metadata.
|
|
var decryptedMeta []byte
|
|
var blockCipher cipher.BlockCipher
|
|
if u.isClient {
|
|
decryptedMeta, err = u.block.Decrypt(encryptedMeta)
|
|
cipher.ClientDirectDecrypt.Add(1)
|
|
if err != nil {
|
|
cipher.ClientFailedDirectDecrypt.Add(1)
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v Decrypt() failed with UDP packet from %v", u, addr)
|
|
}
|
|
continue
|
|
}
|
|
} else {
|
|
var decrypted bool
|
|
var err error
|
|
// Try existing sessions.
|
|
cipher.ServerIterateDecrypt.Add(1)
|
|
u.sessionMap.Range(func(k, v any) bool {
|
|
session := v.(*Session)
|
|
if session.block != nil && session.RemoteAddr().String() == addr.String() {
|
|
decryptedMeta, err = session.block.Decrypt(encryptedMeta)
|
|
if err == nil {
|
|
decrypted = true
|
|
blockCipher = session.block
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if !decrypted {
|
|
// This is a new session. Try all registered users.
|
|
for _, user := range u.users {
|
|
var password []byte
|
|
password, err = hex.DecodeString(user.GetHashedPassword())
|
|
if err != nil {
|
|
log.Debugf("Unable to decode hashed password %q from user %q", user.GetHashedPassword(), user.GetName())
|
|
continue
|
|
}
|
|
if len(password) == 0 {
|
|
password = cipher.HashPassword([]byte(user.GetPassword()), []byte(user.GetName()))
|
|
}
|
|
blockCipher, decryptedMeta, err = cipher.TryDecrypt(encryptedMeta, password, true)
|
|
if err == nil {
|
|
decrypted = true
|
|
blockCipher.SetBlockContext(cipher.BlockContext{
|
|
UserName: user.GetName(),
|
|
})
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if !decrypted {
|
|
cipher.ServerFailedIterateDecrypt.Add(1)
|
|
if isNewSessionReplay {
|
|
log.Debugf("found possible replay attack in %v from %v", u, addr)
|
|
} else if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v TryDecrypt() failed with UDP packet from %v", u, addr)
|
|
}
|
|
continue
|
|
} else {
|
|
if blockCipher == nil {
|
|
panic("UDPUnderlay readOneSegment(): block cipher is nil after decryption is successful")
|
|
}
|
|
if isNewSessionReplay {
|
|
replay.NewSessionDecrypted.Add(1)
|
|
log.Debugf("found possible replay attack with payload decrypted in %v from %v", u, addr)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
if len(decryptedMeta) != MetadataLength {
|
|
return nil, nil, fmt.Errorf("decrypted metadata size %d is unexpected", len(decryptedMeta))
|
|
}
|
|
|
|
// Read payload and construct segment.
|
|
var seg *segment
|
|
p := decryptedMeta[0]
|
|
if isSessionProtocol(protocolType(p)) {
|
|
ss := &sessionStruct{}
|
|
if err := ss.Unmarshal(decryptedMeta); err != nil {
|
|
if u.isClient {
|
|
return nil, nil, fmt.Errorf("Unmarshal() to sessionStruct failed: %w", err)
|
|
} else {
|
|
log.Debugf("%v Unmarshal() to sessionStruct failed: %v", u, err)
|
|
continue
|
|
}
|
|
}
|
|
seg, err = u.readSessionSegment(ss, nonce, b[udpNonHeaderPosition:], blockCipher)
|
|
if err != nil {
|
|
if u.isClient {
|
|
return nil, nil, err
|
|
} else {
|
|
log.Debugf("%v readSessionSegment() failed: %v", u, err)
|
|
continue
|
|
}
|
|
}
|
|
if blockCipher != nil {
|
|
seg.block = blockCipher
|
|
}
|
|
return seg, addr, nil
|
|
} else if isDataAckProtocol(protocolType(p)) {
|
|
das := &dataAckStruct{}
|
|
if err := das.Unmarshal(decryptedMeta); err != nil {
|
|
if u.isClient {
|
|
return nil, nil, fmt.Errorf("Unmarshal() to dataAckStruct failed: %w", err)
|
|
} else {
|
|
log.Debugf("%v Unmarshal() to dataAckStruct failed: %v", u, err)
|
|
continue
|
|
}
|
|
}
|
|
seg, err = u.readDataAckSegment(das, nonce, b[udpNonHeaderPosition:], blockCipher)
|
|
if err != nil {
|
|
if u.isClient {
|
|
return nil, nil, err
|
|
} else {
|
|
log.Debugf("%v readDataAckSegment() failed: %v", u, err)
|
|
continue
|
|
}
|
|
}
|
|
if blockCipher != nil {
|
|
seg.block = blockCipher
|
|
}
|
|
return seg, addr, nil
|
|
}
|
|
return nil, nil, fmt.Errorf("unable to handle protocol %d", p)
|
|
}
|
|
}
|
|
|
|
func (u *UDPUnderlay) readSessionSegment(ss *sessionStruct, nonce, remaining []byte, blockCipher cipher.BlockCipher) (*segment, error) {
|
|
var decryptedPayload []byte
|
|
var err error
|
|
|
|
if ss.payloadLen > 0 {
|
|
if len(remaining) < int(ss.payloadLen)+cipher.DefaultOverhead {
|
|
return nil, fmt.Errorf("payload: received incomplete UDP packet")
|
|
}
|
|
if blockCipher == nil {
|
|
if u.isClient {
|
|
blockCipher = u.block
|
|
} else {
|
|
panic("UDPUnderlay readSessionSegment(): block is nil")
|
|
}
|
|
}
|
|
encryptedPayload := remaining[:ss.payloadLen+cipher.DefaultOverhead]
|
|
decryptedPayload, err = blockCipher.DecryptWithNonce(encryptedPayload, nonce)
|
|
if u.isClient {
|
|
cipher.ClientDirectDecrypt.Add(1)
|
|
} else {
|
|
cipher.ServerDirectDecrypt.Add(1)
|
|
}
|
|
if err != nil {
|
|
if u.isClient {
|
|
cipher.ClientFailedDirectDecrypt.Add(1)
|
|
} else {
|
|
cipher.ServerFailedDirectDecrypt.Add(1)
|
|
}
|
|
return nil, fmt.Errorf("DecryptWithNonce() failed: %w", err)
|
|
}
|
|
if int(ss.payloadLen)+cipher.DefaultOverhead+int(ss.suffixLen) != len(remaining) {
|
|
return nil, fmt.Errorf("padding: size not match")
|
|
}
|
|
} else {
|
|
if int(ss.suffixLen) != len(remaining) {
|
|
return nil, fmt.Errorf("padding: size not match")
|
|
}
|
|
}
|
|
|
|
return &segment{
|
|
metadata: ss,
|
|
payload: decryptedPayload,
|
|
transport: util.UDPTransport,
|
|
}, nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) readDataAckSegment(das *dataAckStruct, nonce, remaining []byte, blockCipher cipher.BlockCipher) (*segment, error) {
|
|
var decryptedPayload []byte
|
|
var err error
|
|
|
|
if das.prefixLen > 0 {
|
|
remaining = remaining[das.prefixLen:]
|
|
}
|
|
if das.payloadLen > 0 {
|
|
if len(remaining) < int(das.payloadLen)+cipher.DefaultOverhead {
|
|
return nil, fmt.Errorf("payload: received incomplete UDP packet")
|
|
}
|
|
if blockCipher == nil {
|
|
if u.isClient {
|
|
blockCipher = u.block
|
|
} else {
|
|
panic("UDPUnderlay readDataAckSegment(): block is nil")
|
|
}
|
|
}
|
|
encryptedPayload := remaining[:das.payloadLen+cipher.DefaultOverhead]
|
|
decryptedPayload, err = blockCipher.DecryptWithNonce(encryptedPayload, nonce)
|
|
if u.isClient {
|
|
cipher.ClientDirectDecrypt.Add(1)
|
|
} else {
|
|
cipher.ServerDirectDecrypt.Add(1)
|
|
}
|
|
if err != nil {
|
|
if u.isClient {
|
|
cipher.ClientFailedDirectDecrypt.Add(1)
|
|
} else {
|
|
cipher.ServerFailedDirectDecrypt.Add(1)
|
|
}
|
|
return nil, fmt.Errorf("DecryptWithNonce() failed: %w", err)
|
|
}
|
|
if int(das.payloadLen)+cipher.DefaultOverhead+int(das.suffixLen) != len(remaining) {
|
|
return nil, fmt.Errorf("padding: size not match")
|
|
}
|
|
} else {
|
|
if int(das.suffixLen) != len(remaining) {
|
|
return nil, fmt.Errorf("padding: size not match")
|
|
}
|
|
}
|
|
|
|
return &segment{
|
|
metadata: das,
|
|
payload: decryptedPayload,
|
|
transport: util.UDPTransport,
|
|
}, nil
|
|
}
|
|
|
|
func (u *UDPUnderlay) writeOneSegment(seg *segment, addr *net.UDPAddr) error {
|
|
if seg == nil {
|
|
return stderror.ErrNullPointer
|
|
}
|
|
if u.isClient && addr.String() != u.serverAddr.String() {
|
|
return fmt.Errorf("can't write to %v, UDP server address is %v", addr, u.serverAddr)
|
|
}
|
|
|
|
u.sendMutex.Lock()
|
|
defer u.sendMutex.Unlock()
|
|
|
|
var blockCipher cipher.BlockCipher
|
|
if u.isClient {
|
|
if u.block == nil {
|
|
panic(fmt.Sprintf("%v cipher block is not ready", u))
|
|
} else {
|
|
blockCipher = u.block
|
|
}
|
|
} else {
|
|
if seg.block != nil {
|
|
blockCipher = seg.block
|
|
} else {
|
|
sessionID, err := seg.SessionID()
|
|
if err != nil {
|
|
return fmt.Errorf("%v SessionID() failed: %v", seg, err)
|
|
}
|
|
session, ok := u.sessionMap.Load(sessionID)
|
|
if !ok {
|
|
return fmt.Errorf("session %d not found", sessionID)
|
|
}
|
|
s := session.(*Session)
|
|
if s.block == nil {
|
|
// stderror.ErrNotReady is needed to trigger stderror.ShouldRetry.
|
|
return fmt.Errorf("%v cipher block is not ready, please try again later: %w", s, stderror.ErrNotReady)
|
|
} else {
|
|
blockCipher = s.block
|
|
}
|
|
}
|
|
}
|
|
|
|
if ss, ok := toSessionStruct(seg.metadata); ok {
|
|
maxPaddingSize := MaxPaddingSize(u.mtu, u.IPVersion(), u.TransportProtocol(), int(ss.payloadLen), 0)
|
|
padding := newPadding(paddingOpts{
|
|
maxLen: maxPaddingSize,
|
|
minConsecutiveASCIILen: mathext.Min(maxPaddingSize, recommendedConsecutiveASCIILen),
|
|
})
|
|
ss.suffixLen = uint8(len(padding))
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v is sending %v", u, seg)
|
|
}
|
|
|
|
plaintextMetadata := seg.metadata.Marshal()
|
|
encryptedMetadata, err := blockCipher.Encrypt(plaintextMetadata)
|
|
if err != nil {
|
|
return fmt.Errorf("Encrypt() failed: %w", err)
|
|
}
|
|
nonce := encryptedMetadata[:cipher.DefaultNonceSize]
|
|
dataToSend := encryptedMetadata
|
|
if len(seg.payload) > 0 {
|
|
encryptedPayload, err := blockCipher.EncryptWithNonce(seg.payload, nonce)
|
|
if err != nil {
|
|
return fmt.Errorf("EncryptWithNonce() failed: %w", err)
|
|
}
|
|
dataToSend = append(dataToSend, encryptedPayload...)
|
|
}
|
|
dataToSend = append(dataToSend, padding...)
|
|
if _, err := u.conn.WriteToUDP(dataToSend, addr); err != nil {
|
|
return fmt.Errorf("WriteToUDP() failed: %w", err)
|
|
}
|
|
if u.isClient {
|
|
metrics.UploadBytes.Add(int64(len(dataToSend)))
|
|
} else {
|
|
metrics.DownloadBytes.Add(int64(len(dataToSend)))
|
|
}
|
|
metrics.OutputPaddingBytes.Add(int64(len(padding)))
|
|
} else if das, ok := toDataAckStruct(seg.metadata); ok {
|
|
padding1 := newPadding(paddingOpts{
|
|
maxLen: MaxPaddingSize(u.mtu, u.IPVersion(), u.TransportProtocol(), int(das.payloadLen), 0),
|
|
})
|
|
padding2 := newPadding(paddingOpts{
|
|
maxLen: MaxPaddingSize(u.mtu, u.IPVersion(), u.TransportProtocol(), int(das.payloadLen), len(padding1)),
|
|
})
|
|
das.prefixLen = uint8(len(padding1))
|
|
das.suffixLen = uint8(len(padding2))
|
|
if log.IsLevelEnabled(log.TraceLevel) {
|
|
log.Tracef("%v is sending %v", u, seg)
|
|
}
|
|
|
|
plaintextMetadata := seg.metadata.Marshal()
|
|
encryptedMetadata, err := blockCipher.Encrypt(plaintextMetadata)
|
|
if err != nil {
|
|
return fmt.Errorf("Encrypt() failed: %w", err)
|
|
}
|
|
nonce := encryptedMetadata[:cipher.DefaultNonceSize]
|
|
dataToSend := append(encryptedMetadata, padding1...)
|
|
if len(seg.payload) > 0 {
|
|
encryptedPayload, err := blockCipher.EncryptWithNonce(seg.payload, nonce)
|
|
if err != nil {
|
|
return fmt.Errorf("EncryptWithNonce() failed: %w", err)
|
|
}
|
|
dataToSend = append(dataToSend, encryptedPayload...)
|
|
}
|
|
dataToSend = append(dataToSend, padding2...)
|
|
if _, err := u.conn.WriteToUDP(dataToSend, addr); err != nil {
|
|
return fmt.Errorf("WriteToUDP() failed: %w", err)
|
|
}
|
|
if u.isClient {
|
|
metrics.UploadBytes.Add(int64(len(dataToSend)))
|
|
} else {
|
|
metrics.DownloadBytes.Add(int64(len(dataToSend)))
|
|
}
|
|
metrics.OutputPaddingBytes.Add(int64(len(padding1)))
|
|
metrics.OutputPaddingBytes.Add(int64(len(padding2)))
|
|
} else {
|
|
return stderror.ErrInvalidArgument
|
|
}
|
|
return nil
|
|
}
|