mirror of
https://codeberg.org/cunicu/cunicu.git
synced 2025-10-23 00:40:19 +08:00
fix leaking resources in cancelled gRPC streams
Signed-off-by: Steffen Vogel <post@steffenvogel.de>
This commit is contained in:
@@ -12,21 +12,21 @@ import (
|
||||
)
|
||||
|
||||
func (s *Server) OnInterfaceAdded(i *core.Interface) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_INTERFACE_ADDED,
|
||||
Interface: i.Name(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnInterfaceRemoved(i *core.Interface) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_INTERFACE_REMOVED,
|
||||
Interface: i.Name(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core.InterfaceModifier) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_INTERFACE_MODIFIED,
|
||||
Interface: i.Name(),
|
||||
Event: &pb.Event_InterfaceModified{
|
||||
@@ -34,27 +34,27 @@ func (s *Server) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core
|
||||
Modified: uint32(mod),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnPeerAdded(p *core.Peer) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_PEER_ADDED,
|
||||
Interface: p.Interface.Name(),
|
||||
Peer: p.PublicKey().Bytes(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnPeerRemoved(p *core.Peer) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_PEER_REMOVED,
|
||||
Interface: p.Interface.Name(),
|
||||
Peer: p.PublicKey().Bytes(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnPeerModified(p *core.Peer, old *wgtypes.Peer, mod core.PeerModifier, ipsAdded, ipsRemoved []net.IPNet) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_PEER_MODIFIED,
|
||||
Interface: p.Interface.Name(),
|
||||
Peer: p.PublicKey().Bytes(),
|
||||
@@ -64,11 +64,11 @@ func (s *Server) OnPeerModified(p *core.Peer, old *wgtypes.Peer, mod core.PeerMo
|
||||
Modified: uint32(mod),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnSignalingBackendReady(b signaling.Backend) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_BACKEND_READY,
|
||||
|
||||
Event: &pb.Event_BackendReady{
|
||||
@@ -76,7 +76,7 @@ func (s *Server) OnSignalingBackendReady(b signaling.Backend) {
|
||||
Type: b.Type(),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Server) OnSignalingMessage(kp *crypto.PublicKeyPair, msg *signaling.Message) {
|
||||
|
@@ -33,7 +33,7 @@ type Server struct {
|
||||
|
||||
func NewServer(d *wice.Daemon) (*Server, error) {
|
||||
s := &Server{
|
||||
events: util.NewFanOut[*pb.Event](0),
|
||||
events: util.NewFanOut[*pb.Event](1),
|
||||
logger: zap.L().Named("rpc.server"),
|
||||
}
|
||||
|
||||
@@ -79,6 +79,7 @@ func (s *Server) Wait() {
|
||||
|
||||
func (s *Server) Close() error {
|
||||
s.grpc.GracefulStop()
|
||||
s.events.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package rpc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
wice "riasc.eu/wice/pkg"
|
||||
"riasc.eu/wice/pkg/pb"
|
||||
@@ -33,9 +34,21 @@ func (s *DaemonServer) StreamEvents(params *pb.StreamEventsParams, stream pb.Soc
|
||||
s.ep.SendConnectionStates(stream)
|
||||
}
|
||||
|
||||
for e := range s.events.Add() {
|
||||
if err := stream.Send(e); err != nil {
|
||||
return fmt.Errorf("failed to send event: %w", err)
|
||||
events := s.events.Add()
|
||||
defer s.events.Remove(events)
|
||||
|
||||
out:
|
||||
for {
|
||||
select {
|
||||
case event := <-events:
|
||||
if err := stream.Send(event); err == io.EOF {
|
||||
break out
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to send event: %w", err)
|
||||
}
|
||||
|
||||
case <-stream.Context().Done():
|
||||
break out
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -3,6 +3,7 @@ package rpc
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"riasc.eu/wice/pkg/crypto"
|
||||
@@ -58,19 +59,21 @@ func (s *EndpointDiscoveryServer) SendConnectionStates(stream pb.Socket_StreamEv
|
||||
Peer: p.Peer.PublicKey().Bytes(),
|
||||
Event: &pb.Event_PeerConnectionStateChange{
|
||||
PeerConnectionStateChange: &pb.PeerConnectionStateChangeEvent{
|
||||
NewState: pb.NewConnectionState(p.ConnectionState),
|
||||
NewState: pb.NewConnectionState(p.ConnectionState()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := stream.Send(e); err != nil {
|
||||
if err := stream.Send(e); err == io.EOF {
|
||||
continue
|
||||
} else if err != nil {
|
||||
s.logger.Error("Failed to send", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EndpointDiscoveryServer) OnConnectionStateChange(p *epice.Peer, new, prev icex.ConnectionState) {
|
||||
s.events.C <- &pb.Event{
|
||||
s.events.Send(&pb.Event{
|
||||
Type: pb.Event_PEER_CONNECTION_STATE_CHANGED,
|
||||
|
||||
Interface: p.Interface.Name(),
|
||||
@@ -82,5 +85,5 @@ func (s *EndpointDiscoveryServer) OnConnectionStateChange(p *epice.Peer, new, pr
|
||||
PrevState: pb.NewConnectionState(prev),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@@ -115,7 +115,7 @@ func (b *Backend) subscribeFromServer(ctx context.Context, pk *crypto.Key) error
|
||||
}
|
||||
|
||||
// Wait until subscription has been created
|
||||
// This avoids a race between Subscribe()/Publish() when two subscribers are subscribing
|
||||
// This avoids a race between Subscribe() / Publish() when two subscribers are subscribing
|
||||
// to each other.
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
return fmt.Errorf("failed receive sync envelope: %s", err)
|
||||
|
@@ -73,9 +73,18 @@ func (s *Server) Subscribe(params *pb.SubscribeParams, stream pb.Signaling_Subsc
|
||||
s.logger.Error("Failed to send sync envelope", zap.Error(err))
|
||||
}
|
||||
|
||||
for env := range ch {
|
||||
if err := stream.Send(env); err != nil && err != io.EOF {
|
||||
s.logger.Error("Failed to send envelope", zap.Error(err))
|
||||
out:
|
||||
for {
|
||||
select {
|
||||
case env := <-ch:
|
||||
if err := stream.Send(env); err == io.EOF {
|
||||
break out
|
||||
} else if err != nil {
|
||||
s.logger.Error("Failed to send envelope", zap.Error(err))
|
||||
}
|
||||
|
||||
case <-stream.Context().Done():
|
||||
break out
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -22,7 +22,7 @@ func (r *topicRegistry) getTopic(pk *crypto.Key) *topic {
|
||||
return top
|
||||
}
|
||||
|
||||
top = newTopic()
|
||||
top = NewTopic()
|
||||
|
||||
r.topics[*pk] = top
|
||||
|
||||
@@ -34,9 +34,7 @@ func (r *topicRegistry) Close() error {
|
||||
defer r.topicsLock.Unlock()
|
||||
|
||||
for _, t := range r.topics {
|
||||
if err := t.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
t.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -46,7 +44,7 @@ type topic struct {
|
||||
subs *util.FanOut[*signaling.Envelope]
|
||||
}
|
||||
|
||||
func newTopic() *topic {
|
||||
func NewTopic() *topic {
|
||||
t := &topic{
|
||||
subs: util.NewFanOut[*signaling.Envelope](128),
|
||||
}
|
||||
@@ -55,7 +53,7 @@ func newTopic() *topic {
|
||||
}
|
||||
|
||||
func (t *topic) Publish(env *signaling.Envelope) {
|
||||
t.subs.C <- env
|
||||
t.subs.Send(env)
|
||||
}
|
||||
|
||||
func (t *topic) Subscribe() chan *signaling.Envelope {
|
||||
@@ -66,6 +64,6 @@ func (t *topic) Unsubscribe(ch chan *signaling.Envelope) {
|
||||
t.subs.Remove(ch)
|
||||
}
|
||||
|
||||
func (t *topic) Close() error {
|
||||
return t.subs.Close()
|
||||
func (t *topic) Close() {
|
||||
t.subs.Close()
|
||||
}
|
||||
|
@@ -3,42 +3,36 @@ package util
|
||||
import "sync"
|
||||
|
||||
type FanOut[T any] struct {
|
||||
C chan T
|
||||
|
||||
buf int
|
||||
subs map[chan T]struct{}
|
||||
lock sync.RWMutex
|
||||
buf int
|
||||
subs map[chan T]any
|
||||
}
|
||||
|
||||
func NewFanOut[T any](buf int) *FanOut[T] {
|
||||
f := &FanOut[T]{
|
||||
C: make(chan T),
|
||||
subs: map[chan T]struct{}{},
|
||||
subs: map[chan T]any{},
|
||||
buf: buf,
|
||||
}
|
||||
|
||||
go f.run()
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *FanOut[T]) run() {
|
||||
for t := range f.C {
|
||||
f.lock.RLock()
|
||||
for ch := range f.subs {
|
||||
ch <- t
|
||||
}
|
||||
f.lock.RUnlock()
|
||||
func (f *FanOut[T]) Send(v T) {
|
||||
f.lock.RLock()
|
||||
defer f.lock.RUnlock()
|
||||
|
||||
for ch := range f.subs {
|
||||
ch <- v
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FanOut[T]) Add() chan T {
|
||||
ch := make(chan T, f.buf)
|
||||
|
||||
f.lock.Lock()
|
||||
defer f.lock.Unlock()
|
||||
|
||||
f.subs[ch] = struct{}{}
|
||||
ch := make(chan T, f.buf)
|
||||
|
||||
f.subs[ch] = nil
|
||||
|
||||
return ch
|
||||
}
|
||||
@@ -50,15 +44,8 @@ func (f *FanOut[T]) Remove(ch chan T) {
|
||||
delete(f.subs, ch)
|
||||
}
|
||||
|
||||
func (f *FanOut[T]) Close() error {
|
||||
f.lock.Lock()
|
||||
defer f.lock.Unlock()
|
||||
|
||||
func (f *FanOut[T]) Close() {
|
||||
for ch := range f.subs {
|
||||
close(ch)
|
||||
}
|
||||
|
||||
close(f.C)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@@ -6,60 +6,51 @@ import (
|
||||
"riasc.eu/wice/pkg/util"
|
||||
)
|
||||
|
||||
var _ = Describe("Fan-out", func() {
|
||||
It("should require buffered channels if we synchronously receive from the output channels", func() {
|
||||
var _ = Describe("fanout", func() {
|
||||
It("works with no channel", func() {
|
||||
fo := util.NewFanOut[int](1)
|
||||
|
||||
fo.Send(1234)
|
||||
|
||||
fo.Close()
|
||||
})
|
||||
|
||||
It("works with a single channel", func() {
|
||||
fo := util.NewFanOut[int](1)
|
||||
ch := fo.Add()
|
||||
|
||||
fo.Send(1234)
|
||||
|
||||
Eventually(ch).Should(Receive(Equal(1234)))
|
||||
|
||||
fo.Close()
|
||||
})
|
||||
|
||||
It("works with two channels", func() {
|
||||
fo := util.NewFanOut[int](1)
|
||||
|
||||
ch1 := fo.Add()
|
||||
ch2 := fo.Add()
|
||||
|
||||
fo.C <- 1234
|
||||
fo.Send(1234)
|
||||
|
||||
Eventually(ch1).Should(Receive(Equal(1234)))
|
||||
Eventually(ch2).Should(Receive(Equal(1234)))
|
||||
|
||||
err := fo.Close()
|
||||
Expect(err).To(Succeed())
|
||||
fo.Close()
|
||||
})
|
||||
|
||||
It("also works with unbuffered channels if there is only a single channel", func() {
|
||||
fo := util.NewFanOut[int](0)
|
||||
ch := fo.Add()
|
||||
|
||||
fo.C <- 1234
|
||||
|
||||
Eventually(ch).Should(Receive(Equal(1234)))
|
||||
|
||||
err := fo.Close()
|
||||
Expect(err).To(Succeed())
|
||||
})
|
||||
|
||||
It("also works with unbuffered channels if there is only a single channel or others have been removed", func() {
|
||||
fo := util.NewFanOut[int](0)
|
||||
It("works with two channels after one has been removed", func() {
|
||||
fo := util.NewFanOut[int](1)
|
||||
ch1 := fo.Add()
|
||||
ch2 := fo.Add()
|
||||
|
||||
fo.Remove(ch2)
|
||||
|
||||
fo.C <- 1234
|
||||
fo.Send(1234)
|
||||
|
||||
Eventually(ch1).Should(Receive(Equal(1234)))
|
||||
|
||||
err := fo.Close()
|
||||
Expect(err).To(Succeed())
|
||||
})
|
||||
|
||||
It("might deadlock if there are more receiving channels", func() {
|
||||
fo := util.NewFanOut[int](0)
|
||||
ch1 := fo.Add()
|
||||
ch2 := fo.Add()
|
||||
|
||||
fo.C <- 1234
|
||||
|
||||
Eventually(ch1).ShouldNot(Receive(Equal(1234)))
|
||||
Eventually(ch2).ShouldNot(Receive(Equal(1234)))
|
||||
|
||||
err := fo.Close()
|
||||
Expect(err).To(Succeed())
|
||||
fo.Close()
|
||||
})
|
||||
})
|
||||
|
Reference in New Issue
Block a user