fix leaking resources in cancelled gRPC streams

Signed-off-by: Steffen Vogel <post@steffenvogel.de>
This commit is contained in:
Steffen Vogel
2022-08-23 08:36:42 +02:00
parent e8c9091ed3
commit c2302a6a2c
9 changed files with 98 additions and 96 deletions

View File

@@ -12,21 +12,21 @@ import (
) )
func (s *Server) OnInterfaceAdded(i *core.Interface) { func (s *Server) OnInterfaceAdded(i *core.Interface) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_INTERFACE_ADDED, Type: pb.Event_INTERFACE_ADDED,
Interface: i.Name(), Interface: i.Name(),
} })
} }
func (s *Server) OnInterfaceRemoved(i *core.Interface) { func (s *Server) OnInterfaceRemoved(i *core.Interface) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_INTERFACE_REMOVED, Type: pb.Event_INTERFACE_REMOVED,
Interface: i.Name(), Interface: i.Name(),
} })
} }
func (s *Server) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core.InterfaceModifier) { func (s *Server) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core.InterfaceModifier) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_INTERFACE_MODIFIED, Type: pb.Event_INTERFACE_MODIFIED,
Interface: i.Name(), Interface: i.Name(),
Event: &pb.Event_InterfaceModified{ Event: &pb.Event_InterfaceModified{
@@ -34,27 +34,27 @@ func (s *Server) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core
Modified: uint32(mod), Modified: uint32(mod),
}, },
}, },
} })
} }
func (s *Server) OnPeerAdded(p *core.Peer) { func (s *Server) OnPeerAdded(p *core.Peer) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_PEER_ADDED, Type: pb.Event_PEER_ADDED,
Interface: p.Interface.Name(), Interface: p.Interface.Name(),
Peer: p.PublicKey().Bytes(), Peer: p.PublicKey().Bytes(),
} })
} }
func (s *Server) OnPeerRemoved(p *core.Peer) { func (s *Server) OnPeerRemoved(p *core.Peer) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_PEER_REMOVED, Type: pb.Event_PEER_REMOVED,
Interface: p.Interface.Name(), Interface: p.Interface.Name(),
Peer: p.PublicKey().Bytes(), Peer: p.PublicKey().Bytes(),
} })
} }
func (s *Server) OnPeerModified(p *core.Peer, old *wgtypes.Peer, mod core.PeerModifier, ipsAdded, ipsRemoved []net.IPNet) { func (s *Server) OnPeerModified(p *core.Peer, old *wgtypes.Peer, mod core.PeerModifier, ipsAdded, ipsRemoved []net.IPNet) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_PEER_MODIFIED, Type: pb.Event_PEER_MODIFIED,
Interface: p.Interface.Name(), Interface: p.Interface.Name(),
Peer: p.PublicKey().Bytes(), Peer: p.PublicKey().Bytes(),
@@ -64,11 +64,11 @@ func (s *Server) OnPeerModified(p *core.Peer, old *wgtypes.Peer, mod core.PeerMo
Modified: uint32(mod), Modified: uint32(mod),
}, },
}, },
} })
} }
func (s *Server) OnSignalingBackendReady(b signaling.Backend) { func (s *Server) OnSignalingBackendReady(b signaling.Backend) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_BACKEND_READY, Type: pb.Event_BACKEND_READY,
Event: &pb.Event_BackendReady{ Event: &pb.Event_BackendReady{
@@ -76,7 +76,7 @@ func (s *Server) OnSignalingBackendReady(b signaling.Backend) {
Type: b.Type(), Type: b.Type(),
}, },
}, },
} })
} }
func (s *Server) OnSignalingMessage(kp *crypto.PublicKeyPair, msg *signaling.Message) { func (s *Server) OnSignalingMessage(kp *crypto.PublicKeyPair, msg *signaling.Message) {

View File

@@ -33,7 +33,7 @@ type Server struct {
func NewServer(d *wice.Daemon) (*Server, error) { func NewServer(d *wice.Daemon) (*Server, error) {
s := &Server{ s := &Server{
events: util.NewFanOut[*pb.Event](0), events: util.NewFanOut[*pb.Event](1),
logger: zap.L().Named("rpc.server"), logger: zap.L().Named("rpc.server"),
} }
@@ -79,6 +79,7 @@ func (s *Server) Wait() {
func (s *Server) Close() error { func (s *Server) Close() error {
s.grpc.GracefulStop() s.grpc.GracefulStop()
s.events.Close()
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package rpc
import ( import (
"context" "context"
"fmt" "fmt"
"io"
wice "riasc.eu/wice/pkg" wice "riasc.eu/wice/pkg"
"riasc.eu/wice/pkg/pb" "riasc.eu/wice/pkg/pb"
@@ -33,9 +34,21 @@ func (s *DaemonServer) StreamEvents(params *pb.StreamEventsParams, stream pb.Soc
s.ep.SendConnectionStates(stream) s.ep.SendConnectionStates(stream)
} }
for e := range s.events.Add() { events := s.events.Add()
if err := stream.Send(e); err != nil { defer s.events.Remove(events)
return fmt.Errorf("failed to send event: %w", err)
out:
for {
select {
case event := <-events:
if err := stream.Send(event); err == io.EOF {
break out
} else if err != nil {
return fmt.Errorf("failed to send event: %w", err)
}
case <-stream.Context().Done():
break out
} }
} }

View File

@@ -3,6 +3,7 @@ package rpc
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"go.uber.org/zap" "go.uber.org/zap"
"riasc.eu/wice/pkg/crypto" "riasc.eu/wice/pkg/crypto"
@@ -58,19 +59,21 @@ func (s *EndpointDiscoveryServer) SendConnectionStates(stream pb.Socket_StreamEv
Peer: p.Peer.PublicKey().Bytes(), Peer: p.Peer.PublicKey().Bytes(),
Event: &pb.Event_PeerConnectionStateChange{ Event: &pb.Event_PeerConnectionStateChange{
PeerConnectionStateChange: &pb.PeerConnectionStateChangeEvent{ PeerConnectionStateChange: &pb.PeerConnectionStateChangeEvent{
NewState: pb.NewConnectionState(p.ConnectionState), NewState: pb.NewConnectionState(p.ConnectionState()),
}, },
}, },
} }
if err := stream.Send(e); err != nil { if err := stream.Send(e); err == io.EOF {
continue
} else if err != nil {
s.logger.Error("Failed to send", zap.Error(err)) s.logger.Error("Failed to send", zap.Error(err))
} }
} }
} }
func (s *EndpointDiscoveryServer) OnConnectionStateChange(p *epice.Peer, new, prev icex.ConnectionState) { func (s *EndpointDiscoveryServer) OnConnectionStateChange(p *epice.Peer, new, prev icex.ConnectionState) {
s.events.C <- &pb.Event{ s.events.Send(&pb.Event{
Type: pb.Event_PEER_CONNECTION_STATE_CHANGED, Type: pb.Event_PEER_CONNECTION_STATE_CHANGED,
Interface: p.Interface.Name(), Interface: p.Interface.Name(),
@@ -82,5 +85,5 @@ func (s *EndpointDiscoveryServer) OnConnectionStateChange(p *epice.Peer, new, pr
PrevState: pb.NewConnectionState(prev), PrevState: pb.NewConnectionState(prev),
}, },
}, },
} })
} }

View File

@@ -115,7 +115,7 @@ func (b *Backend) subscribeFromServer(ctx context.Context, pk *crypto.Key) error
} }
// Wait until subscription has been created // Wait until subscription has been created
// This avoids a race between Subscribe()/Publish() when two subscribers are subscribing // This avoids a race between Subscribe() / Publish() when two subscribers are subscribing
// to each other. // to each other.
if _, err := stream.Recv(); err != nil { if _, err := stream.Recv(); err != nil {
return fmt.Errorf("failed receive sync envelope: %s", err) return fmt.Errorf("failed receive sync envelope: %s", err)

View File

@@ -73,9 +73,18 @@ func (s *Server) Subscribe(params *pb.SubscribeParams, stream pb.Signaling_Subsc
s.logger.Error("Failed to send sync envelope", zap.Error(err)) s.logger.Error("Failed to send sync envelope", zap.Error(err))
} }
for env := range ch { out:
if err := stream.Send(env); err != nil && err != io.EOF { for {
s.logger.Error("Failed to send envelope", zap.Error(err)) select {
case env := <-ch:
if err := stream.Send(env); err == io.EOF {
break out
} else if err != nil {
s.logger.Error("Failed to send envelope", zap.Error(err))
}
case <-stream.Context().Done():
break out
} }
} }

View File

@@ -22,7 +22,7 @@ func (r *topicRegistry) getTopic(pk *crypto.Key) *topic {
return top return top
} }
top = newTopic() top = NewTopic()
r.topics[*pk] = top r.topics[*pk] = top
@@ -34,9 +34,7 @@ func (r *topicRegistry) Close() error {
defer r.topicsLock.Unlock() defer r.topicsLock.Unlock()
for _, t := range r.topics { for _, t := range r.topics {
if err := t.Close(); err != nil { t.Close()
return err
}
} }
return nil return nil
@@ -46,7 +44,7 @@ type topic struct {
subs *util.FanOut[*signaling.Envelope] subs *util.FanOut[*signaling.Envelope]
} }
func newTopic() *topic { func NewTopic() *topic {
t := &topic{ t := &topic{
subs: util.NewFanOut[*signaling.Envelope](128), subs: util.NewFanOut[*signaling.Envelope](128),
} }
@@ -55,7 +53,7 @@ func newTopic() *topic {
} }
func (t *topic) Publish(env *signaling.Envelope) { func (t *topic) Publish(env *signaling.Envelope) {
t.subs.C <- env t.subs.Send(env)
} }
func (t *topic) Subscribe() chan *signaling.Envelope { func (t *topic) Subscribe() chan *signaling.Envelope {
@@ -66,6 +64,6 @@ func (t *topic) Unsubscribe(ch chan *signaling.Envelope) {
t.subs.Remove(ch) t.subs.Remove(ch)
} }
func (t *topic) Close() error { func (t *topic) Close() {
return t.subs.Close() t.subs.Close()
} }

View File

@@ -3,42 +3,36 @@ package util
import "sync" import "sync"
type FanOut[T any] struct { type FanOut[T any] struct {
C chan T
buf int
subs map[chan T]struct{}
lock sync.RWMutex lock sync.RWMutex
buf int
subs map[chan T]any
} }
func NewFanOut[T any](buf int) *FanOut[T] { func NewFanOut[T any](buf int) *FanOut[T] {
f := &FanOut[T]{ f := &FanOut[T]{
C: make(chan T), subs: map[chan T]any{},
subs: map[chan T]struct{}{},
buf: buf, buf: buf,
} }
go f.run()
return f return f
} }
func (f *FanOut[T]) run() { func (f *FanOut[T]) Send(v T) {
for t := range f.C { f.lock.RLock()
f.lock.RLock() defer f.lock.RUnlock()
for ch := range f.subs {
ch <- t for ch := range f.subs {
} ch <- v
f.lock.RUnlock()
} }
} }
func (f *FanOut[T]) Add() chan T { func (f *FanOut[T]) Add() chan T {
ch := make(chan T, f.buf)
f.lock.Lock() f.lock.Lock()
defer f.lock.Unlock() defer f.lock.Unlock()
f.subs[ch] = struct{}{} ch := make(chan T, f.buf)
f.subs[ch] = nil
return ch return ch
} }
@@ -50,15 +44,8 @@ func (f *FanOut[T]) Remove(ch chan T) {
delete(f.subs, ch) delete(f.subs, ch)
} }
func (f *FanOut[T]) Close() error { func (f *FanOut[T]) Close() {
f.lock.Lock()
defer f.lock.Unlock()
for ch := range f.subs { for ch := range f.subs {
close(ch) close(ch)
} }
close(f.C)
return nil
} }

View File

@@ -6,60 +6,51 @@ import (
"riasc.eu/wice/pkg/util" "riasc.eu/wice/pkg/util"
) )
var _ = Describe("Fan-out", func() { var _ = Describe("fanout", func() {
It("should require buffered channels if we synchronously receive from the output channels", func() { It("works with no channel", func() {
fo := util.NewFanOut[int](1)
fo.Send(1234)
fo.Close()
})
It("works with a single channel", func() {
fo := util.NewFanOut[int](1)
ch := fo.Add()
fo.Send(1234)
Eventually(ch).Should(Receive(Equal(1234)))
fo.Close()
})
It("works with two channels", func() {
fo := util.NewFanOut[int](1) fo := util.NewFanOut[int](1)
ch1 := fo.Add() ch1 := fo.Add()
ch2 := fo.Add() ch2 := fo.Add()
fo.C <- 1234 fo.Send(1234)
Eventually(ch1).Should(Receive(Equal(1234))) Eventually(ch1).Should(Receive(Equal(1234)))
Eventually(ch2).Should(Receive(Equal(1234))) Eventually(ch2).Should(Receive(Equal(1234)))
err := fo.Close() fo.Close()
Expect(err).To(Succeed())
}) })
It("also works with unbuffered channels if there is only a single channel", func() { It("works with two channels after one has been removed", func() {
fo := util.NewFanOut[int](0) fo := util.NewFanOut[int](1)
ch := fo.Add()
fo.C <- 1234
Eventually(ch).Should(Receive(Equal(1234)))
err := fo.Close()
Expect(err).To(Succeed())
})
It("also works with unbuffered channels if there is only a single channel or others have been removed", func() {
fo := util.NewFanOut[int](0)
ch1 := fo.Add() ch1 := fo.Add()
ch2 := fo.Add() ch2 := fo.Add()
fo.Remove(ch2) fo.Remove(ch2)
fo.C <- 1234 fo.Send(1234)
Eventually(ch1).Should(Receive(Equal(1234))) Eventually(ch1).Should(Receive(Equal(1234)))
err := fo.Close() fo.Close()
Expect(err).To(Succeed())
})
It("might deadlock if there are more receiving channels", func() {
fo := util.NewFanOut[int](0)
ch1 := fo.Add()
ch2 := fo.Add()
fo.C <- 1234
Eventually(ch1).ShouldNot(Receive(Equal(1234)))
Eventually(ch2).ShouldNot(Receive(Equal(1234)))
err := fo.Close()
Expect(err).To(Succeed())
}) })
}) })