code optimization

code optimization
This commit is contained in:
werben
2023-04-21 19:19:17 +08:00
parent 20864f8089
commit 570b2638cb
4 changed files with 101 additions and 14 deletions

30
errors.go Normal file
View File

@@ -0,0 +1,30 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 werbenhu
// SPDX-FileContributor: werbenhu
package eventbus
type err struct {
Msg string
Code int
}
func (e err) String() string {
return e.Msg
}
func (e err) Error() string {
return e.Msg
}
var (
ErrHandlerIsNotFunc = err{Code: 10000, Msg: "handler is not a function"}
ErrHandlerParamNum = err{Code: 10001, Msg: "the number of parameters of the handler must be two"}
ErrHandlerFirstParam = err{Code: 10002, Msg: "the first of parameters of the handler must be a string"}
ErrNoSubscriber = err{Code: 10003, Msg: "no subscriber on topic"}
ErrChannelClosed = err{Code: 10004, Msg: "channel is closed"}
ErrGroupExisted = err{Code: 10001, Msg: "group already existed"}
ErrNoResultMatched = err{Code: 10002, Msg: "no result matched"}
ErrKeyExisted = err{Code: 10003, Msg: "key already existed"}
)

29
errors_test.go Normal file
View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 werbenhu
// SPDX-FileContributor: werbenhu
package eventbus
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestErrString(t *testing.T) {
c := err{
Msg: "test",
Code: 0x1,
}
require.Equal(t, "test", c.String())
}
func TestErrErrorr(t *testing.T) {
c := err{
Msg: "error",
Code: 0x1,
}
require.Equal(t, "error", error(c).Error())
}

View File

@@ -1,7 +1,6 @@
package eventbus package eventbus
import ( import (
"fmt"
"reflect" "reflect"
"sync" "sync"
) )
@@ -69,10 +68,11 @@ func (c *channel) subscribe(handler any) error {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
if c.closed { if c.closed {
return fmt.Errorf("channel on topic:%s is closed", c.topic) return ErrChannelClosed
} }
fn := reflect.ValueOf(handler) fn := reflect.ValueOf(handler)
c.handlers.Store(fn.Pointer(), &fn) c.handlers.Store(fn.Pointer(), &fn)
return nil
} }
// publish trigger handlers defined for this channel. payload argument will be transferred to handlers. // publish trigger handlers defined for this channel. payload argument will be transferred to handlers.
@@ -80,16 +80,22 @@ func (c *channel) publish(payload any) error {
c.RLock() c.RLock()
defer c.RUnlock() defer c.RUnlock()
if c.closed { if c.closed {
return fmt.Errorf("channel on topic:%s is closed", c.topic) return ErrChannelClosed
} }
c.channel <- payload c.channel <- payload
return nil return nil
} }
// unsubscribe removes handler defined for this channel. // unsubscribe removes handler defined for this channel.
func (c *channel) unsubscribe(handler any) { func (c *channel) unsubscribe(handler any) error {
c.RLock()
defer c.RUnlock()
if c.closed {
return ErrChannelClosed
}
fn := reflect.ValueOf(handler) fn := reflect.ValueOf(handler)
c.handlers.Delete(fn.Pointer()) c.handlers.Delete(fn.Pointer())
return nil
} }
// close closes a channel // close closes a channel
@@ -132,7 +138,7 @@ func New() *EventBus {
func (e *EventBus) Unsubscribe(topic string, handler any) error { func (e *EventBus) Unsubscribe(topic string, handler any) error {
ch, ok := e.channels.Load(topic) ch, ok := e.channels.Load(topic)
if !ok { if !ok {
return fmt.Errorf("no subscriber on topic:%s", topic) return ErrNoSubscriber
} }
ch.(*channel).unsubscribe(handler) ch.(*channel).unsubscribe(handler)
return nil return nil
@@ -142,13 +148,13 @@ func (e *EventBus) Unsubscribe(topic string, handler any) error {
func (e *EventBus) Subscribe(topic string, handler any) error { func (e *EventBus) Subscribe(topic string, handler any) error {
typ := reflect.TypeOf(handler) typ := reflect.TypeOf(handler)
if typ.Kind() != reflect.Func { if typ.Kind() != reflect.Func {
return fmt.Errorf("the type of handler is %s, not type reflect.Func", reflect.TypeOf(handler).Kind()) return ErrHandlerIsNotFunc
} }
if typ.NumIn() != 2 { if typ.NumIn() != 2 {
return fmt.Errorf("the number of parameters of the handler must be two") return ErrHandlerParamNum
} }
if typ.In(0).Kind() != reflect.String { if typ.In(0).Kind() != reflect.String {
return fmt.Errorf("the first of parameters of the handler must be string type") return ErrHandlerFirstParam
} }
ch, ok := e.channels.Load(topic) ch, ok := e.channels.Load(topic)

View File

@@ -22,6 +22,7 @@ func Test_newChannel(t *testing.T) {
assert.NotNil(t, ch.channel) assert.NotNil(t, ch.channel)
assert.Equal(t, "test_topic", ch.topic) assert.Equal(t, "test_topic", ch.topic)
assert.NotNil(t, ch.stopCh) assert.NotNil(t, ch.stopCh)
assert.NotNil(t, ch.handlers)
ch.close() ch.close()
} }
@@ -31,9 +32,29 @@ func Test_channelSubscribe(t *testing.T) {
assert.NotNil(t, ch.channel) assert.NotNil(t, ch.channel)
assert.Equal(t, "test_topic", ch.topic) assert.Equal(t, "test_topic", ch.topic)
ch.subscribe(sub1) err := ch.subscribe(sub1)
ch.subscribe(sub2) assert.Nil(t, err)
ch.close() ch.close()
err = ch.subscribe(sub2)
assert.Equal(t, ErrChannelClosed, err)
}
func Test_channelUnsubscribe(t *testing.T) {
ch := newChannel("test_topic", -1)
assert.NotNil(t, ch)
assert.NotNil(t, ch.channel)
assert.Equal(t, "test_topic", ch.topic)
err := ch.subscribe(sub1)
assert.Nil(t, err)
err = ch.unsubscribe(sub1)
assert.Nil(t, err)
err = ch.subscribe(sub1)
assert.Nil(t, err)
ch.close()
err = ch.subscribe(sub2)
assert.Equal(t, ErrChannelClosed, err)
} }
func Test_channelPublish(t *testing.T) { func Test_channelPublish(t *testing.T) {
@@ -47,11 +68,12 @@ func Test_channelPublish(t *testing.T) {
go func() { go func() {
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
ch.publish(i) err := ch.publish(i)
assert.Nil(t, err)
} }
}() }()
time.Sleep(1000 * time.Millisecond) time.Sleep(1000 * time.Millisecond)
ch.close()
// ch.close() err := ch.publish(1)
// ch.publish(13) assert.Equal(t, ErrChannelClosed, err)
} }