Fix: wrong wait logic

This commit is contained in:
xjasonlyu
2022-03-28 23:14:00 +08:00
parent 9d7dacbea1
commit 42d6c96b6b
4 changed files with 18 additions and 12 deletions

View File

@@ -17,7 +17,4 @@ type Device interface {
// Type returns the driver type of the device. // Type returns the driver type of the device.
Type() string Type() string
// Wait waits for the device to close.
Wait()
} }

View File

@@ -36,6 +36,9 @@ type Endpoint struct {
// once is used to perform the init action once when attaching. // once is used to perform the init action once when attaching.
once sync.Once once sync.Once
// wg keeps track of running goroutines.
wg sync.WaitGroup
} }
// New returns stack.LinkEndpoint(.*Endpoint) and error. // New returns stack.LinkEndpoint(.*Endpoint) and error.
@@ -60,19 +63,26 @@ func New(rw io.ReadWriter, mtu uint32, offset int) (*Endpoint, error) {
}, nil }, nil
} }
func (e *Endpoint) Close() {
e.Endpoint.Close()
}
// Attach launches the goroutine that reads packets from io.Reader and // Attach launches the goroutine that reads packets from io.Reader and
// dispatches them via the provided dispatcher. // dispatches them via the provided dispatcher.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.Endpoint.Attach(dispatcher)
e.once.Do(func() { e.once.Do(func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go e.dispatchLoop(cancel) e.wg.Add(2)
go e.outboundLoop(ctx) go func() {
e.outboundLoop(ctx)
e.wg.Done()
}()
go func() {
e.dispatchLoop(cancel)
e.wg.Done()
}()
}) })
e.Endpoint.Attach(dispatcher) }
func (e *Endpoint) Wait() {
e.wg.Wait()
} }
// dispatchLoop dispatches packets to upper layer. // dispatchLoop dispatches packets to upper layer.

View File

@@ -69,6 +69,6 @@ func (t *TUN) Name() string {
} }
func (t *TUN) Close() error { func (t *TUN) Close() error {
t.Endpoint.Close() defer t.Endpoint.Close()
return t.nt.Close() return t.nt.Close()
} }

View File

@@ -83,7 +83,6 @@ func (e *engine) start() error {
func (e *engine) stop() (err error) { func (e *engine) stop() (err error) {
if e.device != nil { if e.device != nil {
err = e.device.Close() err = e.device.Close()
e.device.Wait()
} }
if e.stack != nil { if e.stack != nil {
e.stack.Close() e.stack.Close()