mirror of
https://github.com/mochi-mqtt/server.git
synced 2025-11-03 02:23:49 +08:00
Refactor, fixes to Circ Peek, start processor/parser
This commit is contained in:
18
mqtt.go
18
mqtt.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/mochi-co/mqtt/auth"
|
"github.com/mochi-co/mqtt/auth"
|
||||||
"github.com/mochi-co/mqtt/listeners"
|
"github.com/mochi-co/mqtt/listeners"
|
||||||
"github.com/mochi-co/mqtt/packets"
|
"github.com/mochi-co/mqtt/packets"
|
||||||
|
"github.com/mochi-co/mqtt/processors/circ"
|
||||||
"github.com/mochi-co/mqtt/topics"
|
"github.com/mochi-co/mqtt/topics"
|
||||||
"github.com/mochi-co/mqtt/topics/trie"
|
"github.com/mochi-co/mqtt/topics/trie"
|
||||||
)
|
)
|
||||||
@@ -90,7 +91,22 @@ func (s *Server) AddListener(listener listeners.Listener, config *listeners.Conf
|
|||||||
// Serve begins the event loops for establishing client connections on all
|
// Serve begins the event loops for establishing client connections on all
|
||||||
// attached listeners.
|
// attached listeners.
|
||||||
func (s *Server) Serve() error {
|
func (s *Server) Serve() error {
|
||||||
s.listeners.ServeAll(s.EstablishConnection)
|
s.listeners.ServeAll(s.EstablishConnection2)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// EstablishConnection establishes a new client connection with the broker.
|
||||||
|
func (s *Server) EstablishConnection2(lid string, c net.Conn, ac auth.Controller) error {
|
||||||
|
|
||||||
|
p := NewProcessor(c, circ.NewReader(128), circ.NewWriter(128))
|
||||||
|
|
||||||
|
// Pull the header from the first packet and check for a CONNECT message.
|
||||||
|
fh := new(packets.FixedHeader)
|
||||||
|
err := p.ReadFixedHeader(fh)
|
||||||
|
if err != nil {
|
||||||
|
return ErrReadConnectFixedHeader
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -67,6 +68,8 @@ func (p *Parser) ReadFixedHeader(fh *packets.FixedHeader) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Println("Peeked", peeked)
|
||||||
|
|
||||||
// Unpack message type and flags from byte 1.
|
// Unpack message type and flags from byte 1.
|
||||||
err = fh.Decode(peeked[0])
|
err = fh.Decode(peeked[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ var fixedHeaderExpected = []fixedHeaderTable{
|
|||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func TestReadFixedHeader(t *testing.T) {
|
func TestParserReadFixedHeader(t *testing.T) {
|
||||||
conn := new(MockNetConn)
|
conn := new(MockNetConn)
|
||||||
|
|
||||||
// Test null data.
|
// Test null data.
|
||||||
|
|||||||
108
processor.go
Normal file
108
processor.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package mqtt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/packets"
|
||||||
|
"github.com/mochi-co/mqtt/processors/circ"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Processor reads and writes bytes to a network connection.
|
||||||
|
type Processor struct {
|
||||||
|
|
||||||
|
// Conn is the net.Conn used to establish the connection.
|
||||||
|
Conn net.Conn
|
||||||
|
|
||||||
|
// R is a reader for reading incoming bytes.
|
||||||
|
R *circ.Reader
|
||||||
|
|
||||||
|
// W is a writer for writing outgoing bytes.
|
||||||
|
W *circ.Writer
|
||||||
|
|
||||||
|
// FixedHeader is the FixedHeader from the last read packet.
|
||||||
|
FixedHeader packets.FixedHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProcessor returns a new instance of Processor.
|
||||||
|
func NewProcessor(c net.Conn, r *circ.Reader, w *circ.Writer) *Processor {
|
||||||
|
return &Processor{
|
||||||
|
Conn: c,
|
||||||
|
R: r,
|
||||||
|
W: w,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshDeadline refreshes the read/write deadline for the net.Conn connection.
|
||||||
|
func (p *Processor) RefreshDeadline(keepalive uint16) {
|
||||||
|
if p.Conn != nil {
|
||||||
|
expiry := time.Duration(keepalive+(keepalive/2)) * time.Second
|
||||||
|
p.Conn.SetDeadline(time.Now().Add(expiry))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadFixedHeader reads in the values of the next packet's fixed header.
|
||||||
|
func (p *Processor) ReadFixedHeader(fh *packets.FixedHeader) error {
|
||||||
|
|
||||||
|
// Peek the maximum message type and flags, and length.
|
||||||
|
peeked, err := p.R.Peek(1)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Peeked", peeked)
|
||||||
|
|
||||||
|
// Unpack message type and flags from byte 1.
|
||||||
|
err = fh.Decode(peeked[0])
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
// @SPEC [MQTT-2.2.2-2]
|
||||||
|
// If invalid flags are received, the receiver MUST close the Network Connection.
|
||||||
|
return packets.ErrInvalidFlags
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("FH %+v\n", fh)
|
||||||
|
|
||||||
|
// The remaining length value can be up to 5 bytes. Peek through each byte
|
||||||
|
// looking for continue values, and if found increase the peek. Otherwise
|
||||||
|
// decode the bytes that were legit.
|
||||||
|
//p.fhBuffer = p.fhBuffer[:0]
|
||||||
|
buf := make([]byte, 0, 6)
|
||||||
|
i := 1
|
||||||
|
//var b int64 = 2
|
||||||
|
for b := int64(2); b < 6; b++ {
|
||||||
|
peeked, err = p.R.Peek(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the byte to the length bytes slice.
|
||||||
|
buf = append(buf, peeked[i])
|
||||||
|
|
||||||
|
// If it's not a continuation flag, end here.
|
||||||
|
if peeked[i] < 128 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// If i has reached 4 without a length terminator, throw a protocol violation.
|
||||||
|
i++
|
||||||
|
if i == 4 {
|
||||||
|
return packets.ErrOversizedLengthIndicator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate and store the remaining length.
|
||||||
|
rem, _ := binary.Uvarint(buf)
|
||||||
|
fh.Remaining = int(rem)
|
||||||
|
|
||||||
|
// Discard the number of used length bytes + first byte.
|
||||||
|
//p.R.Discard(b)
|
||||||
|
|
||||||
|
// Set the fixed header in the parser.
|
||||||
|
p.FixedHeader = *fh
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
44
processor_test.go
Normal file
44
processor_test.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package mqtt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/mochi-co/mqtt/packets"
|
||||||
|
"github.com/mochi-co/mqtt/processors/circ"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewProcessor(t *testing.T) {
|
||||||
|
conn := new(MockNetConn)
|
||||||
|
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
||||||
|
require.NotNil(t, p.R)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessorReadFixedHeader(t *testing.T) {
|
||||||
|
conn := new(MockNetConn)
|
||||||
|
|
||||||
|
p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
|
||||||
|
|
||||||
|
// Test null data.
|
||||||
|
fh := new(packets.FixedHeader)
|
||||||
|
//p.R.SetPos(0, 1)
|
||||||
|
fmt.Printf("A %+v\n", p.R)
|
||||||
|
err := p.ReadFixedHeader(fh)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
// Test insufficient peeking.
|
||||||
|
fh = new(packets.FixedHeader)
|
||||||
|
p.R.Set([]byte{packets.Connect << 4}, 0, 1)
|
||||||
|
p.R.SetPos(0, 1)
|
||||||
|
fmt.Printf("B %+v\n", p.R)
|
||||||
|
err = p.ReadFixedHeader(fh)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
fh = new(packets.FixedHeader)
|
||||||
|
p.R.Set([]byte{packets.Connect << 4, 0x00}, 0, 2)
|
||||||
|
p.R.SetPos(0, 2)
|
||||||
|
err = p.ReadFixedHeader(fh)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
@@ -51,6 +51,22 @@ func newBuffer(size int64) buffer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set writes bytes to a byte buffer. This method should only be used for testing
|
||||||
|
// and will panic if out of range.
|
||||||
|
func (b *buffer) Set(p []byte, start, end int) {
|
||||||
|
o := 0
|
||||||
|
for i := start; i < end; i++ {
|
||||||
|
b.buf[i] = p[o]
|
||||||
|
o++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPos sets the head and tail of the buffer. This method should only be
|
||||||
|
// used for testing.
|
||||||
|
func (b *buffer) SetPos(tail, head int64) {
|
||||||
|
b.tail, b.head = tail, head
|
||||||
|
}
|
||||||
|
|
||||||
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
|
// awaitCapacity will hold until there are n bytes free in the buffer, blocking
|
||||||
// until there are at least n bytes to write without overrunning the tail.
|
// until there are at least n bytes to write without overrunning the tail.
|
||||||
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
|
func (b *buffer) awaitCapacity(n int64) (head int64, err error) {
|
||||||
|
|||||||
@@ -23,6 +23,17 @@ func TestNewBuffer(t *testing.T) {
|
|||||||
require.Equal(t, size, buf.size)
|
require.Equal(t, size, buf.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSet(t *testing.T) {
|
||||||
|
buf := newBuffer(8)
|
||||||
|
require.Equal(t, make([]byte, 4), buf.buf[0:4])
|
||||||
|
p := []byte{'1', '2', '3', '4'}
|
||||||
|
buf.Set(p, 0, 4)
|
||||||
|
require.Equal(t, p, buf.buf[0:4])
|
||||||
|
|
||||||
|
buf.Set(p, 2, 6)
|
||||||
|
require.Equal(t, p, buf.buf[2:6])
|
||||||
|
}
|
||||||
|
|
||||||
func TestAwaitCapacity(t *testing.T) {
|
func TestAwaitCapacity(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
tail int64
|
tail int64
|
||||||
|
|||||||
@@ -70,6 +70,40 @@ func (b *Reader) Peek(n int64) ([]byte, error) {
|
|||||||
tail := atomic.LoadInt64(&b.tail)
|
tail := atomic.LoadInt64(&b.tail)
|
||||||
head := atomic.LoadInt64(&b.head)
|
head := atomic.LoadInt64(&b.head)
|
||||||
|
|
||||||
|
// Wait until there's at least 1 byte of data.
|
||||||
|
/*
|
||||||
|
b.wcond.L.Lock()
|
||||||
|
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
||||||
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
|
return nil, io.EOF
|
||||||
|
}
|
||||||
|
b.wcond.Wait()
|
||||||
|
}
|
||||||
|
b.wcond.L.Unlock()
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Figure out if we can get all n bytes.
|
||||||
|
if head == tail || head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
||||||
|
return nil, ErrInsufficientBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've wrapped, get the bytes from the next and start.
|
||||||
|
if head < tail {
|
||||||
|
b.tmp = b.buf[tail:b.size]
|
||||||
|
b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
|
||||||
|
return b.tmp, nil
|
||||||
|
} else {
|
||||||
|
return b.buf[tail : tail+n], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PeekWait waits until their are n bytes available to return.
|
||||||
|
func (b *Reader) PeekWait(n int64) ([]byte, error) {
|
||||||
|
tail := atomic.LoadInt64(&b.tail)
|
||||||
|
head := atomic.LoadInt64(&b.head)
|
||||||
|
|
||||||
// Wait until there's at least 1 byte of data.
|
// Wait until there's at least 1 byte of data.
|
||||||
b.wcond.L.Lock()
|
b.wcond.L.Lock()
|
||||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
||||||
|
|||||||
@@ -54,7 +54,17 @@ DONE:
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write writes the buffer to the buffer p.
|
||||||
func (b *Writer) Write(p []byte) (nn int, err error) {
|
func (b *Writer) Write(p []byte) (nn int, err error) {
|
||||||
|
if atomic.LoadInt64(&b.done) == 1 {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until there's
|
||||||
|
_, err = b.awaitCapacity(int64(len(p)))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,199 +0,0 @@
|
|||||||
package ring
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
blockSize int64 = 4
|
|
||||||
|
|
||||||
ErrInsufficientBytes = errors.New("Insufficient bytes to return")
|
|
||||||
)
|
|
||||||
|
|
||||||
// _______________________________________________________________________
|
|
||||||
// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
|
|
||||||
// -----------------------------------------------------------------------
|
|
||||||
|
|
||||||
// Buffer is a ring-style byte buffer.
|
|
||||||
type Buffer struct {
|
|
||||||
|
|
||||||
// size is the size of the buffer.
|
|
||||||
size int64
|
|
||||||
|
|
||||||
// buffer is the circular byte buffer.
|
|
||||||
buffer []byte
|
|
||||||
|
|
||||||
// tmp is a temporary buffer.
|
|
||||||
tmp []byte
|
|
||||||
|
|
||||||
// head is the current position in the sequence.
|
|
||||||
head int64
|
|
||||||
|
|
||||||
// tail is the committed position in the sequence, I.E. where we
|
|
||||||
// have successfully consumed and processed to.
|
|
||||||
tail int64
|
|
||||||
|
|
||||||
// rcond is the sync condition for the reader.
|
|
||||||
rcond *sync.Cond
|
|
||||||
|
|
||||||
// done indicates that the buffer is closed.
|
|
||||||
done int64
|
|
||||||
|
|
||||||
// wcond is the sync condition for the writer.
|
|
||||||
wcond *sync.Cond
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewBuffer returns a pointer to a new instance of Buffer.
|
|
||||||
func NewBuffer(size int64) *Buffer {
|
|
||||||
return &Buffer{
|
|
||||||
size: size,
|
|
||||||
buffer: make([]byte, size),
|
|
||||||
tmp: make([]byte, size),
|
|
||||||
rcond: sync.NewCond(new(sync.Mutex)),
|
|
||||||
wcond: sync.NewCond(new(sync.Mutex)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// awaitCapacity will hold until there are n bytes free in the buffer.
|
|
||||||
// awaitCapacity will block until there are at least n bytes for the head to
|
|
||||||
// write into without overrunning the tail.
|
|
||||||
func (b *Buffer) awaitCapacity(n int64) (int64, error) {
|
|
||||||
tail := atomic.LoadInt64(&b.tail)
|
|
||||||
head := atomic.LoadInt64(&b.head)
|
|
||||||
next := head + n
|
|
||||||
wrapped := next - b.size
|
|
||||||
|
|
||||||
b.rcond.L.Lock()
|
|
||||||
for ; wrapped > tail || (tail > head && next > tail && wrapped < 0); tail = atomic.LoadInt64(&b.tail) {
|
|
||||||
//fmt.Println("\t", wrapped, ">", tail, wrapped > tail, "||", tail, ">", head, "&&", next, ">", tail, "&&", wrapped, "<", 0, (tail > head && next > tail && wrapped < 0))
|
|
||||||
b.rcond.Wait()
|
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
|
||||||
fmt.Println("ending")
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.rcond.L.Unlock()
|
|
||||||
|
|
||||||
return head, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadFrom reads bytes from an io.Reader and commits them to the buffer when
|
|
||||||
// there is sufficient capacity to do so.
|
|
||||||
func (b *Buffer) ReadFrom(r io.Reader) (total int64, err error) {
|
|
||||||
DONE:
|
|
||||||
for {
|
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
|
||||||
err = io.EOF
|
|
||||||
break DONE
|
|
||||||
//return 0, io.EOF
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait until there's enough capacity in the buffer.
|
|
||||||
st, err := b.awaitCapacity(blockSize)
|
|
||||||
if err != nil {
|
|
||||||
break DONE
|
|
||||||
//return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the start and end indexes within the buffer.
|
|
||||||
start := st & (b.size - 1)
|
|
||||||
end := start + blockSize
|
|
||||||
|
|
||||||
// If we've overrun, just fill up to the end and then collect
|
|
||||||
// the rest on the next pass.
|
|
||||||
if end > b.size {
|
|
||||||
end = b.size
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read into the buffer between the start and end indexes only.
|
|
||||||
n, err := r.Read(b.buffer[start:end])
|
|
||||||
total += int64(n) // incr total bytes read.
|
|
||||||
if err != nil {
|
|
||||||
break DONE
|
|
||||||
//return int64(n), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move the head forward.
|
|
||||||
//fmt.Println(start, end, b.buffer[start:end])
|
|
||||||
// fmt.Println(">>", b.head, start+int64(n), end)
|
|
||||||
atomic.StoreInt64(&b.head, end) //start+int64(n))
|
|
||||||
b.wcond.L.Lock()
|
|
||||||
b.wcond.Broadcast()
|
|
||||||
b.wcond.L.Unlock()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteTo writes the contents of the buffer to an io.Writer.
|
|
||||||
func (b *Buffer) WriteTo(w io.Writer) (total int64, err error) {
|
|
||||||
var p []byte
|
|
||||||
var n int
|
|
||||||
DONE:
|
|
||||||
for {
|
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
|
||||||
err = io.EOF
|
|
||||||
break DONE
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peek until there's bytes to write.
|
|
||||||
p, err = b.Peek(blockSize)
|
|
||||||
if err != nil {
|
|
||||||
break DONE
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = w.Write(p)
|
|
||||||
total += int64(n)
|
|
||||||
if err != nil {
|
|
||||||
break DONE
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move the tail forward the bytes written.
|
|
||||||
end := (atomic.LoadInt64(&b.tail) + int64(n)) % b.size
|
|
||||||
atomic.StoreInt64(&b.tail, end)
|
|
||||||
b.rcond.L.Lock()
|
|
||||||
b.rcond.Broadcast()
|
|
||||||
b.rcond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peek returns the next n bytes without advancing the reader.
|
|
||||||
func (b *Buffer) Peek(n int64) ([]byte, error) {
|
|
||||||
tail := atomic.LoadInt64(&b.tail)
|
|
||||||
head := atomic.LoadInt64(&b.head)
|
|
||||||
|
|
||||||
// Wait until there's at least 1 byte of data.
|
|
||||||
b.wcond.L.Lock()
|
|
||||||
for ; head == tail; head = atomic.LoadInt64(&b.head) {
|
|
||||||
if atomic.LoadInt64(&b.done) == 1 {
|
|
||||||
return nil, io.EOF
|
|
||||||
}
|
|
||||||
b.wcond.Wait()
|
|
||||||
}
|
|
||||||
b.wcond.L.Unlock()
|
|
||||||
|
|
||||||
// Figure out if we can get all n bytes.
|
|
||||||
start := tail
|
|
||||||
end := tail + n
|
|
||||||
if head > tail && tail+n > head || head < tail && b.size-tail+head < n {
|
|
||||||
return nil, ErrInsufficientBytes
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've wrapped, get the bytes from the end and start.
|
|
||||||
if head < tail {
|
|
||||||
b.tmp = b.buffer[start:b.size]
|
|
||||||
b.tmp = append(b.tmp, b.buffer[:(end%b.size)]...)
|
|
||||||
return b.tmp, nil
|
|
||||||
} else {
|
|
||||||
return b.buffer[start:end], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
package ring
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
blockSize = 4
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewBuffer(t *testing.T) {
|
|
||||||
var size int64 = 256
|
|
||||||
buf := NewBuffer(size)
|
|
||||||
|
|
||||||
require.NotNil(t, buf.buffer)
|
|
||||||
require.Equal(t, size, int64(len(buf.buffer)))
|
|
||||||
require.Equal(t, size, buf.size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAwaitCapacity(t *testing.T) {
|
|
||||||
tt := []struct {
|
|
||||||
tail int64
|
|
||||||
head int64
|
|
||||||
need int64
|
|
||||||
await int
|
|
||||||
desc string
|
|
||||||
}{
|
|
||||||
{0, 0, 4, 0, "OK 4, 0"},
|
|
||||||
{6, 6, 4, 0, "OK 6, 10"},
|
|
||||||
{12, 16, 4, 0, "OK 16, 4"},
|
|
||||||
{16, 12, 4, 0, "OK 16 16"},
|
|
||||||
{3, 6, 4, 0, "OK 3, 10"},
|
|
||||||
{12, 0, 4, 0, "OK 16, 0"},
|
|
||||||
{1, 16, 4, 4, "next is more than tail, wait for tail incr"},
|
|
||||||
{7, 5, 4, 2, "tail is great than head, wrapped and caught up with tail, wait for tail incr"},
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := NewBuffer(16)
|
|
||||||
for i, check := range tt {
|
|
||||||
buf.tail, buf.head = check.tail, check.head
|
|
||||||
o := make(chan []interface{})
|
|
||||||
var start int64 = -1
|
|
||||||
var err error
|
|
||||||
go func() {
|
|
||||||
start, err = buf.awaitCapacity(4)
|
|
||||||
o <- []interface{}{start, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond) // atomic updates are super fast so wait a bit
|
|
||||||
for i := 0; i < check.await; i++ {
|
|
||||||
atomic.AddInt64(&buf.tail, 1)
|
|
||||||
buf.rcond.L.Lock()
|
|
||||||
buf.rcond.Broadcast()
|
|
||||||
buf.rcond.L.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
|
|
||||||
if start == -1 {
|
|
||||||
atomic.StoreInt64(&buf.done, 1)
|
|
||||||
buf.rcond.L.Lock()
|
|
||||||
buf.rcond.Broadcast()
|
|
||||||
buf.rcond.L.Unlock()
|
|
||||||
}
|
|
||||||
done := <-o
|
|
||||||
require.Equal(t, check.head, done[0].(int64), "Head-Start mismatch [i:%d] %s", i, check.desc)
|
|
||||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, check.desc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadFrom(t *testing.T) {
|
|
||||||
buf := NewBuffer(16)
|
|
||||||
|
|
||||||
b4 := bytes.Repeat([]byte{'-'}, 4)
|
|
||||||
br := bytes.NewReader(b4)
|
|
||||||
_, err := buf.ReadFrom(br)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, bytes.Repeat([]byte{'-'}, 4), buf.buffer[:4])
|
|
||||||
require.Equal(t, int64(4), buf.head)
|
|
||||||
|
|
||||||
br.Reset(b4)
|
|
||||||
_, err = buf.ReadFrom(br)
|
|
||||||
require.Equal(t, int64(8), buf.head)
|
|
||||||
|
|
||||||
br.Reset(b4)
|
|
||||||
_, err = buf.ReadFrom(br)
|
|
||||||
require.Equal(t, int64(12), buf.head)
|
|
||||||
|
|
||||||
br.Reset([]byte{'-', '-', '-', '-', '/', '/', '/', '/'})
|
|
||||||
o := make(chan error)
|
|
||||||
go func() {
|
|
||||||
_, err := buf.ReadFrom(br)
|
|
||||||
o <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
atomic.StoreInt64(&buf.tail, 4)
|
|
||||||
<-o
|
|
||||||
require.Equal(t, int64(4), buf.head)
|
|
||||||
require.Equal(t, []byte{'/', '/', '/', '/', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-', '-'}, buf.buffer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPeek(t *testing.T) {
|
|
||||||
|
|
||||||
buf := NewBuffer(16)
|
|
||||||
buf.buffer = []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
tail int64
|
|
||||||
head int64
|
|
||||||
want int
|
|
||||||
bytes []byte
|
|
||||||
err error
|
|
||||||
desc string
|
|
||||||
}{
|
|
||||||
{tail: 0, head: 4, want: 4, bytes: []byte{'a', 'b', 'c', 'd'}, err: nil, desc: "0,4: OK"},
|
|
||||||
{tail: 3, head: 15, want: 8, bytes: []byte{'d', 'e', 'f', 'g', 'h', 'i', 'j', 'k'}, err: nil, desc: "0,4: OK"},
|
|
||||||
{tail: 14, head: 5, want: 6, bytes: []byte{'o', 'p', 'a', 'b', 'c', 'd'}, err: nil, desc: "14,5 OK"},
|
|
||||||
{tail: 14, head: 1, want: 6, bytes: nil, err: ErrInsufficientBytes, desc: "14,1 insufficient bytes"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
buf.tail, buf.head = tt.tail, tt.head
|
|
||||||
o := make(chan []interface{})
|
|
||||||
go func() {
|
|
||||||
bs, err := buf.Peek(int64(tt.want))
|
|
||||||
o <- []interface{}{bs, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
atomic.StoreInt64(&buf.head, buf.head+int64(tt.want))
|
|
||||||
time.Sleep(time.Millisecond * 10)
|
|
||||||
|
|
||||||
buf.wcond.L.Lock()
|
|
||||||
buf.wcond.Broadcast()
|
|
||||||
buf.wcond.L.Unlock()
|
|
||||||
|
|
||||||
time.Sleep(time.Millisecond) // wait for await capacity to actually exit
|
|
||||||
done := <-o
|
|
||||||
if tt.err != nil {
|
|
||||||
require.Error(t, done[1].(error), "Expected Error [i:%d] %s", i, tt.desc)
|
|
||||||
} else {
|
|
||||||
require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, tt.bytes, done[0].([]byte), "Peeked bytes mismatch [i:%d] %s", i, tt.desc)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("done")
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteTo(t *testing.T) {
|
|
||||||
bb := []byte{'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p'}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
tail int64
|
|
||||||
next int64
|
|
||||||
head int64
|
|
||||||
err error
|
|
||||||
bytes []byte
|
|
||||||
total int64
|
|
||||||
desc string
|
|
||||||
}{
|
|
||||||
{tail: 0, next: 4, head: 4, err: io.EOF, bytes: []byte{'a', 'b', 'c', 'd'}, total: 4, desc: "0, 4 OK"},
|
|
||||||
{tail: 14, next: 2, head: 2, err: io.EOF, bytes: []byte{'o', 'p', 'a', 'b'}, total: 4, desc: "14, 2 wrap"},
|
|
||||||
{tail: 0, next: 0, head: 2, err: ErrInsufficientBytes, bytes: []byte{}, total: 0, desc: "0, 2 insufficient"},
|
|
||||||
{tail: 14, next: 2, head: 3, err: ErrInsufficientBytes, bytes: []byte{'o', 'p', 'a', 'b'}, total: 4, desc: "0, 3 OK > insufficient wrap"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tt := range tests {
|
|
||||||
buf := NewBuffer(16)
|
|
||||||
buf.buffer = bb
|
|
||||||
buf.tail, buf.head = tt.tail, tt.head
|
|
||||||
r, w := net.Pipe()
|
|
||||||
go func() {
|
|
||||||
time.Sleep(time.Millisecond)
|
|
||||||
atomic.StoreInt64(&buf.done, 1)
|
|
||||||
buf.wcond.L.Lock()
|
|
||||||
buf.wcond.Broadcast()
|
|
||||||
buf.wcond.L.Unlock()
|
|
||||||
w.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
recv := make(chan []byte)
|
|
||||||
go func() {
|
|
||||||
buf, err := ioutil.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
recv <- buf
|
|
||||||
}()
|
|
||||||
|
|
||||||
total, err := buf.WriteTo(w)
|
|
||||||
require.Equal(t, tt.next, buf.tail, "Tail placement mismatched [i:%d] %s", i, tt.desc)
|
|
||||||
require.Equal(t, tt.total, total, "Total bytes written mismatch [i:%d] %s", i, tt.desc)
|
|
||||||
if tt.err != nil {
|
|
||||||
require.Error(t, err, "Expected Error [i:%d] %s", i, tt.desc)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err, "Unexpected Error [i:%d] %s", i, tt.desc)
|
|
||||||
}
|
|
||||||
require.Equal(t, tt.bytes, <-recv, "Written bytes mismatch [i:%d] %s", i, tt.desc)
|
|
||||||
|
|
||||||
r.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
package ring
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
package ring
|
|
||||||
Reference in New Issue
Block a user