diff --git a/internal/fdbased_darwin/endpoint.go b/internal/fdbased_darwin/endpoint.go new file mode 100644 index 0000000..b542793 --- /dev/null +++ b/internal/fdbased_darwin/endpoint.go @@ -0,0 +1,648 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fdbased provides the implementation of data-link layer endpoints +// backed by boundary-preserving file descriptors (e.g., TUN devices, +// seqpacket/datagram sockets). +// +// FD based endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +// +// FD based endpoints can use more than one file descriptor to read incoming +// packets. If there are more than one FDs specified and the underlying FD is an +// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the +// host kernel will consistently hash the packets to the sockets. This ensures +// that packets for the same TCP streams are not reordered. +// +// Similarly if more than one FD's are specified where the underlying FD is not +// AF_PACKET then it's the caller's responsibility to ensure that all inbound +// packets on the descriptors are consistently 5 tuple hashed to one of the +// descriptors to prevent TCP reordering. +// +// Since netstack today does not compute 5 tuple hashes for outgoing packets we +// only use the first FD to write outbound packets. Once 5 tuple hashes for +// all outbound packets are available we will make use of all underlying FD's to +// write outbound packets. +package fdbased + +import ( + "fmt" + "runtime" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/sync" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/sing-tun/internal/rawfile_darwin" + "github.com/sagernet/sing/common" + + "golang.org/x/sys/unix" +) + +// linkDispatcher reads packets from the link FD and dispatches them to the +// NetworkDispatcher. +type linkDispatcher interface { + Stop() + dispatch() (bool, tcpip.Error) + release() +} + +// PacketDispatchMode are the various supported methods of receiving and +// dispatching packets from the underlying FD. +type PacketDispatchMode int + +// BatchSize is the number of packets to write in each syscall. It is 47 +// because when GVisorGSO is in use then a single 65KB TCP segment can get +// split into 46 segments of 1420 bytes and a single 216 byte segment. +const BatchSize = 47 + +const ( + // Readv is the default dispatch mode and is the least performant of the + // dispatch options but the one that is supported by all underlying FD + // types. + Readv PacketDispatchMode = iota +) + +func (p PacketDispatchMode) String() string { + switch p { + case Readv: + return "Readv" + default: + return fmt.Sprintf("unknown packet dispatch mode '%d'", p) + } +} + +var ( + _ stack.LinkEndpoint = (*endpoint)(nil) + _ stack.GSOEndpoint = (*endpoint)(nil) +) + +// +stateify savable +type fdInfo struct { + fd int + isSocket bool +} + +// +stateify savable +type endpoint struct { + // fds is the set of file descriptors each identifying one inbound/outbound + // channel. The endpoint will dispatch from all inbound channels as well as + // hash outbound packets to specific channels based on the packet hash. + fds []fdInfo + + // hdrSize specifies the link-layer header size. If set to 0, no header + // is added/removed; otherwise an ethernet header is used. + hdrSize int + + // caps holds the endpoint capabilities. + caps stack.LinkEndpointCapabilities + + // closed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + closed func(tcpip.Error) `state:"nosave"` + + inboundDispatchers []linkDispatcher + + mu endpointRWMutex `state:"nosave"` + // +checklocks:mu + dispatcher stack.NetworkDispatcher + + // packetDispatchMode controls the packet dispatcher used by this + // endpoint. + packetDispatchMode PacketDispatchMode + + // wg keeps track of running goroutines. + wg sync.WaitGroup `state:"nosave"` + + // maxSyscallHeaderBytes has the same meaning as + // Options.MaxSyscallHeaderBytes. + maxSyscallHeaderBytes uintptr + + // writevMaxIovs is the maximum number of iovecs that may be passed to + // rawfile.NonBlockingWriteIovec, as possibly limited by + // maxSyscallHeaderBytes. (No analogous limit is defined for + // rawfile.NonBlockingSendMMsg, since in that case the maximum number of + // iovecs also depends on the number of mmsghdrs. Instead, if sendBatch + // encounters a packet whose iovec count is limited by + // maxSyscallHeaderBytes, it falls back to writing the packet using writev + // via WritePacket.) + writevMaxIovs int + + // addr is the address of the endpoint. + // + // +checklocks:mu + addr tcpip.LinkAddress + + // mtu (maximum transmission unit) is the maximum size of a packet. + // +checklocks:mu + mtu uint32 + + batchSize int +} + +// Options specify the details about the fd-based endpoint to be created. +// +// +stateify savable +type Options struct { + // FDs is a set of FDs used to read/write packets. + FDs []int + + // MTU is the mtu to use for this endpoint. + MTU uint32 + + // EthernetHeader if true, indicates that the endpoint should read/write + // ethernet frames instead of IP packets. + EthernetHeader bool + + // ClosedFunc is a function to be called when an endpoint's peer (if + // any) closes its end of the communication pipe. + ClosedFunc func(tcpip.Error) + + // Address is the link address for this endpoint. Only used if + // EthernetHeader is true. + Address tcpip.LinkAddress + + // SaveRestore if true, indicates that this NIC capability set should + // include CapabilitySaveRestore + SaveRestore bool + + // DisconnectOk if true, indicates that this NIC capability set should + // include CapabilityDisconnectOk. + DisconnectOk bool + + // PacketDispatchMode specifies the type of inbound dispatcher to be + // used for this endpoint. + PacketDispatchMode PacketDispatchMode + + // TXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityTXChecksumOffload. + TXChecksumOffload bool + + // RXChecksumOffload if true, indicates that this endpoints capability + // set should include CapabilityRXChecksumOffload. + RXChecksumOffload bool + + // If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes + // of struct iovec, msghdr, and mmsghdr that may be passed by each host + // system call. + MaxSyscallHeaderBytes int + + // InterfaceIndex is the interface index of the underlying device. + InterfaceIndex int + + // ProcessorsPerChannel is the number of goroutines used to handle packets + // from each FD. + ProcessorsPerChannel int +} + +// New creates a new fd-based endpoint. +// +// Makes fd non-blocking, but does not take ownership of fd, which must remain +// open for the lifetime of the returned endpoint (until after the endpoint has +// stopped being using and Wait returns). +func New(opts *Options) (stack.LinkEndpoint, error) { + caps := stack.LinkEndpointCapabilities(0) + if opts.RXChecksumOffload { + caps |= stack.CapabilityRXChecksumOffload + } + + if opts.TXChecksumOffload { + caps |= stack.CapabilityTXChecksumOffload + } + + hdrSize := 0 + if opts.EthernetHeader { + hdrSize = header.EthernetMinimumSize + caps |= stack.CapabilityResolutionRequired + } + + if opts.SaveRestore { + caps |= stack.CapabilitySaveRestore + } + + if opts.DisconnectOk { + caps |= stack.CapabilityDisconnectOk + } + + if len(opts.FDs) == 0 { + return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified") + } + + if opts.MaxSyscallHeaderBytes < 0 { + return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative") + } + + e := &endpoint{ + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: hdrSize, + packetDispatchMode: opts.PacketDispatchMode, + maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes), + writevMaxIovs: rawfile.MaxIovs, + batchSize: int((512*1024)/(opts.MTU)) + 1, + } + if e.maxSyscallHeaderBytes != 0 { + if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs { + e.writevMaxIovs = max + } + } + + // Create per channel dispatchers. + for _, fd := range opts.FDs { + if err := unix.SetNonblock(fd, true); err != nil { + return nil, fmt.Errorf("unix.SetNonblock(%v) failed: %v", fd, err) + } + + e.fds = append(e.fds, fdInfo{fd: fd, isSocket: true}) + if opts.ProcessorsPerChannel == 0 { + opts.ProcessorsPerChannel = common.Max(1, runtime.GOMAXPROCS(0)/len(opts.FDs)) + } + + inboundDispatcher, err := newRecvMMsgDispatcher(fd, e, opts) + if err != nil { + return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err) + } + e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher) + } + + return e, nil +} + +func isSocketFD(fd int) (bool, error) { + var stat unix.Stat_t + if err := unix.Fstat(fd, &stat); err != nil { + return false, fmt.Errorf("unix.Fstat(%v,...) failed: %v", fd, err) + } + return (stat.Mode & unix.S_IFSOCK) == unix.S_IFSOCK, nil +} + +// Attach launches the goroutine that reads packets from the file descriptor and +// dispatches them via the provided dispatcher. If one is already attached, +// then nothing happens. +// +// Attach implements stack.LinkEndpoint.Attach. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.mu.Lock() + + // nil means the NIC is being removed. + if dispatcher == nil && e.dispatcher != nil { + for _, dispatcher := range e.inboundDispatchers { + dispatcher.Stop() + } + e.dispatcher = nil + // NOTE(gvisor.dev/issue/11456): Unlock e.mu before e.Wait(). + e.mu.Unlock() + e.Wait() + return + } + defer e.mu.Unlock() + if dispatcher != nil && e.dispatcher == nil { + e.dispatcher = dispatcher + // Link endpoints are not savable. When transportation endpoints are + // saved, they stop sending outgoing packets and all incoming packets + // are rejected. + for i := range e.inboundDispatchers { + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) + } + } +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + e.mu.RLock() + defer e.mu.RUnlock() + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. +func (e *endpoint) MTU() uint32 { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mtu +} + +// SetMTU implements stack.LinkEndpoint.SetMTU. +func (e *endpoint) SetMTU(mtu uint32) { + e.mu.Lock() + defer e.mu.Unlock() + e.mtu = mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength returns the maximum size of the link-layer header. +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress returns the link address of this endpoint. +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + e.mu.RLock() + defer e.mu.RUnlock() + return e.addr +} + +// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress. +func (e *endpoint) SetLinkAddress(addr tcpip.LinkAddress) { + e.mu.Lock() + defer e.mu.Unlock() + e.addr = addr +} + +// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop +// reading from its FD. +func (e *endpoint) Wait() { + e.wg.Wait() +} + +// AddHeader implements stack.LinkEndpoint.AddHeader. +func (e *endpoint) AddHeader(pkt *stack.PacketBuffer) { + if e.hdrSize > 0 { + // Add ethernet header if needed. + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + eth.Encode(&header.EthernetFields{ + SrcAddr: pkt.EgressRoute.LocalLinkAddress, + DstAddr: pkt.EgressRoute.RemoteLinkAddress, + Type: pkt.NetworkProtocolNumber, + }) + } +} + +func (e *endpoint) parseHeader(pkt *stack.PacketBuffer) (header.Ethernet, bool) { + if e.hdrSize <= 0 { + return nil, true + } + hdrBytes, ok := pkt.LinkHeader().Consume(e.hdrSize) + if !ok { + return nil, false + } + hdr := header.Ethernet(hdrBytes) + pkt.NetworkProtocolNumber = hdr.Type() + return hdr, true +} + +// parseInboundHeader parses the link header of pkt and returns true if the +// header is well-formed and sent to this endpoint's MAC or the broadcast +// address. +func (e *endpoint) parseInboundHeader(pkt *stack.PacketBuffer, wantAddr tcpip.LinkAddress) bool { + hdr, ok := e.parseHeader(pkt) + if !ok || e.hdrSize <= 0 { + return ok + } + dstAddr := hdr.DestinationAddress() + // Per RFC 9542 2.1 on the least significant bit of the first octet of + // a MAC address: "If it is zero, the MAC address is unicast. If it is + // a one, the address is groupcast (multicast or broadcast)." Multicast + // and broadcast are the same thing to ethernet; they are both sent to + // everyone. + return dstAddr == wantAddr || byte(dstAddr[0])&0x01 == 1 +} + +// ParseHeader implements stack.LinkEndpoint.ParseHeader. +func (e *endpoint) ParseHeader(pkt *stack.PacketBuffer) bool { + _, ok := e.parseHeader(pkt) + return ok +} + +var ( + packetHeader4 = []byte{0x00, 0x00, 0x00, unix.AF_INET} + packetHeader6 = []byte{0x00, 0x00, 0x00, unix.AF_INET6} +) + +// writePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +func (e *endpoint) writePacket(pkt *stack.PacketBuffer) tcpip.Error { + fdInfo := e.fds[pkt.Hash%uint32(len(e.fds))] + fd := fdInfo.fd + var vnetHdrBuf []byte + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + vnetHdrBuf = packetHeader4 + } else { + vnetHdrBuf = packetHeader6 + } + views := pkt.AsSlices() + numIovecs := len(views) + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > e.writevMaxIovs { + numIovecs = e.writevMaxIovs + } + + // Allocate small iovec arrays on the stack. + var iovecsArr [8]unix.Iovec + iovecs := iovecsArr[:0] + if numIovecs > len(iovecsArr) { + iovecs = make([]unix.Iovec, 0, numIovecs) + } + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + for _, v := range views { + iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs) + } + if errno := rawfile.NonBlockingWriteIovec(fd, iovecs); errno != 0 { + return TranslateErrno(errno) + } + return nil +} + +func (e *endpoint) sendBatch(batchFDInfo fdInfo, pkts []*stack.PacketBuffer) (int, tcpip.Error) { + // Degrade to writePacket if underlying fd is not a socket. + if !batchFDInfo.isSocket { + var written int + var err tcpip.Error + for written < len(pkts) { + if err = e.writePacket(pkts[written]); err != nil { + break + } + written++ + } + return written, err + } + + // Send a batch of packets through batchFD. + batchFD := batchFDInfo.fd + mmsgHdrsStorage := make([]rawfile.MsgHdrX, 0, len(pkts)) + packets := 0 + for packets < len(pkts) { + mmsgHdrs := mmsgHdrsStorage + batch := pkts[packets:] + syscallHeaderBytes := uintptr(0) + for _, pkt := range batch { + var vnetHdrBuf []byte + if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { + vnetHdrBuf = packetHeader4 + } else { + vnetHdrBuf = packetHeader6 + } + views, offset := pkt.AsViewList() + var skipped int + var view *buffer.View + for view = views.Front(); view != nil && offset >= view.Size(); view = view.Next() { + offset -= view.Size() + skipped++ + } + + // We've made it to the usable views. + numIovecs := views.Len() - skipped + if len(vnetHdrBuf) != 0 { + numIovecs++ + } + if numIovecs > rawfile.MaxIovs { + numIovecs = rawfile.MaxIovs + } + if e.maxSyscallHeaderBytes != 0 { + syscallHeaderBytes += rawfile.SizeofMsgHdrX + uintptr(numIovecs)*rawfile.SizeofIovec + if syscallHeaderBytes > e.maxSyscallHeaderBytes { + // We can't fit this packet into this call to sendmmsg(). + // We could potentially do so if we reduced numIovecs + // further, but this might incur considerable extra + // copying. Leave it to the next batch instead. + break + } + } + + // We can't easily allocate iovec arrays on the stack here since + // they will escape this loop iteration via mmsgHdrs. + iovecs := make([]unix.Iovec, 0, numIovecs) + iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs) + // At most one slice has a non-zero offset. + iovecs = rawfile.AppendIovecFromBytes(iovecs, view.AsSlice()[offset:], numIovecs) + for view = view.Next(); view != nil; view = view.Next() { + iovecs = rawfile.AppendIovecFromBytes(iovecs, view.AsSlice(), numIovecs) + } + + var mmsgHdr rawfile.MsgHdrX + mmsgHdr.Msg.Iov = &iovecs[0] + mmsgHdr.Msg.SetIovlen(len(iovecs)) + // mmsgHdr.DataLen = uint32(len(iovecs)) + mmsgHdrs = append(mmsgHdrs, mmsgHdr) + } + + if len(mmsgHdrs) == 0 { + // We can't fit batch[0] into a mmsghdr while staying under + // e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the + // mmsghdr (by using writev) and re-buffer iovecs more aggressively + // if necessary (by using e.writevMaxIovs instead of + // rawfile.MaxIovs). + pkt := batch[0] + if err := e.writePacket(pkt); err != nil { + return packets, err + } + packets++ + } else { + for len(mmsgHdrs) > 0 { + sent, errno := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs) + if errno != 0 { + return packets, TranslateErrno(errno) + } + packets += sent + mmsgHdrs = mmsgHdrs[sent:] + } + } + } + + return packets, nil +} + +// WritePackets writes outbound packets to the underlying file descriptors. If +// one is not currently writable, the packet is dropped. +// +// Being a batch API, each packet in pkts should have the following +// fields populated: +// - pkt.EgressRoute +// - pkt.GSOOptions +// - pkt.NetworkProtocolNumber +func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + // Preallocate to avoid repeated reallocation as we append to batch. + batch := make([]*stack.PacketBuffer, 0, e.batchSize) + batchFDInfo := fdInfo{fd: -1, isSocket: false} + sentPackets := 0 + for _, pkt := range pkts.AsSlice() { + if len(batch) == 0 { + batchFDInfo = e.fds[pkt.Hash%uint32(len(e.fds))] + } + pktFDInfo := e.fds[pkt.Hash%uint32(len(e.fds))] + if sendNow := pktFDInfo != batchFDInfo; !sendNow { + batch = append(batch, pkt) + continue + } + n, err := e.sendBatch(batchFDInfo, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + batch = batch[:0] + batch = append(batch, pkt) + batchFDInfo = pktFDInfo + } + + if len(batch) != 0 { + n, err := e.sendBatch(batchFDInfo, batch) + sentPackets += n + if err != nil { + return sentPackets, err + } + } + return sentPackets, nil +} + +// dispatchLoop reads packets from the file descriptor in a loop and dispatches +// them to the network stack. +func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error { + for { + cont, err := inboundDispatcher.dispatch() + if err != nil || !cont { + if e.closed != nil { + e.closed(err) + } + inboundDispatcher.release() + return err + } + } +} + +// GSOMaxSize implements stack.GSOEndpoint. +func (e *endpoint) GSOMaxSize() uint32 { + return 0 +} + +// SupportedGSO implements stack.GSOEndpoint. +func (e *endpoint) SupportedGSO() stack.SupportedGSO { + return stack.GSONotSupported +} + +// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType. +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + if e.hdrSize > 0 { + return header.ARPHardwareEther + } + return header.ARPHardwareNone +} + +// Close implements stack.LinkEndpoint. +func (e *endpoint) Close() {} + +// SetOnCloseAction implements stack.LinkEndpoint. +func (*endpoint) SetOnCloseAction(func()) {} diff --git a/internal/fdbased_darwin/endpoint_mutex.go b/internal/fdbased_darwin/endpoint_mutex.go new file mode 100644 index 0000000..d05b264 --- /dev/null +++ b/internal/fdbased_darwin/endpoint_mutex.go @@ -0,0 +1,96 @@ +package fdbased + +import ( + "reflect" + + "github.com/sagernet/gvisor/pkg/sync" + "github.com/sagernet/gvisor/pkg/sync/locking" +) + +// RWMutex is sync.RWMutex with the correctness validator. +type endpointRWMutex struct { + mu sync.RWMutex +} + +// lockNames is a list of user-friendly lock names. +// Populated in init. +var endpointlockNames []string + +// lockNameIndex is used as an index passed to NestedLock and NestedUnlock, +// referring to an index within lockNames. +// Values are specified using the "consts" field of go_template_instance. +type endpointlockNameIndex int + +// DO NOT REMOVE: The following function automatically replaced with lock index constants. +// LOCK_NAME_INDEX_CONSTANTS +const () + +// Lock locks m. +// +checklocksignore +func (m *endpointRWMutex) Lock() { + locking.AddGLock(endpointprefixIndex, -1) + m.mu.Lock() +} + +// NestedLock locks m knowing that another lock of the same type is held. +// +checklocksignore +func (m *endpointRWMutex) NestedLock(i endpointlockNameIndex) { + locking.AddGLock(endpointprefixIndex, int(i)) + m.mu.Lock() +} + +// Unlock unlocks m. +// +checklocksignore +func (m *endpointRWMutex) Unlock() { + m.mu.Unlock() + locking.DelGLock(endpointprefixIndex, -1) +} + +// NestedUnlock unlocks m knowing that another lock of the same type is held. +// +checklocksignore +func (m *endpointRWMutex) NestedUnlock(i endpointlockNameIndex) { + m.mu.Unlock() + locking.DelGLock(endpointprefixIndex, int(i)) +} + +// RLock locks m for reading. +// +checklocksignore +func (m *endpointRWMutex) RLock() { + locking.AddGLock(endpointprefixIndex, -1) + m.mu.RLock() +} + +// RUnlock undoes a single RLock call. +// +checklocksignore +func (m *endpointRWMutex) RUnlock() { + m.mu.RUnlock() + locking.DelGLock(endpointprefixIndex, -1) +} + +// RLockBypass locks m for reading without executing the validator. +// +checklocksignore +func (m *endpointRWMutex) RLockBypass() { + m.mu.RLock() +} + +// RUnlockBypass undoes a single RLockBypass call. +// +checklocksignore +func (m *endpointRWMutex) RUnlockBypass() { + m.mu.RUnlock() +} + +// DowngradeLock atomically unlocks rw for writing and locks it for reading. +// +checklocksignore +func (m *endpointRWMutex) DowngradeLock() { + m.mu.DowngradeLock() +} + +var endpointprefixIndex *locking.MutexClass + +// DO NOT REMOVE: The following function is automatically replaced. +func endpointinitLockNames() {} + +func init() { + endpointinitLockNames() + endpointprefixIndex = locking.NewMutexClass(reflect.TypeOf(endpointRWMutex{}), endpointlockNames) +} diff --git a/internal/fdbased_darwin/errno.go b/internal/fdbased_darwin/errno.go new file mode 100644 index 0000000..074f4e2 --- /dev/null +++ b/internal/fdbased_darwin/errno.go @@ -0,0 +1,54 @@ +package fdbased + +import ( + "github.com/sagernet/gvisor/pkg/tcpip" + + "golang.org/x/sys/unix" +) + +func TranslateErrno(e unix.Errno) tcpip.Error { + switch e { + case unix.EEXIST: + return &tcpip.ErrDuplicateAddress{} + case unix.ENETUNREACH: + return &tcpip.ErrHostUnreachable{} + case unix.EINVAL: + return &tcpip.ErrInvalidEndpointState{} + case unix.EALREADY: + return &tcpip.ErrAlreadyConnecting{} + case unix.EISCONN: + return &tcpip.ErrAlreadyConnected{} + case unix.EADDRINUSE: + return &tcpip.ErrPortInUse{} + case unix.EADDRNOTAVAIL: + return &tcpip.ErrBadLocalAddress{} + case unix.EPIPE: + return &tcpip.ErrClosedForSend{} + case unix.EWOULDBLOCK: + return &tcpip.ErrWouldBlock{} + case unix.ECONNREFUSED: + return &tcpip.ErrConnectionRefused{} + case unix.ETIMEDOUT: + return &tcpip.ErrTimeout{} + case unix.EINPROGRESS: + return &tcpip.ErrConnectStarted{} + case unix.EDESTADDRREQ: + return &tcpip.ErrDestinationRequired{} + case unix.ENOTSUP: + return &tcpip.ErrNotSupported{} + case unix.ENOTTY: + return &tcpip.ErrQueueSizeNotSupported{} + case unix.ENOTCONN: + return &tcpip.ErrNotConnected{} + case unix.ECONNRESET: + return &tcpip.ErrConnectionReset{} + case unix.ECONNABORTED: + return &tcpip.ErrConnectionAborted{} + case unix.EMSGSIZE: + return &tcpip.ErrMessageTooLong{} + case unix.ENOBUFS: + return &tcpip.ErrNoBufferSpace{} + default: + return &tcpip.ErrInvalidEndpointState{} + } +} diff --git a/internal/fdbased_darwin/packet_dispatchers.go b/internal/fdbased_darwin/packet_dispatchers.go new file mode 100644 index 0000000..967f2a8 --- /dev/null +++ b/internal/fdbased_darwin/packet_dispatchers.go @@ -0,0 +1,229 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fdbased + +import ( + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/stack/gro" + "github.com/sagernet/sing-tun/internal/rawfile_darwin" + "github.com/sagernet/sing-tun/internal/stopfd_darwin" + + "golang.org/x/sys/unix" +) + +// BufConfig defines the shape of the buffer used to read packets from the NIC. +var BufConfig = []int{4, 128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} + +// +stateify savable +type iovecBuffer struct { + // buffer is the actual buffer that holds the packet contents. Some contents + // are reused across calls to pullBuffer if number of requested bytes is + // smaller than the number of bytes allocated in the buffer. + views []*buffer.View + + // iovecs are initialized with base pointers/len of the corresponding + // entries in the views defined above, except when GSO is enabled + // (skipsVnetHdr) then the first iovec points to a buffer for the vnet header + // which is stripped before the views are passed up the stack for further + // processing. + iovecs []unix.Iovec `state:"nosave"` + + // sizes is an array of buffer sizes for the underlying views. sizes is + // immutable. + sizes []int + + // pulledIndex is the index of the last []byte buffer pulled from the + // underlying buffer storage during a call to pullBuffers. It is -1 + // if no buffer is pulled. + pulledIndex int +} + +func newIovecBuffer(sizes []int) *iovecBuffer { + b := &iovecBuffer{ + views: make([]*buffer.View, len(sizes)), + iovecs: make([]unix.Iovec, len(sizes)), + sizes: sizes, + } + return b +} + +func (b *iovecBuffer) nextIovecs() []unix.Iovec { + for i := range b.views { + if b.views[i] != nil { + break + } + v := buffer.NewViewSize(b.sizes[i]) + b.views[i] = v + b.iovecs[i] = unix.Iovec{Base: v.BasePtr()} + b.iovecs[i].SetLen(v.Size()) + } + return b.iovecs +} + +// pullBuffer extracts the enough underlying storage from b.buffer to hold n +// bytes. It removes this storage from b.buffer, returns a new buffer +// that holds the storage, and updates pulledIndex to indicate which part +// of b.buffer's storage must be reallocated during the next call to +// nextIovecs. +func (b *iovecBuffer) pullBuffer(n int) buffer.Buffer { + var views []*buffer.View + c := 0 + // Remove the used views from the buffer. + for i, v := range b.views { + c += v.Size() + if c >= n { + b.views[i].CapLength(v.Size() - (c - n)) + views = append(views, b.views[:i+1]...) + break + } + } + for i := range views { + b.views[i] = nil + } + pulled := buffer.Buffer{} + for _, v := range views { + pulled.Append(v) + } + pulled.Truncate(int64(n)) + return pulled +} + +func (b *iovecBuffer) release() { + for _, v := range b.views { + if v != nil { + v.Release() + v = nil + } + } +} + +// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and +// dispatches them. +// +// +stateify savable +type recvMMsgDispatcher struct { + stopfd.StopFD + // fd is the file descriptor used to send and receive packets. + fd int + + // e is the endpoint this dispatcher is attached to. + e *endpoint + + // bufs is an array of iovec buffers that contain packet contents. + bufs []*iovecBuffer + + // msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to + // reference an array of iovecs in the iovecs field defined above. This + // array is passed as the parameter to recvmmsg call to retrieve + // potentially more than 1 packet per unix. + msgHdrs []rawfile.MsgHdrX `state:"nosave"` + + // pkts is reused to avoid allocations. + pkts stack.PacketBufferList + + // gro coalesces incoming packets to increase throughput. + gro gro.GRO + + // mgr is the processor goroutine manager. + mgr *processorManager +} + +func newRecvMMsgDispatcher(fd int, e *endpoint, opts *Options) (linkDispatcher, error) { + stopFD, err := stopfd.New() + if err != nil { + return nil, err + } + batchSize := int((512*1024)/(opts.MTU)) + 1 + d := &recvMMsgDispatcher{ + StopFD: stopFD, + fd: fd, + e: e, + bufs: make([]*iovecBuffer, batchSize), + msgHdrs: make([]rawfile.MsgHdrX, batchSize), + } + bufConfig := []int{4, int(opts.MTU)} + for i := range d.bufs { + d.bufs[i] = newIovecBuffer(bufConfig) + } + d.gro.Init(false) + d.mgr = newProcessorManager(opts, e) + d.mgr.start() + + return d, nil +} + +func (d *recvMMsgDispatcher) release() { + for _, iov := range d.bufs { + iov.release() + } + d.mgr.close() +} + +// recvMMsgDispatch reads more than one packet at a time from the file +// descriptor and dispatches it. +func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { + // Fill message headers. + for k := range d.msgHdrs { + if d.msgHdrs[k].Msg.Iovlen > 0 { + break + } + iovecs := d.bufs[k].nextIovecs() + iovLen := len(iovecs) + d.msgHdrs[k].DataLen = 0 + d.msgHdrs[k].Msg.Iov = &iovecs[0] + d.msgHdrs[k].Msg.SetIovlen(iovLen) + } + + nMsgs, errno := rawfile.BlockingRecvMMsgUntilStopped(d.ReadFD, d.fd, d.msgHdrs) + if errno != 0 { + return false, TranslateErrno(errno) + } + if nMsgs == -1 { + return false, nil + } + + // Process each of received packets. + + d.e.mu.RLock() + addr := d.e.addr + dsp := d.e.dispatcher + d.e.mu.RUnlock() + + d.gro.Dispatcher = dsp + defer d.pkts.Reset() + + for k := 0; k < nMsgs; k++ { + n := int(d.msgHdrs[k].DataLen) + payload := d.bufs[k].pullBuffer(n) + payload.TrimFront(4) + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: payload, + }) + d.pkts.PushBack(pkt) + + // Mark that this iovec has been processed. + d.msgHdrs[k].Msg.Iovlen = 0 + + if d.e.parseInboundHeader(pkt, addr) { + pkt.RXChecksumValidated = d.e.caps&stack.CapabilityRXChecksumOffload != 0 + d.mgr.queuePacket(pkt, d.e.hdrSize > 0) + } + } + d.mgr.wakeReady() + + return true, nil +} diff --git a/internal/fdbased_darwin/processor_mutex.go b/internal/fdbased_darwin/processor_mutex.go new file mode 100644 index 0000000..cd297d2 --- /dev/null +++ b/internal/fdbased_darwin/processor_mutex.go @@ -0,0 +1,64 @@ +package fdbased + +import ( + "reflect" + + "github.com/sagernet/gvisor/pkg/sync" + "github.com/sagernet/gvisor/pkg/sync/locking" +) + +// Mutex is sync.Mutex with the correctness validator. +type processorMutex struct { + mu sync.Mutex +} + +var processorprefixIndex *locking.MutexClass + +// lockNames is a list of user-friendly lock names. +// Populated in init. +var processorlockNames []string + +// lockNameIndex is used as an index passed to NestedLock and NestedUnlock, +// referring to an index within lockNames. +// Values are specified using the "consts" field of go_template_instance. +type processorlockNameIndex int + +// DO NOT REMOVE: The following function automatically replaced with lock index constants. +// LOCK_NAME_INDEX_CONSTANTS +const () + +// Lock locks m. +// +checklocksignore +func (m *processorMutex) Lock() { + locking.AddGLock(processorprefixIndex, -1) + m.mu.Lock() +} + +// NestedLock locks m knowing that another lock of the same type is held. +// +checklocksignore +func (m *processorMutex) NestedLock(i processorlockNameIndex) { + locking.AddGLock(processorprefixIndex, int(i)) + m.mu.Lock() +} + +// Unlock unlocks m. +// +checklocksignore +func (m *processorMutex) Unlock() { + locking.DelGLock(processorprefixIndex, -1) + m.mu.Unlock() +} + +// NestedUnlock unlocks m knowing that another lock of the same type is held. +// +checklocksignore +func (m *processorMutex) NestedUnlock(i processorlockNameIndex) { + locking.DelGLock(processorprefixIndex, int(i)) + m.mu.Unlock() +} + +// DO NOT REMOVE: The following function is automatically replaced. +func processorinitLockNames() {} + +func init() { + processorinitLockNames() + processorprefixIndex = locking.NewMutexClass(reflect.TypeOf(processorMutex{}), processorlockNames) +} diff --git a/internal/fdbased_darwin/processors.go b/internal/fdbased_darwin/processors.go new file mode 100644 index 0000000..9df6cfa --- /dev/null +++ b/internal/fdbased_darwin/processors.go @@ -0,0 +1,275 @@ +// Copyright 2024 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fdbased + +import ( + "context" + "encoding/binary" + + "github.com/sagernet/gvisor/pkg/rand" + "github.com/sagernet/gvisor/pkg/sleep" + "github.com/sagernet/gvisor/pkg/sync" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/hash/jenkins" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/stack/gro" +) + +// +stateify savable +type processor struct { + mu processorMutex `state:"nosave"` + // +checklocks:mu + pkts stack.PacketBufferList + + e *endpoint + gro gro.GRO + sleeper sleep.Sleeper + packetWaker sleep.Waker + closeWaker sleep.Waker +} + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + for { + switch w := p.sleeper.Fetch(true); { + case w == &p.packetWaker: + p.deliverPackets() + case w == &p.closeWaker: + p.mu.Lock() + p.pkts.Reset() + p.mu.Unlock() + return + } + } +} + +func (p *processor) deliverPackets() { + p.e.mu.RLock() + p.gro.Dispatcher = p.e.dispatcher + p.e.mu.RUnlock() + if p.gro.Dispatcher == nil { + p.mu.Lock() + p.pkts.Reset() + p.mu.Unlock() + return + } + + p.mu.Lock() + for p.pkts.Len() > 0 { + pkt := p.pkts.PopFront() + p.mu.Unlock() + p.gro.Enqueue(pkt) + pkt.DecRef() + p.mu.Lock() + } + p.mu.Unlock() + p.gro.Flush() +} + +// processorManager handles starting, closing, and queuing packets on processor +// goroutines. +// +// +stateify savable +type processorManager struct { + processors []processor + seed uint32 + wg sync.WaitGroup `state:"nosave"` + e *endpoint + ready []bool +} + +// newProcessorManager creates a new processor manager. +func newProcessorManager(opts *Options, e *endpoint) *processorManager { + m := &processorManager{} + m.seed = rand.Uint32() + m.ready = make([]bool, opts.ProcessorsPerChannel) + m.processors = make([]processor, opts.ProcessorsPerChannel) + m.e = e + m.wg.Add(opts.ProcessorsPerChannel) + + for i := range m.processors { + p := &m.processors[i] + p.sleeper.AddWaker(&p.packetWaker) + p.sleeper.AddWaker(&p.closeWaker) + p.gro.Init(false) + p.e = e + } + + return m +} + +// start starts the processor goroutines if the processor manager is configured +// with more than one processor. +func (m *processorManager) start() { + for i := range m.processors { + p := &m.processors[i] + // Only start processor in a separate goroutine if we have multiple of them. + if len(m.processors) > 1 { + go p.start(&m.wg) + } + } +} + +// afterLoad is invoked by stateify. +func (m *processorManager) afterLoad(context.Context) { + m.wg.Add(len(m.processors)) + m.start() +} + +func (m *processorManager) connectionHash(cid *connectionID) uint32 { + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], cid.srcPort) + binary.LittleEndian.PutUint16(payload[2:], cid.dstPort) + + h := jenkins.Sum32(m.seed) + h.Write(payload[:]) + h.Write(cid.srcAddr) + h.Write(cid.dstAddr) + return h.Sum32() +} + +// queuePacket queues a packet to be delivered to the appropriate processor. +func (m *processorManager) queuePacket(pkt *stack.PacketBuffer, hasEthHeader bool) { + var pIdx uint32 + cid, nonConnectionPkt := tcpipConnectionID(pkt) + if !hasEthHeader { + if nonConnectionPkt { + // If there's no eth header this should be a standard tcpip packet. If + // it isn't the packet is invalid so drop it. + return + } + pkt.NetworkProtocolNumber = cid.proto + } + if len(m.processors) == 1 || nonConnectionPkt { + // If the packet is not associated with an active connection, use the + // first processor. + pIdx = 0 + } else { + pIdx = m.connectionHash(&cid) % uint32(len(m.processors)) + } + p := &m.processors[pIdx] + p.mu.Lock() + defer p.mu.Unlock() + p.pkts.PushBack(pkt.IncRef()) + m.ready[pIdx] = true +} + +type connectionID struct { + srcAddr, dstAddr []byte + srcPort, dstPort uint16 + proto tcpip.NetworkProtocolNumber +} + +// tcpipConnectionID returns a tcpip connection id tuple based on the data found +// in the packet. It returns true if the packet is not associated with an active +// connection (e.g ARP, NDP, etc). The method assumes link headers have already +// been processed if they were present. +func tcpipConnectionID(pkt *stack.PacketBuffer) (connectionID, bool) { + var cid connectionID + h, ok := pkt.Data().PullUp(1) + if !ok { + // Skip this packet. + return cid, true + } + + const tcpSrcDstPortLen = 4 + switch header.IPVersion(h) { + case header.IPv4Version: + hdrLen := header.IPv4(h).HeaderLength() + h, ok = pkt.Data().PullUp(int(hdrLen) + tcpSrcDstPortLen) + if !ok { + return cid, true + } + ipHdr := header.IPv4(h[:hdrLen]) + tcpHdr := header.TCP(h[hdrLen:][:tcpSrcDstPortLen]) + + cid.srcAddr = ipHdr.SourceAddressSlice() + cid.dstAddr = ipHdr.DestinationAddressSlice() + // All fragment packets need to be processed by the same goroutine, so + // only record the TCP ports if this is not a fragment packet. + if ipHdr.IsValid(pkt.Data().Size()) && !ipHdr.More() && ipHdr.FragmentOffset() == 0 { + cid.srcPort = tcpHdr.SourcePort() + cid.dstPort = tcpHdr.DestinationPort() + } + cid.proto = header.IPv4ProtocolNumber + case header.IPv6Version: + h, ok = pkt.Data().PullUp(header.IPv6FixedHeaderSize + tcpSrcDstPortLen) + if !ok { + return cid, true + } + ipHdr := header.IPv6(h) + + var tcpHdr header.TCP + if tcpip.TransportProtocolNumber(ipHdr.NextHeader()) == header.TCPProtocolNumber { + tcpHdr = header.TCP(h[header.IPv6FixedHeaderSize:][:tcpSrcDstPortLen]) + } else { + // Slow path for IPv6 extension headers :(. + dataBuf := pkt.Data().ToBuffer() + dataBuf.TrimFront(header.IPv6MinimumSize) + it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataBuf) + defer it.Release() + for { + hdr, done, err := it.Next() + if done || err != nil { + break + } + hdr.Release() + } + h, ok = pkt.Data().PullUp(int(it.HeaderOffset()) + tcpSrcDstPortLen) + if !ok { + return cid, true + } + tcpHdr = header.TCP(h[it.HeaderOffset():][:tcpSrcDstPortLen]) + } + cid.srcAddr = ipHdr.SourceAddressSlice() + cid.dstAddr = ipHdr.DestinationAddressSlice() + cid.srcPort = tcpHdr.SourcePort() + cid.dstPort = tcpHdr.DestinationPort() + cid.proto = header.IPv6ProtocolNumber + default: + return cid, true + } + return cid, false +} + +func (m *processorManager) close() { + if len(m.processors) < 2 { + return + } + for i := range m.processors { + p := &m.processors[i] + p.closeWaker.Assert() + } +} + +// wakeReady wakes up all processors that have a packet queued. If there is only +// one processor, the method delivers the packet inline without waking a +// goroutine. +func (m *processorManager) wakeReady() { + for i, ready := range m.ready { + if !ready { + continue + } + p := &m.processors[i] + if len(m.processors) > 1 { + p.packetWaker.Assert() + } else { + p.deliverPackets() + } + m.ready[i] = false + } +} diff --git a/internal/rawfile_darwin/rawfile.go b/internal/rawfile_darwin/rawfile.go new file mode 100644 index 0000000..b73bd82 --- /dev/null +++ b/internal/rawfile_darwin/rawfile.go @@ -0,0 +1,188 @@ +package rawfile + +import ( + "reflect" + "unsafe" + + "golang.org/x/sys/unix" +) + +// SizeofIovec is the size of a unix.Iovec in bytes. +const SizeofIovec = unsafe.Sizeof(unix.Iovec{}) + +// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a +// host system call in a single array. +const MaxIovs = 1024 + +// IovecFromBytes returns a unix.Iovec representing bs. +// +// Preconditions: len(bs) > 0. +func IovecFromBytes(bs []byte) unix.Iovec { + iov := unix.Iovec{ + Base: &bs[0], + } + iov.SetLen(len(bs)) + return iov +} + +func bytesFromIovec(iov unix.Iovec) (bs []byte) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + sh.Data = uintptr(unsafe.Pointer(iov.Base)) + sh.Len = int(iov.Len) + sh.Cap = int(iov.Len) + return +} + +// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) == +// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >= +// max, AppendIovecFromBytes replaces the final iovec in iovs with one that +// also includes the contents of bs. Note that this implies that +// AppendIovecFromBytes is only usable when the returned iovec slice is used as +// the source of a write. +func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec { + if len(bs) == 0 { + return iovs + } + if len(iovs) < max { + return append(iovs, IovecFromBytes(bs)) + } + iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...)) + return iovs +} + +type MsgHdrX struct { + Msg unix.Msghdr + DataLen uint32 +} + +func NonBlockingSendMMsg(fd int, msgHdrs []MsgHdrX) (int, unix.Errno) { + n, _, e := unix.RawSyscall6(unix.SYS_SENDMSG_X, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0) + return int(n), e +} + +const SizeofMsgHdrX = unsafe.Sizeof(MsgHdrX{}) + +// NonBlockingWriteIovec writes iovec to a file descriptor in a single unix. +// It fails if partial data is written. +func NonBlockingWriteIovec(fd int, iovec []unix.Iovec) unix.Errno { + iovecLen := uintptr(len(iovec)) + _, _, e := unix.RawSyscall(unix.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen) + return e +} + +func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, unix.Errno) { + for { + n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) + if e == 0 { + return int(n), 0 + } + if e != 0 && e != unix.EWOULDBLOCK { + return 0, e + } + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, e + } + if e != 0 && e != unix.EINTR { + return 0, e + } + } +} + +func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MsgHdrX) (int, unix.Errno) { + for { + n, _, e := unix.RawSyscall6(unix.SYS_RECVMSG_X, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0) + if e == 0 { + return int(n), e + } + + if e != 0 && e != unix.EWOULDBLOCK { + return 0, e + } + + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, e + } + if e != 0 && e != unix.EINTR { + return 0, e + } + } +} + +func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) { + // Create kqueue + kq, err := unix.Kqueue() + if err != nil { + return false, unix.Errno(err.(unix.Errno)) + } + defer unix.Close(kq) + + // Prepare kevents for registration + var kevents []unix.Kevent_t + + // Always monitor efd for read events + kevents = append(kevents, unix.Kevent_t{ + Ident: uint64(efd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD | unix.EV_ENABLE, + }) + + // Monitor fd based on requested events + // Convert poll events to kqueue filters + if events&unix.POLLIN != 0 { + kevents = append(kevents, unix.Kevent_t{ + Ident: uint64(fd), + Filter: unix.EVFILT_READ, + Flags: unix.EV_ADD | unix.EV_ENABLE, + }) + } + if events&unix.POLLOUT != 0 { + kevents = append(kevents, unix.Kevent_t{ + Ident: uint64(fd), + Filter: unix.EVFILT_WRITE, + Flags: unix.EV_ADD | unix.EV_ENABLE, + }) + } + + // Register events + _, err = unix.Kevent(kq, kevents, nil, nil) + if err != nil { + return false, unix.Errno(err.(unix.Errno)) + } + + // Wait for events (blocking) + revents := make([]unix.Kevent_t, len(kevents)) + n, err := unix.Kevent(kq, nil, revents, nil) + if err != nil { + return false, unix.Errno(err.(unix.Errno)) + } + + // Check results + var efdHasData bool + var errno unix.Errno + + for i := 0; i < n; i++ { + ev := &revents[i] + + if int(ev.Ident) == efd && ev.Filter == unix.EVFILT_READ { + efdHasData = true + } + + if int(ev.Ident) == fd { + // Check for errors or EOF + if ev.Flags&unix.EV_EOF != 0 { + errno = unix.ECONNRESET + } else if ev.Flags&unix.EV_ERROR != 0 { + // Extract error from Data field + if ev.Data != 0 { + errno = unix.Errno(ev.Data) + } else { + errno = unix.ECONNRESET + } + } + } + } + + return efdHasData, errno +} diff --git a/internal/stopfd_darwin/stopfd.go b/internal/stopfd_darwin/stopfd.go new file mode 100644 index 0000000..fdc3973 --- /dev/null +++ b/internal/stopfd_darwin/stopfd.go @@ -0,0 +1,61 @@ +package stopfd + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +type StopFD struct { + ReadFD int + WriteFD int +} + +func New() (StopFD, error) { + fds := make([]int, 2) + err := unix.Pipe(fds) + if err != nil { + return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to create pipe: %w", err) + } + + if err := unix.SetNonblock(fds[0], true); err != nil { + unix.Close(fds[0]) + unix.Close(fds[1]) + return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to set read end non-blocking: %w", err) + } + + if err := unix.SetNonblock(fds[1], true); err != nil { + unix.Close(fds[0]) + unix.Close(fds[1]) + return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to set write end non-blocking: %w", err) + } + + return StopFD{ReadFD: fds[0], WriteFD: fds[1]}, nil +} + +func (sf *StopFD) Stop() { + signal := []byte{1} + if n, err := unix.Write(sf.WriteFD, signal); n != len(signal) || err != nil { + panic(fmt.Sprintf("write(WriteFD) = (%d, %s), want (%d, nil)", n, err, len(signal))) + } +} + +func (sf *StopFD) Close() error { + var err1, err2 error + if sf.ReadFD != -1 { + err1 = unix.Close(sf.ReadFD) + sf.ReadFD = -1 + } + if sf.WriteFD != -1 { + err2 = unix.Close(sf.WriteFD) + sf.WriteFD = -1 + } + if err1 != nil { + return err1 + } + return err2 +} + +func (sf *StopFD) EFD() int { + return sf.ReadFD +} diff --git a/stack_gvisor.go b/stack_gvisor.go index 213d50f..5e03cc6 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -40,7 +40,7 @@ type GVisor struct { type GVisorTun interface { Tun - NewEndpoint() (stack.LinkEndpoint, error) + NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) } func NewGVisor( @@ -65,12 +65,12 @@ func NewGVisor( } func (t *GVisor) Start() error { - linkEndpoint, err := t.tun.NewEndpoint() + linkEndpoint, nicOptions, err := t.tun.NewEndpoint() if err != nil { return err } linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun} - ipStack, err := NewGVisorStack(linkEndpoint) + ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions) if err != nil { return err } @@ -110,6 +110,10 @@ func AddrFromAddress(address tcpip.Address) netip.Addr { } func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { + return NewGVisorStackWithOptions(ep, stack.NICOptions{}) +} + +func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*stack.Stack, error) { ipStack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -122,7 +126,7 @@ func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) { icmp.NewProtocol6, }, }) - err := ipStack.CreateNIC(DefaultNIC, ep) + err := ipStack.CreateNICWithOptions(DefaultNIC, ep, opts) if err != nil { return nil, gonet.TranslateNetstackError(err) } diff --git a/stack_mixed.go b/stack_mixed.go index 9293fb8..36eef1e 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" ) @@ -72,10 +73,14 @@ func (m *Mixed) tunLoop() { m.txChecksumOffload = linuxTUN.TXChecksumOffload() batchSize := linuxTUN.BatchSize() if batchSize > 1 { - m.batchLoop(linuxTUN, batchSize) + m.batchLoopLinux(linuxTUN, batchSize) return } } + if darwinTUN, isDarwinTUN := m.tun.(DarwinTUN); isDarwinTUN { + m.batchLoopDarwin(darwinTUN) + return + } packetBuffer := make([]byte, m.mtu+PacketOffset) for { n, err := m.tun.Read(packetBuffer) @@ -119,12 +124,12 @@ func (m *Mixed) wintunLoop(winTun WinTun) { } } -func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { +func (m *Mixed) batchLoopLinux(linuxTUN LinuxTUN, batchSize int) { packetBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) for i := range packetBuffers { - packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom) + packetBuffers[i] = make([]byte, m.mtu+PacketOffset+m.frontHeadroom) } for { n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes) @@ -158,6 +163,40 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) { } } +func (m *Mixed) batchLoopDarwin(darwinTUN DarwinTUN) { + var writeBuffers []*buf.Buffer + for { + buffers, err := darwinTUN.BatchRead() + if err != nil { + if E.IsClosed(err) { + return + } + m.logger.Error(E.Cause(err, "batch read packet")) + } + if len(buffers) == 0 { + continue + } + writeBuffers = writeBuffers[:0] + for _, buffer := range buffers { + packetSize := buffer.Len() + if packetSize < header.IPv4MinimumSize { + continue + } + if m.processPacket(buffer.Bytes()) { + writeBuffers = append(writeBuffers, buffer) + } else { + buffer.Release() + } + } + if len(writeBuffers) > 0 { + err = darwinTUN.BatchWrite(writeBuffers) + if err != nil { + m.logger.Trace(E.Cause(err, "batch write packet")) + } + } + } +} + func (m *Mixed) processPacket(packet []byte) bool { var ( writeBack bool diff --git a/stack_system.go b/stack_system.go index 23070fe..6b2fde4 100644 --- a/stack_system.go +++ b/stack_system.go @@ -170,10 +170,14 @@ func (s *System) tunLoop() { s.txChecksumOffload = linuxTUN.TXChecksumOffload() batchSize := linuxTUN.BatchSize() if batchSize > 1 { - s.batchLoop(linuxTUN, batchSize) + s.batchLoopLinux(linuxTUN, batchSize) return } } + if darwinTUN, isDarwinTUN := s.tun.(DarwinTUN); isDarwinTUN { + s.batchLoopDarwin(darwinTUN) + return + } packetBuffer := make([]byte, s.mtu+PacketOffset) for { n, err := s.tun.Read(packetBuffer) @@ -217,7 +221,7 @@ func (s *System) wintunLoop(winTun WinTun) { } } -func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { +func (s *System) batchLoopLinux(linuxTUN LinuxTUN, batchSize int) { packetBuffers := make([][]byte, batchSize) writeBuffers := make([][]byte, batchSize) packetSizes := make([]int, batchSize) @@ -256,6 +260,40 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) { } } +func (s *System) batchLoopDarwin(darwinTUN DarwinTUN) { + var writeBuffers []*buf.Buffer + for { + buffers, err := darwinTUN.BatchRead() + if err != nil { + if E.IsClosed(err) { + return + } + s.logger.Error(E.Cause(err, "batch read packet")) + } + if len(buffers) == 0 { + continue + } + writeBuffers = writeBuffers[:0] + for _, buffer := range buffers { + packetSize := buffer.Len() + if packetSize < header.IPv4MinimumSize { + continue + } + if s.processPacket(buffer.Bytes()) { + writeBuffers = append(writeBuffers, buffer) + } else { + buffer.Release() + } + } + if len(writeBuffers) > 0 { + err = darwinTUN.BatchWrite(writeBuffers) + if err != nil { + s.logger.Trace(E.Cause(err, "batch write packet")) + } + } + } +} + func (s *System) processPacket(packet []byte) bool { var ( writeBack bool diff --git a/tun.go b/tun.go index 882adc5..03ded6a 100644 --- a/tun.go +++ b/tun.go @@ -8,6 +8,7 @@ import ( "strconv" "strings" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" @@ -45,6 +46,12 @@ type LinuxTUN interface { TXChecksumOffload() bool } +type DarwinTUN interface { + Tun + BatchRead() ([]*buf.Buffer, error) + BatchWrite(buffers []*buf.Buffer) error +} + const ( DefaultIPRoute2TableIndex = 2022 DefaultIPRoute2RuleIndex = 9000 diff --git a/tun_darwin.go b/tun_darwin.go index 2462d75..ed04234 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -10,6 +10,8 @@ import ( "unsafe" "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/internal/rawfile_darwin" + "github.com/sagernet/sing-tun/internal/stopfd_darwin" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -21,15 +23,64 @@ import ( "golang.org/x/sys/unix" ) +var _ DarwinTUN = (*NativeTun)(nil) + const PacketOffset = 4 type NativeTun struct { - tunFile *os.File - tunWriter N.VectorisedWriter - options Options - inet4Address [4]byte - inet6Address [16]byte - routeSet bool + tunFd int + tunFile *os.File + batchSize int + iovecs []iovecBuffer + iovecsOutput []iovecBuffer + msgHdrs []rawfile.MsgHdrX + msgHdrsOutput []rawfile.MsgHdrX + buffers []*buf.Buffer + stopFd stopfd.StopFD + tunWriter N.VectorisedWriter + options Options + inet4Address [4]byte + inet6Address [16]byte + routeSet bool +} + +type iovecBuffer struct { + mtu int + buffer *buf.Buffer + iovecs []unix.Iovec +} + +func newIovecBuffer(mtu int) iovecBuffer { + return iovecBuffer{ + mtu: mtu, + iovecs: make([]unix.Iovec, 2), + } +} + +func (b *iovecBuffer) nextIovecs() []unix.Iovec { + if b.iovecs[0].Len == 0 { + headBuffer := make([]byte, PacketOffset) + b.iovecs[0].Base = &headBuffer[0] + b.iovecs[0].SetLen(PacketOffset) + } + if b.buffer == nil { + b.buffer = buf.NewSize(b.mtu) + b.iovecs[1].Base = &b.buffer.FreeBytes()[0] + b.iovecs[1].SetLen(b.mtu) + } + return b.iovecs +} + +func (b *iovecBuffer) nextIovecsOutput(buffer *buf.Buffer) []unix.Iovec { + switch header.IPVersion(buffer.Bytes()) { + case header.IPv4Version: + b.iovecs[0] = packetHeaderVec4 + case header.IPv6Version: + b.iovecs[0] = packetHeaderVec6 + } + b.iovecs[1].Base = &buffer.Bytes()[0] + b.iovecs[1].SetLen(buffer.Len()) + return b.iovecs } func (t *NativeTun) Name() (string, error) { @@ -42,6 +93,7 @@ func (t *NativeTun) Name() (string, error) { func New(options Options) (Tun, error) { var tunFd int + batchSize := ((512 * 1024) / int(options.MTU)) + 1 if options.FileDescriptor == 0 { ifIndex := -1 _, err := fmt.Sscanf(options.Name, "utun%d", &ifIndex) @@ -54,18 +106,37 @@ func New(options Options) (Tun, error) { return nil, err } - err = configure(tunFd, ifIndex, options.Name, options) + err = create(tunFd, ifIndex, options.Name, options) + if err != nil { + unix.Close(tunFd) + return nil, err + } + err = configure(tunFd, batchSize) if err != nil { unix.Close(tunFd) return nil, err } } else { tunFd = options.FileDescriptor + err := configure(tunFd, batchSize) + if err != nil { + return nil, err + } } - nativeTun := &NativeTun{ - tunFile: os.NewFile(uintptr(tunFd), "utun"), - options: options, + tunFd: tunFd, + tunFile: os.NewFile(uintptr(tunFd), "utun"), + options: options, + batchSize: batchSize, + iovecs: make([]iovecBuffer, batchSize), + iovecsOutput: make([]iovecBuffer, batchSize), + msgHdrs: make([]rawfile.MsgHdrX, batchSize), + msgHdrsOutput: make([]rawfile.MsgHdrX, batchSize), + stopFd: common.Must1(stopfd.New()), + } + for i := 0; i < batchSize; i++ { + nativeTun.iovecs[i] = newIovecBuffer(int(options.MTU)) + nativeTun.iovecsOutput[i] = newIovecBuffer(int(options.MTU)) } if len(options.Inet4Address) > 0 { nativeTun.inet4Address = options.Inet4Address[0].Addr().As4() @@ -100,10 +171,17 @@ func (t *NativeTun) Write(p []byte) (n int, err error) { } var ( - packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET} - packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6} + packetHeader4 = []byte{0x00, 0x00, 0x00, unix.AF_INET} + packetHeader6 = []byte{0x00, 0x00, 0x00, unix.AF_INET6} + packetHeaderVec4 = unix.Iovec{Base: &packetHeader4[0]} + packetHeaderVec6 = unix.Iovec{Base: &packetHeader6[0]} ) +func init() { + packetHeaderVec4.SetLen(4) + packetHeaderVec6.SetLen(4) +} + func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { var packetHeader []byte switch header.IPVersion(buffers[0].Bytes()) { @@ -147,7 +225,7 @@ type addrLifetime6 struct { Pltime uint32 } -func configure(tunFd int, ifIndex int, name string, options Options) error { +func create(tunFd int, ifIndex int, name string, options Options) error { ctlInfo := &unix.CtlInfo{} copy(ctlInfo.Name[:], utunControlName) err := unix.IoctlCtlInfo(tunFd, ctlInfo) @@ -163,11 +241,6 @@ func configure(tunFd int, ifIndex int, name string, options Options) error { return os.NewSyscallError("Connect", err) } - err = unix.SetNonblock(tunFd, true) - if err != nil { - return os.NewSyscallError("SetNonblock", err) - } - err = useSocket(unix.AF_INET, unix.SOCK_DGRAM, 0, func(socketFd int) error { var ifr unix.IfreqMTU copy(ifr.Name[:], name) @@ -259,6 +332,65 @@ func configure(tunFd int, ifIndex int, name string, options Options) error { return nil } +func configure(tunFd int, batchSize int) error { + err := unix.SetNonblock(tunFd, true) + if err != nil { + return os.NewSyscallError("SetNonblock", err) + } + const UTUN_OPT_MAX_PENDING_PACKETS = 16 + err = unix.SetsockoptInt(tunFd, 2, UTUN_OPT_MAX_PENDING_PACKETS, batchSize) + if err != nil { + return os.NewSyscallError("SetsockoptInt UTUN_OPT_MAX_PENDING_PACKETS", err) + } + return nil +} + +func (t *NativeTun) BatchSize() int { + return t.batchSize +} + +func (t *NativeTun) BatchRead() ([]*buf.Buffer, error) { + for i := 0; i < t.batchSize; i++ { + iovecs := t.iovecs[i].nextIovecs() + t.msgHdrs[i].DataLen = 0 + t.msgHdrs[i].Msg.Iov = &iovecs[0] + t.msgHdrs[i].Msg.Iovlen = 2 + } + n, errno := rawfile.BlockingRecvMMsgUntilStopped(t.stopFd.ReadFD, t.tunFd, t.msgHdrs) + if errno != 0 { + return nil, errno + } + if n < 1 { + return nil, nil + } + buffers := t.buffers + for k := 0; k < n; k++ { + buffer := t.iovecs[k].buffer + t.iovecs[k].buffer = nil + buffer.Truncate(int(t.msgHdrs[k].DataLen) - PacketOffset) + buffers = append(buffers, buffer) + } + t.buffers = buffers[:0] + return buffers, nil +} + +func (t *NativeTun) BatchWrite(buffers []*buf.Buffer) error { + for i, buffer := range buffers { + iovecs := t.iovecsOutput[i].nextIovecsOutput(buffer) + t.msgHdrsOutput[i].Msg.Iov = &iovecs[0] + t.msgHdrsOutput[i].Msg.Iovlen = 2 + } + _, errno := rawfile.NonBlockingSendMMsg(t.tunFd, t.msgHdrsOutput[:len(buffers)]) + if errno != 0 { + return errno + } + return nil +} + +func (t *NativeTun) TXChecksumOffload() bool { + return false +} + func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error { err := t.unsetRoutes() if err != nil { diff --git a/tun_darwin_gvisor.go b/tun_darwin_gvisor.go index df46bf1..16ecbe7 100644 --- a/tun_darwin_gvisor.go +++ b/tun_darwin_gvisor.go @@ -3,132 +3,23 @@ package tun import ( - "github.com/sagernet/gvisor/pkg/buffer" - "github.com/sagernet/gvisor/pkg/tcpip" - "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/link/qdisc/fifo" "github.com/sagernet/gvisor/pkg/tcpip/stack" - "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing-tun/internal/fdbased_darwin" ) var _ GVisorTun = (*NativeTun)(nil) -func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { - return &DarwinEndpoint{tun: t}, nil -} - -var _ stack.LinkEndpoint = (*DarwinEndpoint)(nil) - -type DarwinEndpoint struct { - tun *NativeTun - dispatcher stack.NetworkDispatcher -} - -func (e *DarwinEndpoint) MTU() uint32 { - return e.tun.options.MTU -} - -func (e *DarwinEndpoint) SetMTU(mtu uint32) { -} - -func (e *DarwinEndpoint) MaxHeaderLength() uint16 { - return 0 -} - -func (e *DarwinEndpoint) LinkAddress() tcpip.LinkAddress { - return "" -} - -func (e *DarwinEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { -} - -func (e *DarwinEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return stack.CapabilityRXChecksumOffload -} - -func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - if dispatcher == nil && e.dispatcher != nil { - e.dispatcher = nil - return - } - if dispatcher != nil && e.dispatcher == nil { - e.dispatcher = dispatcher - go e.dispatchLoop() +func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) { + ep, err := fdbased.New(&fdbased.Options{ + FDs: []int{t.tunFd}, + MTU: t.options.MTU, + RXChecksumOffload: true, + }) + if err != nil { + return nil, stack.NICOptions{}, err } -} - -func (e *DarwinEndpoint) dispatchLoop() { - packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset) - for { - n, err := e.tun.tunFile.Read(packetBuffer) - if err != nil { - break - } - packet := packetBuffer[PacketOffset:n] - var networkProtocol tcpip.NetworkProtocolNumber - switch header.IPVersion(packet) { - case header.IPv4Version: - networkProtocol = header.IPv4ProtocolNumber - if header.IPv4(packet).DestinationAddress().As4() == e.tun.inet4Address { - e.tun.tunFile.Write(packetBuffer[:n]) - continue - } - case header.IPv6Version: - networkProtocol = header.IPv6ProtocolNumber - if header.IPv6(packet).DestinationAddress().As16() == e.tun.inet6Address { - e.tun.tunFile.Write(packetBuffer[:n]) - continue - } - default: - e.tun.tunFile.Write(packetBuffer[:n]) - continue - } - pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(packetBuffer[4:n]), - IsForwardedPacket: true, - }) - pkt.NetworkProtocolNumber = networkProtocol - dispatcher := e.dispatcher - if dispatcher == nil { - pkt.DecRef() - return - } - dispatcher.DeliverNetworkPacket(networkProtocol, pkt) - pkt.DecRef() - } -} - -func (e *DarwinEndpoint) IsAttached() bool { - return e.dispatcher != nil -} - -func (e *DarwinEndpoint) Wait() { -} - -func (e *DarwinEndpoint) ARPHardwareType() header.ARPHardwareType { - return header.ARPHardwareNone -} - -func (e *DarwinEndpoint) AddHeader(buffer *stack.PacketBuffer) { -} - -func (e *DarwinEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool { - return true -} - -func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) { - var n int - for _, packet := range packetBufferList.AsSlice() { - _, err := bufio.WriteVectorised(e.tun, packet.AsSlices()) - if err != nil { - return n, &tcpip.ErrAborted{} - } - n++ - } - return n, nil -} - -func (e *DarwinEndpoint) Close() { -} - -func (e *DarwinEndpoint) SetOnCloseAction(f func()) { + return ep, stack.NICOptions{ + QDisc: fifo.New(ep, 1, 1000), + }, nil } diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index f82d762..680d6f5 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -9,9 +9,9 @@ import ( var _ GVisorTun = (*NativeTun)(nil) -func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { +func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) { if t.vnetHdr { - return fdbased.New(&fdbased.Options{ + ep, err := fdbased.New(&fdbased.Options{ FDs: []int{t.tunFd}, MTU: t.options.MTU, GSOMaxSize: gsoMaxSize, @@ -19,11 +19,20 @@ func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { RXChecksumOffload: true, TXChecksumOffload: t.txChecksumOffload, }) + if err != nil { + return nil, stack.NICOptions{}, err + } + return ep, stack.NICOptions{}, nil + } else { + ep, err := fdbased.New(&fdbased.Options{ + FDs: []int{t.tunFd}, + MTU: t.options.MTU, + RXChecksumOffload: true, + TXChecksumOffload: t.txChecksumOffload, + }) + if err != nil { + return nil, stack.NICOptions{}, err + } + return ep, stack.NICOptions{}, nil } - return fdbased.New(&fdbased.Options{ - FDs: []int{t.tunFd}, - MTU: t.options.MTU, - RXChecksumOffload: true, - TXChecksumOffload: t.txChecksumOffload, - }) } diff --git a/tun_windows_gvisor.go b/tun_windows_gvisor.go index b87dbfe..463a88b 100644 --- a/tun_windows_gvisor.go +++ b/tun_windows_gvisor.go @@ -11,8 +11,8 @@ import ( var _ GVisorTun = (*NativeTun)(nil) -func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) { - return &WintunEndpoint{tun: t}, nil +func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) { + return &WintunEndpoint{tun: t}, stack.NICOptions{}, nil } var _ stack.LinkEndpoint = (*WintunEndpoint)(nil)