mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-11-02 04:22:37 +08:00
Refactor, fixes to Circ Peek, start processor/parser
This commit is contained in:
18
mqtt.go
18
mqtt.go
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/mochi-co/mqtt/auth"
|
||||
"github.com/mochi-co/mqtt/listeners"
|
||||
"github.com/mochi-co/mqtt/packets"
|
||||
"github.com/mochi-co/mqtt/processors/circ"
|
||||
"github.com/mochi-co/mqtt/topics"
|
||||
"github.com/mochi-co/mqtt/topics/trie"
|
||||
)
|
||||
@@ -90,7 +91,22 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
|
||||
// Serve begins the event loops for establishing client connections on all
|
||||
// attached listeners.
|
||||
func (s *Server) Serve() error {
|
||||
s.listeners.ServeAll(s.EstablishConnection)
|
||||
s.listeners.ServeAll(s.EstablishConnection2)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EstablishConnection establishes a new client connection with the broker.
|
||||
func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller) error {
|
||||
|
||||
p := NewProcessor(c, circ.NewReader(128), circ.NewWriter(128))
|
||||
|
||||
// Pull the header from the first packet and check for a CONNECT message.
|
||||
fh := new(packets.FixedHeader)
|
||||
err := p.ReadFixedHeader(fh)
|
||||
if err != nil {
|
||||
return ErrReadConnectFixedHeader
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -67,6 +68,8 @@ func (p *Parser) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("Peeked", peeked)
|
||||
|
||||
// Unpack message type and flags from byte 1.
|
||||
err = fh.Decode(peeked[0])
|
||||
if err != nil {
|
||||
|
||||
@@ -194,7 +194,7 @@ var fixedHeaderExpected = []fixedHeaderTable{
|
||||
}
|
||||
*/
|
||||
|
||||
func TestReadFixedHeader(t *testing.T) {
|
||||
func TestParserReadFixedHeader(t *testing.T) {
|
||||
conn := new(MockNetConn)
|
||||
|
||||
// Test null data.
|
||||
|
||||
108
processor.go
Normal file
108
processor.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/mochi-co/mqtt/packets"
|
||||
"github.com/mochi-co/mqtt/processors/circ"
|
||||
)
|
||||
|
||||
// Processor reads and writes bytes to a network connection.
|
||||
type Processor struct {
|
||||
|
||||
// Conn is the net.Conn used to establish the connection.
|
||||
Conn net.Conn
|
||||
|
||||
// R is a reader for reading incoming bytes.
|
||||
R *circ.Reader
|
||||
|
||||
// W is a writer for writing outgoing bytes.
|
||||
W *circ.Writer
|
||||
|
||||
// FixedHeader is the FixedHeader from the last read packet.
|
||||
FixedHeader packets.FixedHeader
|
||||
}
|
||||
|
||||
// NewProcessor returns a new instance of Processor.
|
||||
func NewProcessor(c net.Conn, r *circ.Reader, w *circ.Writer) *Processor {
|
||||
return &Processor{
|
||||
Conn: c,
|
||||
R: r,
|
||||
W: w,
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||
func (p *Processor) RefreshDeadline(keepalive uint16) {
|
||||
if p.Conn != nil {
|
||||
expiry := time.Duration(keepalive+(keepalive/2)) * time.Second
|
||||
p.Conn.SetDeadline(time.Now().Add(expiry))
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||
func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||
|
||||
// Peek the maximum message type and flags, and length.
|
||||
peeked, err := p.R.Peek(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("Peeked", peeked)
|
||||
|
||||
// Unpack message type and flags from byte 1.
|
||||
err = fh.Decode(peeked[0])
|
||||
if err != nil {
|
||||
|
||||
// @SPEC [MQTT-2.2.2-2]
|
||||
// If invalid flags are received, the receiver MUST close the Network Connection.
|
||||
return packets.ErrInvalidFlags
|
||||
}
|
||||
|
||||
log.Printf("FH %+v\n", fh)
|
||||
|
||||
// The remaining length value can be up to 5 bytes. Peek through each byte
|
||||
// looking for continue values, and if found increase the peek. Otherwise
|
||||
// decode the bytes that were legit.
|
||||
//p.fhBuffer = p.fhBuffer[:0]
|
||||
buf := make([]byte, 0, 6)
|
||||
i := 1
|
||||
//var b int64 = 2
|
||||
for b := int64(2); b < 6; b++ {
|
||||
peeked, err = p.R.Peek(b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the byte to the length bytes slice.
|
||||
buf = append(buf, peeked[i])
|
||||
|
||||
// If it's not a continuation flag, end here.
|
||||
if peeked[i] < 128 {
|
||||
break
|
||||
}
|
||||
|
||||
// If i has reached 4 without a length terminator, throw a protocol violation.
|
||||
i++
|
||||
if i == 4 {
|
||||
return packets.ErrOversizedLengthIndicator
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate and store the remaining length.
|
||||
rem, _ := binary.Uvarint(buf)
|
||||
fh.Remaining = int(rem)
|
||||
|
||||
// Discard the number of used length bytes + first byte.
|
||||
//p.R.Discard(b)
|
||||
|
||||
// Set the fixed header in the parser.
|
||||
p.FixedHeader = *fh
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
44
processor_test.go
Normal file
44
processor_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package mqtt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/mochi-co/mqtt/packets"
|
||||
"github.com/mochi-co/mqtt/processors/circ"
|
||||
)
|
||||
|
||||
func TestNewProcessor(t *testing.T) {
|
||||
conn := new(MockNetConn)
|
||||
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
||||
require.NotNil(t, p.R)
|
||||
}
|
||||
|
||||
func TestProcessorReadFixedHeader(t *testing.T) {
|
||||
conn := new(MockNetConn)
|
||||
|
||||
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
||||
|
||||
// Test null data.
|
||||
fh := new(packets.FixedHeader)
|
||||
//p.R.SetPos(0, 1)
|
||||
fmt.Printf("A %+v\n", p.R)
|
||||
err := p.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
|
||||
// Test insufficient peeking.
|
||||
fh = new(packets.FixedHeader)
|
||||
p.R.Set([]byte{packets.Connect << 4}, 0, 1)
|
||||
p.R.SetPos(0, 1)
|
||||
fmt.Printf("B %+v\n", p.R)
|
||||
err = p.ReadFixedHeader(fh)
|
||||
require.Error(t, err)
|
||||
|
||||
fh = new(packets.FixedHeader)
|
||||
p.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
|
||||
p.R.SetPos(0, 2)
|
||||
err = p.ReadFixedHeader(fh)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -51,6 +51,22 @@ func newBuffer(size int64) buffer {
|
||||
}
|
||||
}
|
||||
|
||||
// Set writes bytes to a byte buffer. This method should only be used for testing
|
||||
// and will panic if out of range.
|
||||
func (b *buffer) Set(p []byte, start, end int) {
|
||||
o := 0
|
||||
for i := start; i < end; i++ {
|
||||
b.buf[i] = p[o]
|
||||
o++
|
||||
}
|
||||
}
|
||||
|
||||
// SetPos sets the head and tail of the buffer. This method should only be
|
||||
// used for testing.
|
||||
func (b *buffer) SetPos(tail, head int64) {
|
||||
b.tail, b.head = tail, head
|
||||
}
|
||||
|
||||
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
|
||||
// until there are at least n bytes to write without overrunning the tail.
|
||||
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
|
||||
|
||||
@@ -23,6 +23,17 @@ func TestNewBuffer(t *testing.T) {
|
||||
require.Equal(t, size, buf.size)
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
buf := newBuffer(8)
|
||||
require.Equal(t, make([]byte, 4), buf.buf[0:4])
|
||||
p := []byte{'1', '2', '3', '4'}
|
||||
buf.Set(p, 0, 4)
|
||||
require.Equal(t, p, buf.buf[0:4])
|
||||
|
||||
buf.Set(p, 2, 6)
|
||||
require.Equal(t, p, buf.buf[2:6])
|
||||
}
|
||||
|
||||
func TestAwaitCapacity(t *testing.T) {
|
||||
tests := []struct {
|
||||
tail int64
|
||||
|
||||
@@ -70,6 +70,40 @@ func (b *Reader) Peek(n int64) ([]byte, error) {
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
head := atomic.LoadInt64(&b.head)
|
||||
|
||||
// Wait until there's at least 1 byte of data.
|
||||
/*
|
||||
b.wcond.L.Lock()
|
||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
return nil, io.EOF
|
||||
}
|
||||
b.wcond.Wait()
|
||||
}
|
||||
b.wcond.L.Unlock()
|
||||
*/
|
||||
|
||||
// Figure out if we can get all n bytes.
|
||||
if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
||||
return nil, ErrInsufficientBytes
|
||||
}
|
||||
|
||||
// If we've wrapped, get the bytes from the next and start.
|
||||
if head < tail {
|
||||
b.tmp = b.buf[tail:b.size]
|
||||
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
|
||||
return b.tmp, nil
|
||||
} else {
|
||||
return b.buf[tail : tail+n], nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// PeekWait waits until their are n bytes available to return.
|
||||
func (b *Reader) PeekWait(n int64) ([]byte, error) {
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
head := atomic.LoadInt64(&b.head)
|
||||
|
||||
// Wait until there's at least 1 byte of data.
|
||||
b.wcond.L.Lock()
|
||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
||||
|
||||
@@ -54,7 +54,17 @@ DONE:
|
||||
return
|
||||
}
|
||||
|
||||
// Write writes the buffer to the buffer p.
|
||||
func (b *Writer) Write(p []byte) (nn int, err error) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Wait until there's
|
||||
_, err = b.awaitCapacity(int64(len(p)))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
package ring
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
var (
|
||||
blockSize int64 = 4
|
||||
|
||||
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
|
||||
)
|
||||
|
||||
// _______________________________________________________________________
|
||||
// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Buffer is a ring-style byte buffer.
|
||||
type Buffer struct {
|
||||
|
||||
// size is the size of the buffer.
|
||||
size int64
|
||||
|
||||
// buffer is the circular byte buffer.
|
||||
buffer []byte
|
||||
|
||||
// tmp is a temporary buffer.
|
||||
tmp []byte
|
||||
|
||||
// head is the current position in the sequence.
|
||||
head int64
|
||||
|
||||
// tail is the committed position in the sequence, I.E. where we
|
||||
// have successfully consumed and processed to.
|
||||
tail int64
|
||||
|
||||
// rcond is the sync condition for the reader.
|
||||
rcond *sync.Cond
|
||||
|
||||
// done indicates that the buffer is closed.
|
||||
done int64
|
||||
|
||||
// wcond is the sync condition for the writer.
|
||||
wcond *sync.Cond
|
||||
}
|
||||
|
||||
// NewBuffer returns a pointer to a new instance of Buffer.
|
||||
func NewBuffer(size int64) *Buffer {
|
||||
return &Buffer{
|
||||
size: size,
|
||||
buffer: make([]byte, size),
|
||||
tmp: make([]byte, size),
|
||||
rcond: sync.NewCond(new(sync.Mutex)),
|
||||
wcond: sync.NewCond(new(sync.Mutex)),
|
||||
}
|
||||
}
|
||||
|
||||
// awaitCapacity will hold until there are n bytes free in the buffer.
|
||||
// awaitCapacity will block until there are at least n bytes for the head to
|
||||
// write into without overrunning the tail.
|
||||
func (b *Buffer) awaitCapacity(n int64) (int64, error) {
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
head := atomic.LoadInt64(&b.head)
|
||||
next := head + n
|
||||
wrapped := next - b.size
|
||||
|
||||
b.rcond.L.Lock()
|
||||
for ; wrapped > tail || (tail > head && next > tail && wrapped < 0); tail = atomic.LoadInt64(&b.tail) {
|
||||
//fmt.Println("\t", wrapped, ">", tail, wrapped > tail, "||", tail, ">", head, "&&", next, ">", tail, "&&", wrapped, "<", 0, (tail > head && next > tail && wrapped < 0))
|
||||
b.rcond.Wait()
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
fmt.Println("ending")
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
b.rcond.L.Unlock()
|
||||
|
||||
return head, nil
|
||||
}
|
||||
|
||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
||||
// there is sufficient capacity to do so.
|
||||
func (b *Buffer) ReadFrom(r io.Reader) (total int64, err error) {
|
||||
DONE:
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
err = io.EOF
|
||||
break DONE
|
||||
//return 0, io.EOF
|
||||
}
|
||||
|
||||
// Wait until there's enough capacity in the buffer.
|
||||
st, err := b.awaitCapacity(blockSize)
|
||||
if err != nil {
|
||||
break DONE
|
||||
//return 0, err
|
||||
}
|
||||
|
||||
// Find the start and end indexes within the buffer.
|
||||
start := st & (b.size - 1)
|
||||
end := start + blockSize
|
||||
|
||||
// If we've overrun, just fill up to the end and then collect
|
||||
// the rest on the next pass.
|
||||
if end > b.size {
|
||||
end = b.size
|
||||
}
|
||||
|
||||
// Read into the buffer between the start and end indexes only.
|
||||
n, err := r.Read(b.buffer[start:end])
|
||||
total += int64(n) // incr total bytes read.
|
||||
if err != nil {
|
||||
break DONE
|
||||
//return int64(n), err
|
||||
}
|
||||
|
||||
// Move the head forward.
|
||||
//fmt.Println(start, end, b.buffer[start:end])
|
||||
// fmt.Println(">>", b.head, start+int64(n), end)
|
||||
atomic.StoreInt64(&b.head, end) //start+int64(n))
|
||||
b.wcond.L.Lock()
|
||||
b.wcond.Broadcast()
|
||||
b.wcond.L.Unlock()
|
||||
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
||||
func (b *Buffer) WriteTo(w io.Writer) (total int64, err error) {
|
||||
var p []byte
|
||||
var n int
|
||||
DONE:
|
||||
for {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
err = io.EOF
|
||||
break DONE
|
||||
}
|
||||
|
||||
// Peek until there's bytes to write.
|
||||
p, err = b.Peek(blockSize)
|
||||
if err != nil {
|
||||
break DONE
|
||||
}
|
||||
|
||||
n, err = w.Write(p)
|
||||
total += int64(n)
|
||||
if err != nil {
|
||||
break DONE
|
||||
}
|
||||
|
||||
// Move the tail forward the bytes written.
|
||||
end := (atomic.LoadInt64(&b.tail) + int64(n)) % b.size
|
||||
atomic.StoreInt64(&b.tail, end)
|
||||
b.rcond.L.Lock()
|
||||
b.rcond.Broadcast()
|
||||
b.rcond.L.Unlock()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Peek returns the next n bytes without advancing the reader.
|
||||
func (b *Buffer) Peek(n int64) ([]byte, error) {
|
||||
tail := atomic.LoadInt64(&b.tail)
|
||||
head := atomic.LoadInt64(&b.head)
|
||||
|
||||
// Wait until there's at least 1 byte of data.
|
||||
b.wcond.L.Lock()
|
||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
||||
if atomic.LoadInt64(&b.done) == 1 {
|
||||
return nil, io.EOF
|
||||
}
|
||||
b.wcond.Wait()
|
||||
}
|
||||
b.wcond.L.Unlock()
|
||||
|
||||
// Figure out if we can get all n bytes.
|
||||
start := tail
|
||||
end := tail + n
|
||||
if head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
||||
return nil, ErrInsufficientBytes
|
||||
}
|
||||
|
||||
// If we've wrapped, get the bytes from the end and start.
|
||||
if head < tail {
|
||||
b.tmp = b.buffer[start:b.size]
|
||||
b.tmp = append(b.tmp, b.buffer[:(end%b.size)]...)
|
||||
return b.tmp, nil
|
||||
} else {
|
||||
return b.buffer[start:end], nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
package ring
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
blockSize = 4
|
||||
}
|
||||
|
||||
func TestNewBuffer(t *testing.T) {
|
||||
var size int64 = 256
|
||||
buf := NewBuffer(size)
|
||||
|
||||
require.NotNil(t, buf.buffer)
|
||||
require.Equal(t, size, int64(len(buf.buffer)))
|
||||
require.Equal(t, size, buf.size)
|
||||
}
|
||||
|
||||
func TestAwaitCapacity(t *testing.T) {
|
||||
tt := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
need int64
|
||||
await int
|
||||
desc string
|
||||
}{
|
||||
{0, 0, 4, 0, "OK 4, 0"},
|
||||
{6, 6, 4, 0, "OK 6, 10"},
|
||||
{12, 16, 4, 0, "OK 16, 4"},
|
||||
{16, 12, 4, 0, "OK 16 16"},
|
||||
{3, 6, 4, 0, "OK 3, 10"},
|
||||
{12, 0, 4, 0, "OK 16, 0"},
|
||||
{1, 16, 4, 4, "next is more than tail, wait for tail incr"},
|
||||
{7, 5, 4, 2, "tail is great than head, wrapped and caught up with tail, wait for tail incr"},
|
||||
}
|
||||
|
||||
buf := NewBuffer(16)
|
||||
for i, check := range tt {
|
||||
buf.tail, buf.head = check.tail, check.head
|
||||
o := make(chan []interface{})
|
||||
var start int64 = -1
|
||||
var err error
|
||||
go func() {
|
||||
start, err = buf.awaitCapacity(4)
|
||||
o <- []interface{}{start, err}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond) // atomic updates are super fast so wait a bit
|
||||
for i := 0; i < check.await; i++ {
|
||||
atomic.AddInt64(&buf.tail, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
|
||||
if start == -1 {
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
buf.rcond.L.Lock()
|
||||
buf.rcond.Broadcast()
|
||||
buf.rcond.L.Unlock()
|
||||
}
|
||||
done := <-o
|
||||
require.Equal(t, check.head, done[0].(int64), "Head-Start mismatch [i:%d] %s", i, check.desc)
|
||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, check.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFrom(t *testing.T) {
|
||||
buf := NewBuffer(16)
|
||||
|
||||
b4 := bytes.Repeat([]byte{'-'}, 4)
|
||||
br := bytes.NewReader(b4)
|
||||
_, err := buf.ReadFrom(br)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buffer[:4])
|
||||
require.Equal(t, int64(4), buf.head)
|
||||
|
||||
br.Reset(b4)
|
||||
_, err = buf.ReadFrom(br)
|
||||
require.Equal(t, int64(8), buf.head)
|
||||
|
||||
br.Reset(b4)
|
||||
_, err = buf.ReadFrom(br)
|
||||
require.Equal(t, int64(12), buf.head)
|
||||
|
||||
br.Reset([]byte{'-', '-', '-', '-', '/', '/', '/', '/'})
|
||||
o := make(chan error)
|
||||
go func() {
|
||||
_, err := buf.ReadFrom(br)
|
||||
o <- err
|
||||
}()
|
||||
|
||||
atomic.StoreInt64(&buf.tail, 4)
|
||||
<-o
|
||||
require.Equal(t, int64(4), buf.head)
|
||||
require.Equal(t, []byte{'/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'}, buf.buffer)
|
||||
}
|
||||
|
||||
func TestPeek(t *testing.T) {
|
||||
|
||||
buf := NewBuffer(16)
|
||||
buf.buffer = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
||||
|
||||
tests := []struct {
|
||||
tail int64
|
||||
head int64
|
||||
want int
|
||||
bytes []byte
|
||||
err error
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, head: 4, want: 4, bytes: []byte{'a', 'b', 'c', 'd'}, err: nil, desc: "0,4: OK"},
|
||||
{tail: 3, head: 15, want: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, err: nil, desc: "0,4: OK"},
|
||||
{tail: 14, head: 5, want: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, err: nil, desc: "14,5 OK"},
|
||||
{tail: 14, head: 1, want: 6, bytes: nil, err: ErrInsufficientBytes, desc: "14,1 insufficient bytes"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf.tail, buf.head = tt.tail, tt.head
|
||||
o := make(chan []interface{})
|
||||
go func() {
|
||||
bs, err := buf.Peek(int64(tt.want))
|
||||
o <- []interface{}{bs, err}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.head, buf.head+int64(tt.want))
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
|
||||
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
|
||||
done := <-o
|
||||
if tt.err != nil {
|
||||
require.Error(t, done[1].(error), "Expected Error [i:%d] %s", i, tt.desc)
|
||||
} else {
|
||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
|
||||
require.Equal(t, tt.bytes, done[0].([]byte), "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
|
||||
fmt.Println("done")
|
||||
|
||||
}
|
||||
|
||||
func TestWriteTo(t *testing.T) {
|
||||
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
||||
|
||||
tests := []struct {
|
||||
tail int64
|
||||
next int64
|
||||
head int64
|
||||
err error
|
||||
bytes []byte
|
||||
total int64
|
||||
desc string
|
||||
}{
|
||||
{tail: 0, next: 4, head: 4, err: io.EOF, bytes: []byte{'a', 'b', 'c', 'd'}, total: 4, desc: "0, 4 OK"},
|
||||
{tail: 14, next: 2, head: 2, err: io.EOF, bytes: []byte{'o', 'p', 'a', 'b'}, total: 4, desc: "14, 2 wrap"},
|
||||
{tail: 0, next: 0, head: 2, err: ErrInsufficientBytes, bytes: []byte{}, total: 0, desc: "0, 2 insufficient"},
|
||||
{tail: 14, next: 2, head: 3, err: ErrInsufficientBytes, bytes: []byte{'o', 'p', 'a', 'b'}, total: 4, desc: "0, 3 OK > insufficient wrap"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
buf := NewBuffer(16)
|
||||
buf.buffer = bb
|
||||
buf.tail, buf.head = tt.tail, tt.head
|
||||
r, w := net.Pipe()
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond)
|
||||
atomic.StoreInt64(&buf.done, 1)
|
||||
buf.wcond.L.Lock()
|
||||
buf.wcond.Broadcast()
|
||||
buf.wcond.L.Unlock()
|
||||
w.Close()
|
||||
}()
|
||||
|
||||
recv := make(chan []byte)
|
||||
go func() {
|
||||
buf, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
recv <- buf
|
||||
}()
|
||||
|
||||
total, err := buf.WriteTo(w)
|
||||
require.Equal(t, tt.next, buf.tail, "Tail placement mismatched [i:%d] %s", i, tt.desc)
|
||||
require.Equal(t, tt.total, total, "Total bytes written mismatch [i:%d] %s", i, tt.desc)
|
||||
if tt.err != nil {
|
||||
require.Error(t, err, "Expected Error [i:%d] %s", i, tt.desc)
|
||||
} else {
|
||||
require.NoError(t, err, "Unexpected Error [i:%d] %s", i, tt.desc)
|
||||
}
|
||||
require.Equal(t, tt.bytes, <-recv, "Written bytes mismatch [i:%d] %s", i, tt.desc)
|
||||
|
||||
r.Close()
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1 +0,0 @@
|
||||
package ring
|
||||
@@ -1 +0,0 @@
|
||||
package ring
|
||||
Reference in New Issue
Block a user