mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 20:02:37 +08:00 
			
		
		
		
	 d664444928
			
		
	
	d664444928
	
	
	
		
			
			Persnickety consumers can now do:
    func init() {
        tun.WintunPool, _ = wintun.MakePool("Flurp")
        tun.WintunStaticRequestedGUID, _ = windows.GUIDFromString("{5ae2716f-0b3e-4dc4-a8b5-48eba11a6e16}")
    }
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
		
	
		
			
				
	
	
		
			234 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			234 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /* SPDX-License-Identifier: MIT
 | |
|  *
 | |
|  * Copyright (C) 2018-2020 WireGuard LLC. All Rights Reserved.
 | |
|  */
 | |
| 
 | |
| package tun
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"log"
 | |
| 	"os"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 	_ "unsafe"
 | |
| 
 | |
| 	"golang.org/x/sys/windows"
 | |
| 
 | |
| 	"golang.zx2c4.com/wireguard/tun/wintun"
 | |
| )
 | |
| 
 | |
| const (
 | |
| 	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
 | |
| 	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
 | |
| 	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
 | |
| )
 | |
| 
 | |
| type rateJuggler struct {
 | |
| 	current       uint64
 | |
| 	nextByteCount uint64
 | |
| 	nextStartTime int64
 | |
| 	changing      int32
 | |
| }
 | |
| 
 | |
| type NativeTun struct {
 | |
| 	wt        *wintun.Adapter
 | |
| 	handle    windows.Handle
 | |
| 	close     bool
 | |
| 	events    chan Event
 | |
| 	errors    chan error
 | |
| 	forcedMTU int
 | |
| 	rate      rateJuggler
 | |
| 	session   wintun.Session
 | |
| 	readWait  windows.Handle
 | |
| }
 | |
| 
 | |
| var WintunPool, _ = wintun.MakePool("WireGuard")
 | |
| var WintunStaticRequestedGUID *windows.GUID
 | |
| 
 | |
| //go:linkname procyield runtime.procyield
 | |
| func procyield(cycles uint32)
 | |
| 
 | |
| //go:linkname nanotime runtime.nanotime
 | |
| func nanotime() int64
 | |
| 
 | |
| //
 | |
| // CreateTUN creates a Wintun interface with the given name. Should a Wintun
 | |
| // interface with the same name exist, it is reused.
 | |
| //
 | |
| func CreateTUN(ifname string, mtu int) (Device, error) {
 | |
| 	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
 | |
| }
 | |
| 
 | |
| //
 | |
| // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
 | |
| // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
 | |
| //
 | |
| func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
 | |
| 	var err error
 | |
| 	var wt *wintun.Adapter
 | |
| 
 | |
| 	// Does an interface with this name already exist?
 | |
| 	wt, err = WintunPool.OpenAdapter(ifname)
 | |
| 	if err == nil {
 | |
| 		// If so, we delete it, in case it has weird residual configuration.
 | |
| 		_, err = wt.Delete(true)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
 | |
| 		}
 | |
| 	}
 | |
| 	wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
 | |
| 	if err != nil {
 | |
| 		return nil, fmt.Errorf("Error creating interface: %w", err)
 | |
| 	}
 | |
| 	if rebootRequired {
 | |
| 		log.Println("Windows indicated a reboot is required.")
 | |
| 	}
 | |
| 
 | |
| 	forcedMTU := 1420
 | |
| 	if mtu > 0 {
 | |
| 		forcedMTU = mtu
 | |
| 	}
 | |
| 
 | |
| 	tun := &NativeTun{
 | |
| 		wt:        wt,
 | |
| 		handle:    windows.InvalidHandle,
 | |
| 		events:    make(chan Event, 10),
 | |
| 		errors:    make(chan error, 1),
 | |
| 		forcedMTU: forcedMTU,
 | |
| 	}
 | |
| 
 | |
| 	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
 | |
| 	if err != nil {
 | |
| 		tun.wt.Delete(false)
 | |
| 		close(tun.events)
 | |
| 		return nil, fmt.Errorf("Error starting session: %w", err)
 | |
| 	}
 | |
| 	tun.readWait = tun.session.ReadWaitEvent()
 | |
| 	return tun, nil
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) Name() (string, error) {
 | |
| 	return tun.wt.Name()
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) File() *os.File {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) Events() chan Event {
 | |
| 	return tun.events
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) Close() error {
 | |
| 	tun.close = true
 | |
| 	tun.session.End()
 | |
| 	var err error
 | |
| 	if tun.wt != nil {
 | |
| 		_, err = tun.wt.Delete(false)
 | |
| 	}
 | |
| 	close(tun.events)
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) MTU() (int, error) {
 | |
| 	return tun.forcedMTU, nil
 | |
| }
 | |
| 
 | |
| // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
 | |
| func (tun *NativeTun) ForceMTU(mtu int) {
 | |
| 	tun.forcedMTU = mtu
 | |
| }
 | |
| 
 | |
| // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
 | |
| 
 | |
| func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
 | |
| retry:
 | |
| 	select {
 | |
| 	case err := <-tun.errors:
 | |
| 		return 0, err
 | |
| 	default:
 | |
| 	}
 | |
| 	start := nanotime()
 | |
| 	shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
 | |
| 	for {
 | |
| 		if tun.close {
 | |
| 			return 0, os.ErrClosed
 | |
| 		}
 | |
| 		packet, err := tun.session.ReceivePacket()
 | |
| 		switch err {
 | |
| 		case nil:
 | |
| 			packetSize := len(packet)
 | |
| 			copy(buff[offset:], packet)
 | |
| 			tun.session.ReleaseReceivePacket(packet)
 | |
| 			tun.rate.update(uint64(packetSize))
 | |
| 			return packetSize, nil
 | |
| 		case windows.ERROR_NO_MORE_ITEMS:
 | |
| 			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
 | |
| 				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
 | |
| 				goto retry
 | |
| 			}
 | |
| 			procyield(1)
 | |
| 			continue
 | |
| 		case windows.ERROR_HANDLE_EOF:
 | |
| 			return 0, os.ErrClosed
 | |
| 		case windows.ERROR_INVALID_DATA:
 | |
| 			return 0, errors.New("Send ring corrupt")
 | |
| 		}
 | |
| 		return 0, fmt.Errorf("Read failed: %w", err)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) Flush() error {
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
 | |
| 	if tun.close {
 | |
| 		return 0, os.ErrClosed
 | |
| 	}
 | |
| 
 | |
| 	packetSize := len(buff) - offset
 | |
| 	tun.rate.update(uint64(packetSize))
 | |
| 
 | |
| 	packet, err := tun.session.AllocateSendPacket(packetSize)
 | |
| 	if err == nil {
 | |
| 		copy(packet, buff[offset:])
 | |
| 		tun.session.SendPacket(packet)
 | |
| 		return packetSize, nil
 | |
| 	}
 | |
| 	switch err {
 | |
| 	case windows.ERROR_HANDLE_EOF:
 | |
| 		return 0, os.ErrClosed
 | |
| 	case windows.ERROR_BUFFER_OVERFLOW:
 | |
| 		return 0, nil // Dropping when ring is full.
 | |
| 	}
 | |
| 	return 0, fmt.Errorf("Write failed: %w", err)
 | |
| }
 | |
| 
 | |
| // LUID returns Windows interface instance ID.
 | |
| func (tun *NativeTun) LUID() uint64 {
 | |
| 	return tun.wt.LUID()
 | |
| }
 | |
| 
 | |
| // RunningVersion returns the running version of the Wintun driver.
 | |
| func (tun *NativeTun) RunningVersion() (version uint32, err error) {
 | |
| 	return wintun.RunningVersion()
 | |
| }
 | |
| 
 | |
| func (rate *rateJuggler) update(packetLen uint64) {
 | |
| 	now := nanotime()
 | |
| 	total := atomic.AddUint64(&rate.nextByteCount, packetLen)
 | |
| 	period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
 | |
| 	if period >= rateMeasurementGranularity {
 | |
| 		if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
 | |
| 			return
 | |
| 		}
 | |
| 		atomic.StoreInt64(&rate.nextStartTime, now)
 | |
| 		atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
 | |
| 		atomic.StoreUint64(&rate.nextByteCount, 0)
 | |
| 		atomic.StoreInt32(&rate.changing, 0)
 | |
| 	}
 | |
| }
 |