Files
tun/device/endpoint.go
Shiming Zhang 31edb3af59 Cleanup
2022-03-13 19:00:24 +08:00

136 lines
3.2 KiB
Go

package device
import (
"context"
"fmt"
"io"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
const (
// Queue length for outbound packet, arriving for read. Overflow
// causes packet drops.
defaultOutQueueLen = 1 << 10
)
// Endpoint implements the interface of stack.LinkEndpoint from io.ReadWriter.
type Endpoint struct {
*channel.Endpoint
// rw is the io.ReadWriter for reading and writing packets.
rw io.ReadWriter
// mtu (maximum transmission unit) is the maximum size of a packet.
mtu uint32
// offset can be useful when perform TUN device I/O with TUN_PI enabled.
offset int
// once is used to perform the init action once when attaching.
once sync.Once
}
// New returns stack.LinkEndpoint(.*Endpoint) and error.
func New(rw io.ReadWriter, mtu uint32, offset int) (*Endpoint, error) {
if mtu == 0 {
return nil, fmt.Errorf("MTU size is zero")
}
if rw == nil {
return nil, fmt.Errorf("RW interface is nil")
}
if offset < 0 {
return nil, fmt.Errorf("offset must be non-negative")
}
return &Endpoint{
Endpoint: channel.New(defaultOutQueueLen, mtu, ""),
rw: rw,
mtu: mtu,
offset: offset,
}, nil
}
func (e *Endpoint) Close() {
e.Endpoint.Close()
}
// Attach launches the goroutine that reads packets from io.Reader and
// dispatches them via the provided dispatcher.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.once.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
go e.dispatchLoop(cancel)
go e.outboundLoop(ctx)
})
e.Endpoint.Attach(dispatcher)
}
// dispatchLoop dispatches packets to upper layer.
func (e *Endpoint) dispatchLoop(cancel context.CancelFunc) {
// Call cancel() to ensure (*Endpoint).outboundLoop(context.Context) exits
// gracefully after (*Endpoint).dispatchLoop(context.CancelFunc) returns.
defer cancel()
for {
data := make([]byte, e.offset+int(e.mtu))
n, err := e.rw.Read(data)
if err != nil {
break
}
if !e.IsAttached() {
continue /* unattached, drop packet */
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(data[e.offset : e.offset+n]).ToVectorisedView(),
})
switch header.IPVersion(data[e.offset:]) {
case header.IPv4Version:
e.InjectInbound(header.IPv4ProtocolNumber, pkt)
case header.IPv6Version:
e.InjectInbound(header.IPv6ProtocolNumber, pkt)
}
pkt.DecRef() /* release */
}
}
// outboundLoop reads outbound packets from channel, and then it calls
// writePacket to send those packets back to lower layer.
func (e *Endpoint) outboundLoop(ctx context.Context) {
for {
pkt := e.ReadContext(ctx)
if pkt == nil {
break
}
e.writePacket(pkt)
}
}
// writePacket writes outbound packets to the io.Writer.
func (e *Endpoint) writePacket(pkt *stack.PacketBuffer) tcpip.Error {
size := pkt.Size()
views := pkt.Views()
if e.offset != 0 {
views = append([]buffer.View{
make(buffer.View, e.offset),
}, views...)
}
vView := buffer.NewVectorisedView(size, views)
if _, err := e.rw.Write(vView.ToView()); err != nil {
return &tcpip.ErrInvalidEndpointState{}
}
return nil
}