Update On Fri Aug 29 20:33:54 CEST 2025

This commit is contained in:
github-action[bot]
2025-08-29 20:33:55 +02:00
parent 47e66d71f3
commit 8245d77671
185 changed files with 136347 additions and 83134 deletions

View File

@@ -29,7 +29,7 @@ var errSniffingTimeout = errors.New("timeout on sniffing")
type cachedReader struct {
sync.Mutex
reader *pipe.Reader
reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
cache buf.MultiBuffer
}
@@ -87,7 +87,9 @@ func (r *cachedReader) Interrupt() {
r.cache = buf.ReleaseMulti(r.cache)
}
r.Unlock()
r.reader.Interrupt()
if p, ok := r.reader.(*pipe.Reader); ok {
p.Interrupt()
}
}
// DefaultDispatcher is a default implementation of Dispatcher.
@@ -319,7 +321,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
d.routedDispatch(ctx, outbound, destination)
} else {
cReader := &cachedReader{
reader: outbound.Reader.(*pipe.Reader),
reader: outbound.Reader.(buf.TimeoutReader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)

View File

@@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error {
// HandleResponse handles udp response packet from remote DNS server.
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
ipRec, err := parseResponse(packet.Payload.Bytes())
payload := packet.Payload
ipRec, err := parseResponse(payload.Bytes())
payload.Release()
if err != nil {
errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
return
@@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
newReq.msg = &newMsg
s.addPendingRequest(&newReq)
b, _ := dns.PackMessage(newReq.msg)
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
b.UDP = &copyDest
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
return
}
@@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
}
s.addPendingRequest(udpReq)
b, _ := dns.PackMessage(req.msg)
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
b.UDP = &copyDest
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
}
}

View File

@@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
}
out:
err := h.proxy.Process(ctx, link, h)
var errC error
if err != nil {
if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) {
errC = errors.Cause(err)
if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) {
err = nil
}
}
@@ -251,7 +253,11 @@ out:
errors.LogInfo(ctx, err.Error())
common.Interrupt(link.Writer)
} else {
common.Close(link.Writer)
if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) {
common.Interrupt(link.Writer)
} else {
common.Close(link.Writer)
}
}
common.Interrupt(link.Reader)
}

View File

@@ -24,9 +24,46 @@ var ErrReadTimeout = errors.New("IO timeout")
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
type TimeoutReader interface {
Reader
ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
}
type TimeoutWrapperReader struct {
Reader
mb MultiBuffer
err error
done chan struct{}
}
func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) {
if r.done != nil {
<-r.done
r.done = nil
return r.mb, r.err
}
r.mb = nil
r.err = nil
return r.Reader.ReadMultiBuffer()
}
func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) {
if r.done == nil {
r.done = make(chan struct{})
go func() {
r.mb, r.err = r.Reader.ReadMultiBuffer()
close(r.done)
}()
}
time.Sleep(duration)
select {
case <-r.done:
r.done = nil
return r.mb, r.err
default:
return nil, nil
}
}
// Writer extends io.Writer with MultiBuffer.
type Writer interface {
// WriteMultiBuffer writes a MultiBuffer into underlying writer.

View File

@@ -2,6 +2,7 @@ package mux
import (
"context"
goerrors "errors"
"io"
"sync"
"time"
@@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
ctx, cancel := context.WithCancel(ctx)
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
errors.LogInfoInner(ctx, err, "failed to handler mux client connection")
if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil {
errC := errors.Cause(errP)
if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) {
errors.LogInfoInner(ctx, errP, "failed to handler mux client connection")
}
}
common.Must(c.Close())
cancel()
@@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() {
select {
case <-m.done.Wait():
m.sessionManager.Close()
common.Close(m.link.Writer)
common.Interrupt(m.link.Writer)
common.Interrupt(m.link.Reader)
return
case <-m.timer.C:
@@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
return nil
}
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
outbounds := session.OutboundsFromContext(ctx)
ob := outbounds[len(outbounds)-1]
transferType := protocol.TransferTypeStream
@@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
defer s.Close(false)
defer writer.Close()
defer timer.Reset(time.Second * 16)
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
if err := writeFirstPayload(s.input, writer); err != nil {
@@ -307,7 +312,11 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
}
s.input = link.Reader
s.output = link.Writer
go fetchInput(ctx, s, m.link.Writer)
if _, ok := link.Reader.(*pipe.Reader); ok {
go fetchInput(ctx, s, m.link.Writer, m.timer)
} else {
fetchInput(ctx, s, m.link.Writer, m.timer)
}
return true
}

View File

@@ -87,7 +87,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
link: link,
sessionManager: NewSessionManager(),
}
go worker.run(ctx)
if inbound := session.InboundFromContext(ctx); inbound != nil {
inbound.CanSpliceCopy = 3
}
if _, ok := link.Reader.(*pipe.Reader); ok {
go worker.run(ctx)
} else {
worker.run(ctx)
}
return worker, nil
}
@@ -311,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) {
reader := &buf.BufferedReader{Reader: w.link.Reader}
defer w.sessionManager.Close()
defer common.Close(w.link.Writer)
defer common.Interrupt(w.link.Reader)
defer common.Interrupt(w.link.Writer)
for {
select {

View File

@@ -19,7 +19,7 @@ import (
var (
Version_x byte = 25
Version_y byte = 8
Version_z byte = 3
Version_z byte = 29
)
var (

View File

@@ -2,10 +2,8 @@ package dokodemo
import (
"context"
"runtime"
"strconv"
"strings"
"sync/atomic"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
@@ -14,11 +12,10 @@ import (
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls"
)
@@ -144,39 +141,11 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
})
errors.LogInfo(ctx, "received request for ", conn.RemoteAddr())
plcy := d.policy()
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
if inbound != nil {
inbound.Timer = timer
}
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return errors.New("failed to dispatch request").Base(err)
}
requestCount := int32(1)
requestDone := func() error {
defer func() {
if atomic.AddInt32(&requestCount, -1) == 0 {
timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
}
}()
var reader buf.Reader
if dest.Network == net.Network_UDP {
reader = buf.NewPacketReader(conn)
} else {
reader = buf.NewReader(conn)
}
if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport request").Base(err)
}
return nil
var reader buf.Reader
if dest.Network == net.Network_TCP {
reader = buf.NewReader(conn)
} else {
reader = buf.NewPacketReader(conn)
}
var writer buf.Writer
@@ -208,72 +177,17 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
return err
}
writer = NewPacketWriter(pConn, &dest, mark, back)
defer func() {
runtime.Gosched()
common.Interrupt(link.Reader) // maybe duplicated
runtime.Gosched()
writer.(*PacketWriter).Close() // close fake UDP conns
}()
/*
sockopt := &internet.SocketConfig{
Tproxy: internet.SocketConfig_TProxy,
}
if dest.Address.Family().IsIP() {
sockopt.BindAddress = dest.Address.IP()
sockopt.BindPort = uint32(dest.Port)
}
if d.sockopt != nil {
sockopt.Mark = d.sockopt.Mark
}
tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
if err != nil {
return err
}
defer tConn.Close()
writer = &buf.SequentialWriter{Writer: tConn}
tReader := buf.NewPacketReader(tConn)
requestCount++
tproxyRequest = func() error {
defer func() {
if atomic.AddInt32(&requestCount, -1) == 0 {
timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
}
}()
if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport request (TPROXY conn)").Base(err)
}
return nil
}
*/
defer writer.(*PacketWriter).Close() // close fake UDP conns
}
}
responseDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
if network == net.Network_UDP && destinationOverridden {
buf.Copy(link.Reader, writer) // respect upload's timeout
return nil
}
if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport response").Base(err)
}
return nil
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
Reader: &buf.TimeoutWrapperReader{Reader: reader},
Writer: writer},
); err != nil {
return errors.New("failed to dispatch request").Base(err)
}
if err := task.Run(ctx,
task.OnSuccess(func() error { return task.Run(ctx, requestDone) }, task.Close(link.Writer)),
responseDone); err != nil {
runtime.Gosched()
common.Interrupt(link.Writer)
runtime.Gosched()
common.Interrupt(link.Reader)
return errors.New("connection ends").Base(err)
}
return nil
return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process()
}
func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {

View File

@@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool {
}
a := addr.AsAddress()
return a != net.AnyIP
return a != net.AnyIP && a != net.AnyIPv6
}
// Process implements proxy.Outbound.
@@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
}
}
}
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
destAddr := b.UDP.RawNetAddr()
if destAddr == nil {
b.Release()
continue

View File

@@ -18,12 +18,12 @@ import (
"github.com/xtls/xray-core/common/protocol"
http_proto "github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat"
)
@@ -173,64 +173,31 @@ Start:
return err
}
func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
if err != nil {
return errors.New("failed to write back OK response").Base(err)
}
plcy := s.policy()
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
if inbound != nil {
inbound.Timer = timer
}
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return err
}
if reader.Buffered() > 0 {
payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
reader := buf.NewReader(conn)
if buffer.Buffered() > 0 {
payload, err := buf.ReadFrom(io.LimitReader(buffer, int64(buffer.Buffered())))
if err != nil {
return err
}
if err := link.Writer.WriteMultiBuffer(payload); err != nil {
return err
}
reader = nil
reader = &buf.BufferedReader{Reader: reader, Buffer: payload}
buffer = nil
}
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
responseDone := func() error {
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
v2writer := buf.NewWriter(conn)
if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
return err
}
return nil
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
Reader: &buf.TimeoutWrapperReader{Reader: reader},
Writer: buf.NewWriter(conn)},
); err != nil {
return errors.New("failed to dispatch request").Base(err)
}
closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return errors.New("connection ends").Base(err)
}
return nil
}

View File

@@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
}
}
if err != nil {
if errors.Cause(err) == io.EOF {
return nil
}
return err
}
}

View File

@@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
request := protocol.RequestHeaderFromContext(ctx)
payload := packet.Payload
if request == nil {
payload.Release()
return
}
payload := packet.Payload
if payload.UDP != nil {
request = &protocol.RequestHeader{
User: request.User,
@@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
return
}
defer data.Release()
conn.Write(data.Bytes())
data.Release()
})
defer udpServer.RemoveRay()

View File

@@ -14,13 +14,12 @@ import (
"github.com/xtls/xray-core/common/protocol"
udp_proto "github.com/xtls/xray-core/common/protocol/udp"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/signal"
"github.com/xtls/xray-core/common/task"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/proxy"
"github.com/xtls/xray-core/proxy/http"
"github.com/xtls/xray-core/transport"
"github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/udp"
)
@@ -158,8 +157,16 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
Reason: "",
})
}
return s.transport(ctx, reader, conn, dest, dispatcher, inbound)
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
Reader: &buf.TimeoutWrapperReader{Reader: reader},
Writer: buf.NewWriter(conn)},
); err != nil {
return errors.New("failed to dispatch request").Base(err)
}
return nil
}
if request.Command == protocol.RequestCommandUDP {
@@ -178,54 +185,6 @@ func (*Server) handleUDP(c io.Reader) error {
return common.Error2(io.Copy(buf.DiscardBytes, c))
}
func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
ctx, cancel := context.WithCancel(ctx)
timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
if inbound != nil {
inbound.Timer = timer
}
plcy := s.policy()
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
link, err := dispatcher.Dispatch(ctx, dest)
if err != nil {
return err
}
requestDone := func() error {
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP request").Base(err)
}
return nil
}
responseDone := func() error {
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
v2writer := buf.NewWriter(writer)
if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
return errors.New("failed to transport all TCP response").Base(err)
}
return nil
}
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
common.Interrupt(link.Reader)
common.Interrupt(link.Writer)
return errors.New("connection ends").Base(err)
}
return nil
}
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) {
errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String())
@@ -237,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
request := protocol.RequestHeaderFromContext(ctx)
if request == nil {
payload.Release()
return
}
@@ -255,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
errors.LogWarningInner(ctx, err, "failed to write UDP response")
return
}
defer udpMessage.Release()
conn.Write(udpMessage.Bytes())
udpMessage.Release()
})
defer udpServer.RemoveRay()
@@ -265,9 +225,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
if inbound != nil && inbound.Source.IsValid() {
errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
}
if inbound.CanSpliceCopy == 2 {
inbound.CanSpliceCopy = 1
}
var dest *net.Destination

View File

@@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
target = b.UDP
}
if _, err := w.writePacket(b.Bytes(), *target); err != nil {
b.Release()
buf.ReleaseMulti(mb)
return err
}
b.Release()
}
return nil
}

View File

@@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
sessionPolicy = s.policyManager.ForLevel(user.Level)
if destination.Network == net.Network_UDP { // handle udp request
return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
}
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
@@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
}
func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
func (s *Server) handleUDPPayload(ctx context.Context, sessionPolicy policy.Session, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
defer timer.SetTimeout(0)
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
udpPayload := packet.Payload
if udpPayload.UDP == nil {
@@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
errors.LogWarningInner(ctx, err, "failed to write response")
cancel()
} else {
timer.Update()
}
})
defer udpServer.RemoveRay()
@@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
var dest *net.Destination
for {
select {
case <-ctx.Done():
return nil
default:
mb, err := clientReader.ReadMultiBuffer()
if err != nil {
if errors.Cause(err) != io.EOF {
return errors.New("unexpected EOF").Base(err)
}
requestDone := func() error {
for {
select {
case <-ctx.Done():
return nil
}
default:
mb, err := clientReader.ReadMultiBuffer()
if err != nil {
if errors.Cause(err) != io.EOF {
return errors.New("unexpected EOF").Base(err)
}
return nil
}
mb2, b := buf.SplitFirst(mb)
if b == nil {
continue
}
destination := *b.UDP
mb2, b := buf.SplitFirst(mb)
if b == nil {
continue
}
timer.Update()
destination := *b.UDP
currentPacketCtx := ctx
if inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source,
To: destination,
Status: log.AccessAccepted,
Reason: "",
Email: user.Email,
})
}
errors.LogInfo(ctx, "tunnelling request to ", destination)
currentPacketCtx := ctx
if inbound.Source.IsValid() {
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
From: inbound.Source,
To: destination,
Status: log.AccessAccepted,
Reason: "",
Email: user.Email,
})
}
errors.LogInfo(ctx, "tunnelling request to ", destination)
if !s.cone || dest == nil {
dest = &destination
}
if !s.cone || dest == nil {
dest = &destination
}
udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
for _, payload := range mb2 {
udpServer.Dispatch(currentPacketCtx, *dest, payload)
udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
for _, payload := range mb2 {
udpServer.Dispatch(currentPacketCtx, *dest, payload)
}
}
}
}
if err := task.Run(ctx, requestDone); err != nil {
return err
}
return nil
}
func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,

View File

@@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/common/crypto"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/protocol"
"lukechampine.com/blake3"
)
@@ -66,7 +67,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if i.NfsPKeys == nil {
return nil, errors.New("uninitialized")
}
c := NewCommonConn(conn)
c := NewCommonConn(conn, protocol.HasAESGCMHardwareSupport)
ivAndRealysLength := 16 + i.RelaysLength
pfsKeyExchangeLength := 18 + 1184 + 32 + 16
@@ -108,18 +109,18 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:])
relays = relays[index+32:]
}
nfsGCM := NewGCM(iv, nfsKey)
nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
if i.Seconds > 0 {
i.RWLock.RLock()
if time.Now().Before(i.Expire) {
c.Client = i
c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection
nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
nfsAEAD.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
nfsAEAD.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
i.RWLock.RUnlock()
c.PreWrite = clientHello[:ivAndRealysLength+18+32]
c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey)
c.AEAD = NewAEAD(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey, c.UseAES)
if i.XorMode == 2 {
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16)
}
@@ -129,15 +130,15 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
}
pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength]
nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
nfsAEAD.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
mlkem768DKey, _ := mlkem.GenerateKey768()
x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...)
nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
nfsAEAD.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:]
nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
nfsAEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
nfsAEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
if _, err := conn.Write(clientHello); err != nil {
return nil, err
@@ -148,7 +149,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
return nil, err
}
nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
nfsAEAD.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088])
if err != nil {
return nil, err
@@ -165,14 +166,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
copy(pfsKey, mlkem768Key)
copy(pfsKey[32:], x25519Key)
c.UnitedKey = append(pfsKey, nfsKey...)
c.GCM = NewGCM(pfsPublicKey, c.UnitedKey)
c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey)
c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1088+32], c.UnitedKey, c.UseAES)
encryptedTicket := make([]byte, 32)
if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
return nil, err
}
if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
if _, err := c.PeerAEAD.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
return nil, err
}
seconds := DecodeLength(encryptedTicket)
@@ -189,7 +190,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err
}
if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
if _, err := c.PeerAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
}
length := DecodeLength(encryptedLength[:2])

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/xtls/xray-core/common/errors"
"golang.org/x/crypto/chacha20poly1305"
"lukechampine.com/blake3"
)
@@ -23,19 +24,21 @@ var OutBytesPool = sync.Pool{
type CommonConn struct {
net.Conn
UseAES bool
Client *ClientInstance
UnitedKey []byte
PreWrite []byte
GCM *GCM
PeerGCM *GCM
AEAD *AEAD
PeerAEAD *AEAD
PeerPadding []byte
PeerInBytes []byte
PeerCache []byte
}
func NewCommonConn(conn net.Conn) *CommonConn {
func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
return &CommonConn{
Conn: conn,
UseAES: useAES,
PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
}
}
@@ -55,12 +58,12 @@ func (c *CommonConn) Write(b []byte) (int, error) {
headerAndData := outBytes[:5+len(b)+16]
EncodeHeader(headerAndData, len(b)+16)
max := false
if bytes.Equal(c.GCM.Nonce[:], MaxNonce) {
if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
max = true
}
c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5])
c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
if max {
c.GCM = NewGCM(headerAndData, c.UnitedKey)
c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
}
if c.PreWrite != nil {
headerAndData = append(c.PreWrite, headerAndData...)
@@ -77,12 +80,12 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
}
if c.PeerGCM == nil { // client's 0-RTT
if c.PeerAEAD == nil { // client's 0-RTT
serverRandom := make([]byte, 16)
if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
return 0, err
}
c.PeerGCM = NewGCM(serverRandom, c.UnitedKey)
c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
if xorConn, ok := c.Conn.(*XorConn); ok {
xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
}
@@ -91,7 +94,7 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
return 0, err
}
if _, err := c.PeerGCM.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
return 0, err
}
c.PeerPadding = nil
@@ -126,13 +129,13 @@ func (c *CommonConn) Read(b []byte) (int, error) {
if len(dst) <= len(b) {
dst = b[:len(dst)] // avoids another copy()
}
var newGCM *GCM
if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) {
newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey)
var newAEAD *AEAD
if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES)
}
_, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader)
if newGCM != nil {
c.PeerGCM = newGCM
_, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader)
if newAEAD != nil {
c.PeerAEAD = newAEAD
}
if err != nil {
return 0, err
@@ -144,28 +147,32 @@ func (c *CommonConn) Read(b []byte) (int, error) {
return len(dst), nil
}
type GCM struct {
type AEAD struct {
cipher.AEAD
Nonce [12]byte
}
func NewGCM(ctx, key []byte) *GCM {
func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
k := make([]byte, 32)
blake3.DeriveKey(k, string(ctx), key)
block, _ := aes.NewCipher(k)
aead, _ := cipher.NewGCM(block)
return &GCM{AEAD: aead}
//chacha20poly1305.New()
var aead cipher.AEAD
if useAES {
block, _ := aes.NewCipher(k)
aead, _ = cipher.NewGCM(block)
} else {
aead, _ = chacha20poly1305.New(k)
}
return &AEAD{AEAD: aead}
}
func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if nonce == nil {
nonce = IncreaseNonce(a.Nonce[:])
}
return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
}
func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
if nonce == nil {
nonce = IncreaseNonce(a.Nonce[:])
}

View File

@@ -102,7 +102,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if i.NfsSKeys == nil {
return nil, errors.New("uninitialized")
}
c := NewCommonConn(conn)
c := NewCommonConn(conn, true)
ivAndRelays := make([]byte, 16+i.RelaysLength)
if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
@@ -151,16 +151,21 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
}
relays = relays[32:]
}
nfsGCM := NewGCM(iv, nfsKey)
nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
encryptedLength := make([]byte, 18)
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err
}
if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
decryptedLength := make([]byte, 2)
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
c.UseAES = !c.UseAES
nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
}
}
length := DecodeLength(encryptedLength[:2])
length := DecodeLength(decryptedLength)
if length == 32 {
if i.Seconds == 0 {
@@ -170,7 +175,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
return nil, err
}
ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil)
ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
if err != nil {
return nil, err
}
@@ -193,8 +198,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
c.PreWrite = make([]byte, 16)
rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
c.GCM = NewGCM(c.PreWrite, c.UnitedKey)
c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
if i.XorMode == 2 {
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
}
@@ -208,7 +213,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
return nil, err
}
if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
return nil, err
}
mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
@@ -230,8 +235,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
copy(pfsKey[32:], x25519Key)
pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
c.UnitedKey = append(pfsKey, nfsKey...)
c.GCM = NewGCM(pfsPublicKey, c.UnitedKey)
c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey)
c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
ticket := make([]byte, 16)
rand.Read(ticket)
copy(ticket, EncodeLength(int(i.Seconds*4/5)))
@@ -240,11 +245,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
encryptedTicketLength := 32
paddingLength := int(crypto.RandBetween(100, 1000))
serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
nfsGCM.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
if _, err := conn.Write(serverHello); err != nil {
return nil, err
@@ -264,14 +269,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
return nil, err
}
if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
return nil, err
}
encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
return nil, err
}
if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
return nil, err
}

View File

@@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet)
type connEntry struct {
link *transport.Link
timer signal.ActivityUpdater
timer *signal.ActivityTimer
cancel context.CancelFunc
closed bool
}
func (c *connEntry) Close() error {
c.timer.SetTimeout(0)
return nil
}
func (c *connEntry) terminate() {
if c.closed {
panic("terminate called more than once")
}
c.closed = true
c.cancel()
common.Interrupt(c.link.Reader)
common.Interrupt(c.link.Writer)
}
type Dispatcher struct {
@@ -32,6 +48,7 @@ type Dispatcher struct {
dispatcher routing.Dispatcher
callback ResponseCallback
callClose func() error
closed bool
}
func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
@@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di
func (v *Dispatcher) RemoveRay() {
v.Lock()
defer v.Unlock()
v.removeRay()
}
func (v *Dispatcher) removeRay() {
v.closed = true
if v.conn != nil {
common.Interrupt(v.conn.link.Reader)
common.Close(v.conn.link.Writer)
v.conn.Close()
v.conn = nil
}
}
@@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
v.Lock()
defer v.Unlock()
if v.closed {
return nil, errors.New("dispatcher is closed")
}
if v.conn != nil {
return v.conn, nil
if v.conn.closed {
v.conn = nil
} else {
return v.conn, nil
}
}
errors.LogInfo(ctx, "establishing new connection for ", dest)
ctx, cancel := context.WithCancel(ctx)
entry := &connEntry{}
removeRay := func() {
v.Lock()
defer v.Unlock()
// sometimes the entry is already removed by others, don't close again
if entry == v.conn {
cancel()
v.removeRay()
}
}
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
link, err := v.dispatcher.Dispatch(ctx, dest)
if err != nil {
cancel()
return nil, errors.New("failed to dispatch request to ", dest).Base(err)
}
*entry = connEntry{
entry := &connEntry{
link: link,
timer: timer,
cancel: removeRay,
cancel: cancel,
}
entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute)
v.conn = entry
go handleInput(ctx, entry, dest, v.callback, v.callClose)
return entry, nil
@@ -106,7 +118,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
if outputStream != nil {
if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
conn.cancel()
conn.Close()
return
}
}
@@ -114,7 +126,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
defer func() {
conn.cancel()
conn.Close()
if callClose != nil {
callClose()
}

View File

@@ -200,16 +200,19 @@ func (p *pipe) Interrupt() {
p.Lock()
defer p.Unlock()
if !p.data.IsEmpty() {
buf.ReleaseMulti(p.data)
p.data = nil
if p.state == closed {
p.state = errord
}
}
if p.state == closed || p.state == errord {
return
}
p.state = errord
if !p.data.IsEmpty() {
buf.ReleaseMulti(p.data)
p.data = nil
}
common.Must(p.done.Close())
}