mirror of
https://github.com/bolucat/Archive.git
synced 2025-12-24 13:28:37 +08:00
Update On Fri Aug 29 20:33:54 CEST 2025
This commit is contained in:
@@ -29,7 +29,7 @@ var errSniffingTimeout = errors.New("timeout on sniffing")
|
||||
|
||||
type cachedReader struct {
|
||||
sync.Mutex
|
||||
reader *pipe.Reader
|
||||
reader buf.TimeoutReader // *pipe.Reader or *buf.TimeoutWrapperReader
|
||||
cache buf.MultiBuffer
|
||||
}
|
||||
|
||||
@@ -87,7 +87,9 @@ func (r *cachedReader) Interrupt() {
|
||||
r.cache = buf.ReleaseMulti(r.cache)
|
||||
}
|
||||
r.Unlock()
|
||||
r.reader.Interrupt()
|
||||
if p, ok := r.reader.(*pipe.Reader); ok {
|
||||
p.Interrupt()
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultDispatcher is a default implementation of Dispatcher.
|
||||
@@ -319,7 +321,7 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
|
||||
d.routedDispatch(ctx, outbound, destination)
|
||||
} else {
|
||||
cReader := &cachedReader{
|
||||
reader: outbound.Reader.(*pipe.Reader),
|
||||
reader: outbound.Reader.(buf.TimeoutReader),
|
||||
}
|
||||
outbound.Reader = cReader
|
||||
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
|
||||
|
||||
@@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error {
|
||||
|
||||
// HandleResponse handles udp response packet from remote DNS server.
|
||||
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
||||
ipRec, err := parseResponse(packet.Payload.Bytes())
|
||||
payload := packet.Payload
|
||||
ipRec, err := parseResponse(payload.Bytes())
|
||||
payload.Release()
|
||||
if err != nil {
|
||||
errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
|
||||
return
|
||||
@@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
||||
newReq.msg = &newMsg
|
||||
s.addPendingRequest(&newReq)
|
||||
b, _ := dns.PackMessage(newReq.msg)
|
||||
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
|
||||
b.UDP = ©Dest
|
||||
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
|
||||
return
|
||||
}
|
||||
@@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
|
||||
}
|
||||
s.addPendingRequest(udpReq)
|
||||
b, _ := dns.PackMessage(req.msg)
|
||||
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
|
||||
b.UDP = ©Dest
|
||||
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
|
||||
}
|
||||
out:
|
||||
err := h.proxy.Process(ctx, link, h)
|
||||
var errC error
|
||||
if err != nil {
|
||||
if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) {
|
||||
errC = errors.Cause(err)
|
||||
if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
@@ -251,7 +253,11 @@ out:
|
||||
errors.LogInfo(ctx, err.Error())
|
||||
common.Interrupt(link.Writer)
|
||||
} else {
|
||||
common.Close(link.Writer)
|
||||
if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) {
|
||||
common.Interrupt(link.Writer)
|
||||
} else {
|
||||
common.Close(link.Writer)
|
||||
}
|
||||
}
|
||||
common.Interrupt(link.Reader)
|
||||
}
|
||||
|
||||
@@ -24,9 +24,46 @@ var ErrReadTimeout = errors.New("IO timeout")
|
||||
|
||||
// TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
|
||||
type TimeoutReader interface {
|
||||
Reader
|
||||
ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
|
||||
}
|
||||
|
||||
type TimeoutWrapperReader struct {
|
||||
Reader
|
||||
mb MultiBuffer
|
||||
err error
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func (r *TimeoutWrapperReader) ReadMultiBuffer() (MultiBuffer, error) {
|
||||
if r.done != nil {
|
||||
<-r.done
|
||||
r.done = nil
|
||||
return r.mb, r.err
|
||||
}
|
||||
r.mb = nil
|
||||
r.err = nil
|
||||
return r.Reader.ReadMultiBuffer()
|
||||
}
|
||||
|
||||
func (r *TimeoutWrapperReader) ReadMultiBufferTimeout(duration time.Duration) (MultiBuffer, error) {
|
||||
if r.done == nil {
|
||||
r.done = make(chan struct{})
|
||||
go func() {
|
||||
r.mb, r.err = r.Reader.ReadMultiBuffer()
|
||||
close(r.done)
|
||||
}()
|
||||
}
|
||||
time.Sleep(duration)
|
||||
select {
|
||||
case <-r.done:
|
||||
r.done = nil
|
||||
return r.mb, r.err
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Writer extends io.Writer with MultiBuffer.
|
||||
type Writer interface {
|
||||
// WriteMultiBuffer writes a MultiBuffer into underlying writer.
|
||||
|
||||
@@ -2,6 +2,7 @@ package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
goerrors "errors"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
|
||||
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
|
||||
errors.LogInfoInner(ctx, err, "failed to handler mux client connection")
|
||||
if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil {
|
||||
errC := errors.Cause(errP)
|
||||
if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) {
|
||||
errors.LogInfoInner(ctx, errP, "failed to handler mux client connection")
|
||||
}
|
||||
}
|
||||
common.Must(c.Close())
|
||||
cancel()
|
||||
@@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() {
|
||||
select {
|
||||
case <-m.done.Wait():
|
||||
m.sessionManager.Close()
|
||||
common.Close(m.link.Writer)
|
||||
common.Interrupt(m.link.Writer)
|
||||
common.Interrupt(m.link.Reader)
|
||||
return
|
||||
case <-m.timer.C:
|
||||
@@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
|
||||
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
|
||||
outbounds := session.OutboundsFromContext(ctx)
|
||||
ob := outbounds[len(outbounds)-1]
|
||||
transferType := protocol.TransferTypeStream
|
||||
@@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
|
||||
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
|
||||
defer s.Close(false)
|
||||
defer writer.Close()
|
||||
defer timer.Reset(time.Second * 16)
|
||||
|
||||
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
|
||||
if err := writeFirstPayload(s.input, writer); err != nil {
|
||||
@@ -307,7 +312,11 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
|
||||
}
|
||||
s.input = link.Reader
|
||||
s.output = link.Writer
|
||||
go fetchInput(ctx, s, m.link.Writer)
|
||||
if _, ok := link.Reader.(*pipe.Reader); ok {
|
||||
go fetchInput(ctx, s, m.link.Writer, m.timer)
|
||||
} else {
|
||||
fetchInput(ctx, s, m.link.Writer, m.timer)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -87,7 +87,14 @@ func NewServerWorker(ctx context.Context, d routing.Dispatcher, link *transport.
|
||||
link: link,
|
||||
sessionManager: NewSessionManager(),
|
||||
}
|
||||
go worker.run(ctx)
|
||||
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
||||
inbound.CanSpliceCopy = 3
|
||||
}
|
||||
if _, ok := link.Reader.(*pipe.Reader); ok {
|
||||
go worker.run(ctx)
|
||||
} else {
|
||||
worker.run(ctx)
|
||||
}
|
||||
return worker, nil
|
||||
}
|
||||
|
||||
@@ -311,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) {
|
||||
reader := &buf.BufferedReader{Reader: w.link.Reader}
|
||||
|
||||
defer w.sessionManager.Close()
|
||||
defer common.Close(w.link.Writer)
|
||||
defer common.Interrupt(w.link.Reader)
|
||||
defer common.Interrupt(w.link.Writer)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
var (
|
||||
Version_x byte = 25
|
||||
Version_y byte = 8
|
||||
Version_z byte = 3
|
||||
Version_z byte = 29
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -2,10 +2,8 @@ package dokodemo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/xtls/xray-core/common"
|
||||
"github.com/xtls/xray-core/common/buf"
|
||||
@@ -14,11 +12,10 @@ import (
|
||||
"github.com/xtls/xray-core/common/net"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
"github.com/xtls/xray-core/transport/internet/tls"
|
||||
)
|
||||
@@ -144,39 +141,11 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
|
||||
})
|
||||
errors.LogInfo(ctx, "received request for ", conn.RemoteAddr())
|
||||
|
||||
plcy := d.policy()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
|
||||
|
||||
if inbound != nil {
|
||||
inbound.Timer = timer
|
||||
}
|
||||
|
||||
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
|
||||
link, err := dispatcher.Dispatch(ctx, dest)
|
||||
if err != nil {
|
||||
return errors.New("failed to dispatch request").Base(err)
|
||||
}
|
||||
|
||||
requestCount := int32(1)
|
||||
requestDone := func() error {
|
||||
defer func() {
|
||||
if atomic.AddInt32(&requestCount, -1) == 0 {
|
||||
timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||
}
|
||||
}()
|
||||
|
||||
var reader buf.Reader
|
||||
if dest.Network == net.Network_UDP {
|
||||
reader = buf.NewPacketReader(conn)
|
||||
} else {
|
||||
reader = buf.NewReader(conn)
|
||||
}
|
||||
if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport request").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
var reader buf.Reader
|
||||
if dest.Network == net.Network_TCP {
|
||||
reader = buf.NewReader(conn)
|
||||
} else {
|
||||
reader = buf.NewPacketReader(conn)
|
||||
}
|
||||
|
||||
var writer buf.Writer
|
||||
@@ -208,72 +177,17 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn st
|
||||
return err
|
||||
}
|
||||
writer = NewPacketWriter(pConn, &dest, mark, back)
|
||||
defer func() {
|
||||
runtime.Gosched()
|
||||
common.Interrupt(link.Reader) // maybe duplicated
|
||||
runtime.Gosched()
|
||||
writer.(*PacketWriter).Close() // close fake UDP conns
|
||||
}()
|
||||
/*
|
||||
sockopt := &internet.SocketConfig{
|
||||
Tproxy: internet.SocketConfig_TProxy,
|
||||
}
|
||||
if dest.Address.Family().IsIP() {
|
||||
sockopt.BindAddress = dest.Address.IP()
|
||||
sockopt.BindPort = uint32(dest.Port)
|
||||
}
|
||||
if d.sockopt != nil {
|
||||
sockopt.Mark = d.sockopt.Mark
|
||||
}
|
||||
tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tConn.Close()
|
||||
|
||||
writer = &buf.SequentialWriter{Writer: tConn}
|
||||
tReader := buf.NewPacketReader(tConn)
|
||||
requestCount++
|
||||
tproxyRequest = func() error {
|
||||
defer func() {
|
||||
if atomic.AddInt32(&requestCount, -1) == 0 {
|
||||
timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||
}
|
||||
}()
|
||||
if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport request (TPROXY conn)").Base(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
*/
|
||||
defer writer.(*PacketWriter).Close() // close fake UDP conns
|
||||
}
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
||||
|
||||
if network == net.Network_UDP && destinationOverridden {
|
||||
buf.Copy(link.Reader, writer) // respect upload's timeout
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport response").Base(err)
|
||||
}
|
||||
return nil
|
||||
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
||||
Reader: &buf.TimeoutWrapperReader{Reader: reader},
|
||||
Writer: writer},
|
||||
); err != nil {
|
||||
return errors.New("failed to dispatch request").Base(err)
|
||||
}
|
||||
|
||||
if err := task.Run(ctx,
|
||||
task.OnSuccess(func() error { return task.Run(ctx, requestDone) }, task.Close(link.Writer)),
|
||||
responseDone); err != nil {
|
||||
runtime.Gosched()
|
||||
common.Interrupt(link.Writer)
|
||||
runtime.Gosched()
|
||||
common.Interrupt(link.Reader)
|
||||
return errors.New("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil // Unlike Dispatch(), DispatchLink() will not return until the outbound finishes Process()
|
||||
}
|
||||
|
||||
func NewPacketWriter(conn net.PacketConn, d *net.Destination, mark int, back *net.UDPAddr) buf.Writer {
|
||||
|
||||
@@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool {
|
||||
}
|
||||
|
||||
a := addr.AsAddress()
|
||||
return a != net.AnyIP
|
||||
return a != net.AnyIP && a != net.AnyIPv6
|
||||
}
|
||||
|
||||
// Process implements proxy.Outbound.
|
||||
@@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
|
||||
destAddr := b.UDP.RawNetAddr()
|
||||
if destAddr == nil {
|
||||
b.Release()
|
||||
continue
|
||||
|
||||
@@ -18,12 +18,12 @@ import (
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
http_proto "github.com/xtls/xray-core/common/protocol/http"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/proxy"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
)
|
||||
|
||||
@@ -173,64 +173,31 @@ Start:
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) handleConnect(ctx context.Context, _ *http.Request, reader *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
|
||||
func (s *Server) handleConnect(ctx context.Context, _ *http.Request, buffer *bufio.Reader, conn stat.Connection, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
|
||||
_, err := conn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
|
||||
if err != nil {
|
||||
return errors.New("failed to write back OK response").Base(err)
|
||||
}
|
||||
|
||||
plcy := s.policy()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, cancel, plcy.Timeouts.ConnectionIdle)
|
||||
|
||||
if inbound != nil {
|
||||
inbound.Timer = timer
|
||||
}
|
||||
|
||||
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
|
||||
link, err := dispatcher.Dispatch(ctx, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if reader.Buffered() > 0 {
|
||||
payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
|
||||
reader := buf.NewReader(conn)
|
||||
if buffer.Buffered() > 0 {
|
||||
payload, err := buf.ReadFrom(io.LimitReader(buffer, int64(buffer.Buffered())))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := link.Writer.WriteMultiBuffer(payload); err != nil {
|
||||
return err
|
||||
}
|
||||
reader = nil
|
||||
reader = &buf.BufferedReader{Reader: reader, Buffer: payload}
|
||||
buffer = nil
|
||||
}
|
||||
|
||||
requestDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||
|
||||
return buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer))
|
||||
if inbound.CanSpliceCopy == 2 {
|
||||
inbound.CanSpliceCopy = 1
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
if inbound.CanSpliceCopy == 2 {
|
||||
inbound.CanSpliceCopy = 1
|
||||
}
|
||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
||||
|
||||
v2writer := buf.NewWriter(conn)
|
||||
if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
||||
Reader: &buf.TimeoutWrapperReader{Reader: reader},
|
||||
Writer: buf.NewWriter(conn)},
|
||||
); err != nil {
|
||||
return errors.New("failed to dispatch request").Base(err)
|
||||
}
|
||||
|
||||
closeWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, closeWriter, responseDone); err != nil {
|
||||
common.Interrupt(link.Reader)
|
||||
common.Interrupt(link.Writer)
|
||||
return errors.New("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if errors.Cause(err) == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
||||
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
||||
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
|
||||
request := protocol.RequestHeaderFromContext(ctx)
|
||||
payload := packet.Payload
|
||||
if request == nil {
|
||||
payload.Release()
|
||||
return
|
||||
}
|
||||
|
||||
payload := packet.Payload
|
||||
|
||||
if payload.UDP != nil {
|
||||
request = &protocol.RequestHeader{
|
||||
User: request.User,
|
||||
@@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
|
||||
errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
|
||||
return
|
||||
}
|
||||
defer data.Release()
|
||||
|
||||
conn.Write(data.Bytes())
|
||||
data.Release()
|
||||
})
|
||||
defer udpServer.RemoveRay()
|
||||
|
||||
|
||||
@@ -14,13 +14,12 @@ import (
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
udp_proto "github.com/xtls/xray-core/common/protocol/udp"
|
||||
"github.com/xtls/xray-core/common/session"
|
||||
"github.com/xtls/xray-core/common/signal"
|
||||
"github.com/xtls/xray-core/common/task"
|
||||
"github.com/xtls/xray-core/core"
|
||||
"github.com/xtls/xray-core/features/policy"
|
||||
"github.com/xtls/xray-core/features/routing"
|
||||
"github.com/xtls/xray-core/proxy"
|
||||
"github.com/xtls/xray-core/proxy/http"
|
||||
"github.com/xtls/xray-core/transport"
|
||||
"github.com/xtls/xray-core/transport/internet/stat"
|
||||
"github.com/xtls/xray-core/transport/internet/udp"
|
||||
)
|
||||
@@ -158,8 +157,16 @@ func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatche
|
||||
Reason: "",
|
||||
})
|
||||
}
|
||||
|
||||
return s.transport(ctx, reader, conn, dest, dispatcher, inbound)
|
||||
if inbound.CanSpliceCopy == 2 {
|
||||
inbound.CanSpliceCopy = 1
|
||||
}
|
||||
if err := dispatcher.DispatchLink(ctx, dest, &transport.Link{
|
||||
Reader: &buf.TimeoutWrapperReader{Reader: reader},
|
||||
Writer: buf.NewWriter(conn)},
|
||||
); err != nil {
|
||||
return errors.New("failed to dispatch request").Base(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if request.Command == protocol.RequestCommandUDP {
|
||||
@@ -178,54 +185,6 @@ func (*Server) handleUDP(c io.Reader) error {
|
||||
return common.Error2(io.Copy(buf.DiscardBytes, c))
|
||||
}
|
||||
|
||||
func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
|
||||
|
||||
if inbound != nil {
|
||||
inbound.Timer = timer
|
||||
}
|
||||
|
||||
plcy := s.policy()
|
||||
ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
|
||||
link, err := dispatcher.Dispatch(ctx, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestDone := func() error {
|
||||
defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
|
||||
if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport all TCP request").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
responseDone := func() error {
|
||||
if inbound.CanSpliceCopy == 2 {
|
||||
inbound.CanSpliceCopy = 1
|
||||
}
|
||||
defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
|
||||
|
||||
v2writer := buf.NewWriter(writer)
|
||||
if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
|
||||
return errors.New("failed to transport all TCP response").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
|
||||
if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
|
||||
common.Interrupt(link.Reader)
|
||||
common.Interrupt(link.Writer)
|
||||
return errors.New("connection ends").Base(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
|
||||
if s.udpFilter != nil && !s.udpFilter.Check(conn.RemoteAddr()) {
|
||||
errors.LogDebug(ctx, "Unauthorized UDP access from ", conn.RemoteAddr().String())
|
||||
@@ -237,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
|
||||
|
||||
request := protocol.RequestHeaderFromContext(ctx)
|
||||
if request == nil {
|
||||
payload.Release()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -255,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
|
||||
errors.LogWarningInner(ctx, err, "failed to write UDP response")
|
||||
return
|
||||
}
|
||||
defer udpMessage.Release()
|
||||
|
||||
conn.Write(udpMessage.Bytes())
|
||||
udpMessage.Release()
|
||||
})
|
||||
defer udpServer.RemoveRay()
|
||||
|
||||
@@ -265,9 +225,6 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
|
||||
if inbound != nil && inbound.Source.IsValid() {
|
||||
errors.LogInfo(ctx, "client UDP connection from ", inbound.Source)
|
||||
}
|
||||
if inbound.CanSpliceCopy == 2 {
|
||||
inbound.CanSpliceCopy = 1
|
||||
}
|
||||
|
||||
var dest *net.Destination
|
||||
|
||||
|
||||
@@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
|
||||
target = b.UDP
|
||||
}
|
||||
if _, err := w.writePacket(b.Bytes(), *target); err != nil {
|
||||
b.Release()
|
||||
buf.ReleaseMulti(mb)
|
||||
return err
|
||||
}
|
||||
b.Release()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
||||
sessionPolicy = s.policyManager.ForLevel(user.Level)
|
||||
|
||||
if destination.Network == net.Network_UDP { // handle udp request
|
||||
return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
|
||||
return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
|
||||
}
|
||||
|
||||
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
@@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
|
||||
return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
|
||||
}
|
||||
|
||||
func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
|
||||
func (s *Server) handleUDPPayload(ctx context.Context, sessionPolicy policy.Session, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
|
||||
defer timer.SetTimeout(0)
|
||||
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
|
||||
udpPayload := packet.Payload
|
||||
if udpPayload.UDP == nil {
|
||||
@@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
|
||||
|
||||
if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
|
||||
errors.LogWarningInner(ctx, err, "failed to write response")
|
||||
cancel()
|
||||
} else {
|
||||
timer.Update()
|
||||
}
|
||||
})
|
||||
defer udpServer.RemoveRay()
|
||||
@@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
|
||||
|
||||
var dest *net.Destination
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
mb, err := clientReader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
if errors.Cause(err) != io.EOF {
|
||||
return errors.New("unexpected EOF").Base(err)
|
||||
}
|
||||
requestDone := func() error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
mb, err := clientReader.ReadMultiBuffer()
|
||||
if err != nil {
|
||||
if errors.Cause(err) != io.EOF {
|
||||
return errors.New("unexpected EOF").Base(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
mb2, b := buf.SplitFirst(mb)
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
destination := *b.UDP
|
||||
mb2, b := buf.SplitFirst(mb)
|
||||
if b == nil {
|
||||
continue
|
||||
}
|
||||
timer.Update()
|
||||
destination := *b.UDP
|
||||
|
||||
currentPacketCtx := ctx
|
||||
if inbound.Source.IsValid() {
|
||||
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
From: inbound.Source,
|
||||
To: destination,
|
||||
Status: log.AccessAccepted,
|
||||
Reason: "",
|
||||
Email: user.Email,
|
||||
})
|
||||
}
|
||||
errors.LogInfo(ctx, "tunnelling request to ", destination)
|
||||
currentPacketCtx := ctx
|
||||
if inbound.Source.IsValid() {
|
||||
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
|
||||
From: inbound.Source,
|
||||
To: destination,
|
||||
Status: log.AccessAccepted,
|
||||
Reason: "",
|
||||
Email: user.Email,
|
||||
})
|
||||
}
|
||||
errors.LogInfo(ctx, "tunnelling request to ", destination)
|
||||
|
||||
if !s.cone || dest == nil {
|
||||
dest = &destination
|
||||
}
|
||||
if !s.cone || dest == nil {
|
||||
dest = &destination
|
||||
}
|
||||
|
||||
udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
|
||||
for _, payload := range mb2 {
|
||||
udpServer.Dispatch(currentPacketCtx, *dest, payload)
|
||||
udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
|
||||
for _, payload := range mb2 {
|
||||
udpServer.Dispatch(currentPacketCtx, *dest, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if err := task.Run(ctx, requestDone); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/xtls/xray-core/common/crypto"
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"github.com/xtls/xray-core/common/protocol"
|
||||
"lukechampine.com/blake3"
|
||||
)
|
||||
|
||||
@@ -66,7 +67,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if i.NfsPKeys == nil {
|
||||
return nil, errors.New("uninitialized")
|
||||
}
|
||||
c := NewCommonConn(conn)
|
||||
c := NewCommonConn(conn, protocol.HasAESGCMHardwareSupport)
|
||||
|
||||
ivAndRealysLength := 16 + i.RelaysLength
|
||||
pfsKeyExchangeLength := 18 + 1184 + 32 + 16
|
||||
@@ -108,18 +109,18 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
lastCTR.XORKeyStream(relays[index:], i.Hash32s[j+1][:])
|
||||
relays = relays[index+32:]
|
||||
}
|
||||
nfsGCM := NewGCM(iv, nfsKey)
|
||||
nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
|
||||
|
||||
if i.Seconds > 0 {
|
||||
i.RWLock.RLock()
|
||||
if time.Now().Before(i.Expire) {
|
||||
c.Client = i
|
||||
c.UnitedKey = append(i.PfsKey, nfsKey...) // different unitedKey for each connection
|
||||
nfsGCM.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
|
||||
nfsGCM.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
|
||||
nfsAEAD.Seal(clientHello[:ivAndRealysLength], nil, EncodeLength(32), nil)
|
||||
nfsAEAD.Seal(clientHello[:ivAndRealysLength+18], nil, i.Ticket, nil)
|
||||
i.RWLock.RUnlock()
|
||||
c.PreWrite = clientHello[:ivAndRealysLength+18+32]
|
||||
c.GCM = NewGCM(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey)
|
||||
c.AEAD = NewAEAD(clientHello[ivAndRealysLength+18:ivAndRealysLength+18+32], c.UnitedKey, c.UseAES)
|
||||
if i.XorMode == 2 {
|
||||
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, iv), nil, len(c.PreWrite), 16)
|
||||
}
|
||||
@@ -129,15 +130,15 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
}
|
||||
|
||||
pfsKeyExchange := clientHello[ivAndRealysLength : ivAndRealysLength+pfsKeyExchangeLength]
|
||||
nfsGCM.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
|
||||
nfsAEAD.Seal(pfsKeyExchange[:0], nil, EncodeLength(pfsKeyExchangeLength-18), nil)
|
||||
mlkem768DKey, _ := mlkem.GenerateKey768()
|
||||
x25519SKey, _ := ecdh.X25519().GenerateKey(rand.Reader)
|
||||
pfsPublicKey := append(mlkem768DKey.EncapsulationKey().Bytes(), x25519SKey.PublicKey().Bytes()...)
|
||||
nfsGCM.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
|
||||
nfsAEAD.Seal(pfsKeyExchange[:18], nil, pfsPublicKey, nil)
|
||||
|
||||
padding := clientHello[ivAndRealysLength+pfsKeyExchangeLength:]
|
||||
nfsGCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
|
||||
nfsGCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
|
||||
nfsAEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
|
||||
nfsAEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
|
||||
|
||||
if _, err := conn.Write(clientHello); err != nil {
|
||||
return nil, err
|
||||
@@ -148,7 +149,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nfsGCM.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
|
||||
nfsAEAD.Open(encryptedPfsPublicKey[:0], MaxNonce, encryptedPfsPublicKey, nil)
|
||||
mlkem768Key, err := mlkem768DKey.Decapsulate(encryptedPfsPublicKey[:1088])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -165,14 +166,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
copy(pfsKey, mlkem768Key)
|
||||
copy(pfsKey[32:], x25519Key)
|
||||
c.UnitedKey = append(pfsKey, nfsKey...)
|
||||
c.GCM = NewGCM(pfsPublicKey, c.UnitedKey)
|
||||
c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1088+32], c.UnitedKey)
|
||||
c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
|
||||
c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1088+32], c.UnitedKey, c.UseAES)
|
||||
|
||||
encryptedTicket := make([]byte, 32)
|
||||
if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := c.PeerGCM.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
|
||||
if _, err := c.PeerAEAD.Open(encryptedTicket[:0], nil, encryptedTicket, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seconds := DecodeLength(encryptedTicket)
|
||||
@@ -189,7 +190,7 @@ func (i *ClientInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := c.PeerGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
if _, err := c.PeerAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := DecodeLength(encryptedLength[:2])
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/xtls/xray-core/common/errors"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"lukechampine.com/blake3"
|
||||
)
|
||||
|
||||
@@ -23,19 +24,21 @@ var OutBytesPool = sync.Pool{
|
||||
|
||||
type CommonConn struct {
|
||||
net.Conn
|
||||
UseAES bool
|
||||
Client *ClientInstance
|
||||
UnitedKey []byte
|
||||
PreWrite []byte
|
||||
GCM *GCM
|
||||
PeerGCM *GCM
|
||||
AEAD *AEAD
|
||||
PeerAEAD *AEAD
|
||||
PeerPadding []byte
|
||||
PeerInBytes []byte
|
||||
PeerCache []byte
|
||||
}
|
||||
|
||||
func NewCommonConn(conn net.Conn) *CommonConn {
|
||||
func NewCommonConn(conn net.Conn, useAES bool) *CommonConn {
|
||||
return &CommonConn{
|
||||
Conn: conn,
|
||||
UseAES: useAES,
|
||||
PeerInBytes: make([]byte, 5+17000), // no need to use sync.Pool, because we are always reading
|
||||
}
|
||||
}
|
||||
@@ -55,12 +58,12 @@ func (c *CommonConn) Write(b []byte) (int, error) {
|
||||
headerAndData := outBytes[:5+len(b)+16]
|
||||
EncodeHeader(headerAndData, len(b)+16)
|
||||
max := false
|
||||
if bytes.Equal(c.GCM.Nonce[:], MaxNonce) {
|
||||
if bytes.Equal(c.AEAD.Nonce[:], MaxNonce) {
|
||||
max = true
|
||||
}
|
||||
c.GCM.Seal(headerAndData[:5], nil, b, headerAndData[:5])
|
||||
c.AEAD.Seal(headerAndData[:5], nil, b, headerAndData[:5])
|
||||
if max {
|
||||
c.GCM = NewGCM(headerAndData, c.UnitedKey)
|
||||
c.AEAD = NewAEAD(headerAndData, c.UnitedKey, c.UseAES)
|
||||
}
|
||||
if c.PreWrite != nil {
|
||||
headerAndData = append(c.PreWrite, headerAndData...)
|
||||
@@ -77,12 +80,12 @@ func (c *CommonConn) Read(b []byte) (int, error) {
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if c.PeerGCM == nil { // client's 0-RTT
|
||||
if c.PeerAEAD == nil { // client's 0-RTT
|
||||
serverRandom := make([]byte, 16)
|
||||
if _, err := io.ReadFull(c.Conn, serverRandom); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.PeerGCM = NewGCM(serverRandom, c.UnitedKey)
|
||||
c.PeerAEAD = NewAEAD(serverRandom, c.UnitedKey, c.UseAES)
|
||||
if xorConn, ok := c.Conn.(*XorConn); ok {
|
||||
xorConn.PeerCTR = NewCTR(c.UnitedKey, serverRandom)
|
||||
}
|
||||
@@ -91,7 +94,7 @@ func (c *CommonConn) Read(b []byte) (int, error) {
|
||||
if _, err := io.ReadFull(c.Conn, c.PeerPadding); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if _, err := c.PeerGCM.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
|
||||
if _, err := c.PeerAEAD.Open(c.PeerPadding[:0], nil, c.PeerPadding, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
c.PeerPadding = nil
|
||||
@@ -126,13 +129,13 @@ func (c *CommonConn) Read(b []byte) (int, error) {
|
||||
if len(dst) <= len(b) {
|
||||
dst = b[:len(dst)] // avoids another copy()
|
||||
}
|
||||
var newGCM *GCM
|
||||
if bytes.Equal(c.PeerGCM.Nonce[:], MaxNonce) {
|
||||
newGCM = NewGCM(c.PeerInBytes[:5+l], c.UnitedKey)
|
||||
var newAEAD *AEAD
|
||||
if bytes.Equal(c.PeerAEAD.Nonce[:], MaxNonce) {
|
||||
newAEAD = NewAEAD(c.PeerInBytes[:5+l], c.UnitedKey, c.UseAES)
|
||||
}
|
||||
_, err = c.PeerGCM.Open(dst[:0], nil, peerData, peerHeader)
|
||||
if newGCM != nil {
|
||||
c.PeerGCM = newGCM
|
||||
_, err = c.PeerAEAD.Open(dst[:0], nil, peerData, peerHeader)
|
||||
if newAEAD != nil {
|
||||
c.PeerAEAD = newAEAD
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -144,28 +147,32 @@ func (c *CommonConn) Read(b []byte) (int, error) {
|
||||
return len(dst), nil
|
||||
}
|
||||
|
||||
type GCM struct {
|
||||
type AEAD struct {
|
||||
cipher.AEAD
|
||||
Nonce [12]byte
|
||||
}
|
||||
|
||||
func NewGCM(ctx, key []byte) *GCM {
|
||||
func NewAEAD(ctx, key []byte, useAES bool) *AEAD {
|
||||
k := make([]byte, 32)
|
||||
blake3.DeriveKey(k, string(ctx), key)
|
||||
block, _ := aes.NewCipher(k)
|
||||
aead, _ := cipher.NewGCM(block)
|
||||
return &GCM{AEAD: aead}
|
||||
//chacha20poly1305.New()
|
||||
var aead cipher.AEAD
|
||||
if useAES {
|
||||
block, _ := aes.NewCipher(k)
|
||||
aead, _ = cipher.NewGCM(block)
|
||||
} else {
|
||||
aead, _ = chacha20poly1305.New(k)
|
||||
}
|
||||
return &AEAD{AEAD: aead}
|
||||
}
|
||||
|
||||
func (a *GCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
|
||||
func (a *AEAD) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
|
||||
if nonce == nil {
|
||||
nonce = IncreaseNonce(a.Nonce[:])
|
||||
}
|
||||
return a.AEAD.Seal(dst, nonce, plaintext, additionalData)
|
||||
}
|
||||
|
||||
func (a *GCM) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
func (a *AEAD) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
|
||||
if nonce == nil {
|
||||
nonce = IncreaseNonce(a.Nonce[:])
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if i.NfsSKeys == nil {
|
||||
return nil, errors.New("uninitialized")
|
||||
}
|
||||
c := NewCommonConn(conn)
|
||||
c := NewCommonConn(conn, true)
|
||||
|
||||
ivAndRelays := make([]byte, 16+i.RelaysLength)
|
||||
if _, err := io.ReadFull(conn, ivAndRelays); err != nil {
|
||||
@@ -151,16 +151,21 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
}
|
||||
relays = relays[32:]
|
||||
}
|
||||
nfsGCM := NewGCM(iv, nfsKey)
|
||||
nfsAEAD := NewAEAD(iv, nfsKey, c.UseAES)
|
||||
|
||||
encryptedLength := make([]byte, 18)
|
||||
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
return nil, err
|
||||
decryptedLength := make([]byte, 2)
|
||||
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
c.UseAES = !c.UseAES
|
||||
nfsAEAD = NewAEAD(iv, nfsKey, c.UseAES)
|
||||
if _, err := nfsAEAD.Open(decryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
length := DecodeLength(encryptedLength[:2])
|
||||
length := DecodeLength(decryptedLength)
|
||||
|
||||
if length == 32 {
|
||||
if i.Seconds == 0 {
|
||||
@@ -170,7 +175,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if _, err := io.ReadFull(conn, encryptedTicket); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ticket, err := nfsGCM.Open(nil, nil, encryptedTicket, nil)
|
||||
ticket, err := nfsAEAD.Open(nil, nil, encryptedTicket, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -193,8 +198,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
c.UnitedKey = append(s.PfsKey, nfsKey...) // the same nfsKey links the upload & download (prevents server -> client's another request)
|
||||
c.PreWrite = make([]byte, 16)
|
||||
rand.Read(c.PreWrite) // always trust yourself, not the client (also prevents being parsed as TLS thus causing false interruption for "native" and "xorpub")
|
||||
c.GCM = NewGCM(c.PreWrite, c.UnitedKey)
|
||||
c.PeerGCM = NewGCM(encryptedTicket, c.UnitedKey) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
|
||||
c.AEAD = NewAEAD(c.PreWrite, c.UnitedKey, c.UseAES)
|
||||
c.PeerAEAD = NewAEAD(encryptedTicket, c.UnitedKey, c.UseAES) // unchangeable ctx (prevents server -> server), and different ctx length for upload / download (prevents client -> client)
|
||||
if i.XorMode == 2 {
|
||||
c.Conn = NewXorConn(conn, NewCTR(c.UnitedKey, c.PreWrite), NewCTR(c.UnitedKey, iv), 16, 0) // it doesn't matter if the attacker sends client's iv back to the client
|
||||
}
|
||||
@@ -208,7 +213,7 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if _, err := io.ReadFull(conn, encryptedPfsPublicKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := nfsGCM.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
|
||||
if _, err := nfsAEAD.Open(encryptedPfsPublicKey[:0], nil, encryptedPfsPublicKey, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mlkem768EKey, err := mlkem.NewEncapsulationKey768(encryptedPfsPublicKey[:1184])
|
||||
@@ -230,8 +235,8 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
copy(pfsKey[32:], x25519Key)
|
||||
pfsPublicKey := append(encapsulatedPfsKey, x25519SKey.PublicKey().Bytes()...)
|
||||
c.UnitedKey = append(pfsKey, nfsKey...)
|
||||
c.GCM = NewGCM(pfsPublicKey, c.UnitedKey)
|
||||
c.PeerGCM = NewGCM(encryptedPfsPublicKey[:1184+32], c.UnitedKey)
|
||||
c.AEAD = NewAEAD(pfsPublicKey, c.UnitedKey, c.UseAES)
|
||||
c.PeerAEAD = NewAEAD(encryptedPfsPublicKey[:1184+32], c.UnitedKey, c.UseAES)
|
||||
ticket := make([]byte, 16)
|
||||
rand.Read(ticket)
|
||||
copy(ticket, EncodeLength(int(i.Seconds*4/5)))
|
||||
@@ -240,11 +245,11 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
encryptedTicketLength := 32
|
||||
paddingLength := int(crypto.RandBetween(100, 1000))
|
||||
serverHello := make([]byte, pfsKeyExchangeLength+encryptedTicketLength+paddingLength)
|
||||
nfsGCM.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
|
||||
c.GCM.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
|
||||
nfsAEAD.Seal(serverHello[:0], MaxNonce, pfsPublicKey, nil)
|
||||
c.AEAD.Seal(serverHello[:pfsKeyExchangeLength], nil, ticket, nil)
|
||||
padding := serverHello[pfsKeyExchangeLength+encryptedTicketLength:]
|
||||
c.GCM.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
|
||||
c.GCM.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
|
||||
c.AEAD.Seal(padding[:0], nil, EncodeLength(paddingLength-18), nil)
|
||||
c.AEAD.Seal(padding[:18], nil, padding[18:paddingLength-16], nil)
|
||||
|
||||
if _, err := conn.Write(serverHello); err != nil {
|
||||
return nil, err
|
||||
@@ -264,14 +269,14 @@ func (i *ServerInstance) Handshake(conn net.Conn) (*CommonConn, error) {
|
||||
if _, err := io.ReadFull(conn, encryptedLength); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := nfsGCM.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
if _, err := nfsAEAD.Open(encryptedLength[:0], nil, encryptedLength, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encryptedPadding := make([]byte, DecodeLength(encryptedLength[:2]))
|
||||
if _, err := io.ReadFull(conn, encryptedPadding); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := nfsGCM.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
|
||||
if _, err := nfsAEAD.Open(encryptedPadding[:0], nil, encryptedPadding, nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet)
|
||||
|
||||
type connEntry struct {
|
||||
link *transport.Link
|
||||
timer signal.ActivityUpdater
|
||||
timer *signal.ActivityTimer
|
||||
cancel context.CancelFunc
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (c *connEntry) Close() error {
|
||||
c.timer.SetTimeout(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *connEntry) terminate() {
|
||||
if c.closed {
|
||||
panic("terminate called more than once")
|
||||
}
|
||||
c.closed = true
|
||||
c.cancel()
|
||||
common.Interrupt(c.link.Reader)
|
||||
common.Interrupt(c.link.Writer)
|
||||
}
|
||||
|
||||
type Dispatcher struct {
|
||||
@@ -32,6 +48,7 @@ type Dispatcher struct {
|
||||
dispatcher routing.Dispatcher
|
||||
callback ResponseCallback
|
||||
callClose func() error
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
|
||||
@@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di
|
||||
func (v *Dispatcher) RemoveRay() {
|
||||
v.Lock()
|
||||
defer v.Unlock()
|
||||
v.removeRay()
|
||||
}
|
||||
|
||||
func (v *Dispatcher) removeRay() {
|
||||
v.closed = true
|
||||
if v.conn != nil {
|
||||
common.Interrupt(v.conn.link.Reader)
|
||||
common.Close(v.conn.link.Writer)
|
||||
v.conn.Close()
|
||||
v.conn = nil
|
||||
}
|
||||
}
|
||||
@@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
|
||||
v.Lock()
|
||||
defer v.Unlock()
|
||||
|
||||
if v.closed {
|
||||
return nil, errors.New("dispatcher is closed")
|
||||
}
|
||||
|
||||
if v.conn != nil {
|
||||
return v.conn, nil
|
||||
if v.conn.closed {
|
||||
v.conn = nil
|
||||
} else {
|
||||
return v.conn, nil
|
||||
}
|
||||
}
|
||||
|
||||
errors.LogInfo(ctx, "establishing new connection for ", dest)
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
entry := &connEntry{}
|
||||
removeRay := func() {
|
||||
v.Lock()
|
||||
defer v.Unlock()
|
||||
// sometimes the entry is already removed by others, don't close again
|
||||
if entry == v.conn {
|
||||
cancel()
|
||||
v.removeRay()
|
||||
}
|
||||
}
|
||||
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
|
||||
|
||||
link, err := v.dispatcher.Dispatch(ctx, dest)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return nil, errors.New("failed to dispatch request to ", dest).Base(err)
|
||||
}
|
||||
|
||||
*entry = connEntry{
|
||||
entry := &connEntry{
|
||||
link: link,
|
||||
timer: timer,
|
||||
cancel: removeRay,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute)
|
||||
v.conn = entry
|
||||
go handleInput(ctx, entry, dest, v.callback, v.callClose)
|
||||
return entry, nil
|
||||
@@ -106,7 +118,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
|
||||
if outputStream != nil {
|
||||
if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
|
||||
errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
|
||||
conn.cancel()
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -114,7 +126,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
|
||||
|
||||
func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
|
||||
defer func() {
|
||||
conn.cancel()
|
||||
conn.Close()
|
||||
if callClose != nil {
|
||||
callClose()
|
||||
}
|
||||
|
||||
@@ -200,16 +200,19 @@ func (p *pipe) Interrupt() {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if !p.data.IsEmpty() {
|
||||
buf.ReleaseMulti(p.data)
|
||||
p.data = nil
|
||||
if p.state == closed {
|
||||
p.state = errord
|
||||
}
|
||||
}
|
||||
|
||||
if p.state == closed || p.state == errord {
|
||||
return
|
||||
}
|
||||
|
||||
p.state = errord
|
||||
|
||||
if !p.data.IsEmpty() {
|
||||
buf.ReleaseMulti(p.data)
|
||||
p.data = nil
|
||||
}
|
||||
|
||||
common.Must(p.done.Close())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user