allow writing primitives to static buffers

This commit is contained in:
aler9
2022-05-11 14:52:20 +02:00
parent ee6d7a87a3
commit c1b10a80be
19 changed files with 662 additions and 786 deletions

View File

@@ -35,11 +35,16 @@ func (b *body) read(header Header, rb *bufio.Reader) error {
return nil
}
func (b body) write(w io.Writer) error {
if len(b) == 0 {
return nil
}
_, err := w.Write(b)
return err
func (b body) writeSize() int {
return len(b)
}
func (b body) writeTo(buf []byte) int {
return copy(buf, b)
}
func (b body) write() []byte {
buf := make([]byte, b.writeSize())
b.writeTo(buf)
return buf
}

View File

@@ -20,11 +20,6 @@ var casesBody = []struct {
},
[]byte{0x01, 0x02, 0x03, 0x04},
},
{
"nil",
Header{},
[]byte(nil),
},
}
func TestBodyRead(t *testing.T) {
@@ -81,9 +76,8 @@ func TestBodyReadErrors(t *testing.T) {
func TestBodyWrite(t *testing.T) {
for _, ca := range casesBody {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
body(ca.byts).write(&buf)
require.Equal(t, ca.byts, buf.Bytes())
buf := body(ca.byts).write()
require.Equal(t, ca.byts, buf)
})
}
}

View File

@@ -3,7 +3,6 @@ package base
import (
"bufio"
"fmt"
"io"
"net/http"
"sort"
"strings"
@@ -98,7 +97,7 @@ func (h *Header) read(rb *bufio.Reader) error {
return nil
}
func (h Header) write(w io.Writer) error {
func (h Header) writeSize() int {
// sort headers by key
// in order to obtain deterministic results
keys := make([]string, len(h))
@@ -107,15 +106,43 @@ func (h Header) write(w io.Writer) error {
}
sort.Strings(keys)
n := 0
for _, key := range keys {
for _, val := range h[key] {
_, err := w.Write([]byte(key + ": " + val + "\r\n"))
if err != nil {
return err
}
n += len([]byte(key + ": " + val + "\r\n"))
}
}
_, err := w.Write([]byte("\r\n"))
return err
n += 2
return n
}
func (h Header) writeTo(buf []byte) int {
// sort headers by key
// in order to obtain deterministic results
keys := make([]string, len(h))
for key := range h {
keys = append(keys, key)
}
sort.Strings(keys)
pos := 0
for _, key := range keys {
for _, val := range h[key] {
pos += copy(buf[pos:], []byte(key+": "+val+"\r\n"))
}
}
pos += copy(buf[pos:], []byte("\r\n"))
return pos
}
func (h Header) write() []byte {
buf := make([]byte, h.writeSize())
h.writeTo(buf)
return buf
}

View File

@@ -176,9 +176,8 @@ func TestHeaderReadErrors(t *testing.T) {
func TestHeaderWrite(t *testing.T) {
for _, ca := range casesHeader {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
ca.header.write(&buf)
require.Equal(t, ca.enc, buf.Bytes())
buf := ca.header.write()
require.Equal(t, ca.enc, buf)
})
}
}

View File

@@ -105,16 +105,28 @@ func (f *InterleavedFrame) Read(maxPayloadSize int, br *bufio.Reader) error {
return nil
}
// Write writes an InterleavedFrame into a buffered writer.
func (f InterleavedFrame) Write(w io.Writer) error {
buf := []byte{0x24, byte(f.Channel), 0x00, 0x00}
binary.BigEndian.PutUint16(buf[2:], uint16(len(f.Payload)))
_, err := w.Write(buf)
if err != nil {
return err
}
_, err = w.Write(f.Payload)
return err
// WriteSize returns the size of an InterleavedFrame.
func (f InterleavedFrame) WriteSize() int {
return 4 + len(f.Payload)
}
// WriteTo writes an InterleavedFrame.
func (f InterleavedFrame) WriteTo(buf []byte) (int, error) {
pos := 0
pos += copy(buf[pos:], []byte{0x24, byte(f.Channel)})
binary.BigEndian.PutUint16(buf[pos:], uint16(len(f.Payload)))
pos += 2
pos += copy(buf[pos:], f.Payload)
return pos, nil
}
// Write writes an InterleavedFrame.
func (f InterleavedFrame) Write() ([]byte, error) {
buf := make([]byte, f.WriteSize())
_, err := f.WriteTo(buf)
return buf, err
}

View File

@@ -82,9 +82,9 @@ func TestInterleavedFrameReadErrors(t *testing.T) {
func TestInterleavedFrameWrite(t *testing.T) {
for _, ca := range casesInterleavedFrame {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
ca.dec.Write(&buf)
require.Equal(t, ca.enc, buf.Bytes())
buf, err := ca.dec.Write()
require.NoError(t, err)
require.Equal(t, ca.enc, buf)
})
}
}

View File

@@ -3,9 +3,7 @@ package base
import (
"bufio"
"bytes"
"fmt"
"io"
"strconv"
)
@@ -117,29 +115,51 @@ func (req *Request) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) error
}
}
// Write writes a request.
func (req Request) Write(w io.Writer) error {
// WriteSize returns the size of a Request.
func (req Request) WriteSize() int {
n := 0
urStr := req.URL.CloneWithoutCredentials().String()
_, err := w.Write([]byte(string(req.Method) + " " + urStr + " " + rtspProtocol10 + "\r\n"))
if err != nil {
return err
}
n += len([]byte(string(req.Method) + " " + urStr + " " + rtspProtocol10 + "\r\n"))
if len(req.Body) != 0 {
req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)}
}
err = req.Header.write(w)
if err != nil {
return err
n += req.Header.writeSize()
n += body(req.Body).writeSize()
return n
}
// WriteTo writes a Request.
func (req Request) WriteTo(buf []byte) (int, error) {
pos := 0
urStr := req.URL.CloneWithoutCredentials().String()
pos += copy(buf[pos:], []byte(string(req.Method)+" "+urStr+" "+rtspProtocol10+"\r\n"))
if len(req.Body) != 0 {
req.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(req.Body)), 10)}
}
return body(req.Body).write(w)
pos += req.Header.writeTo(buf[pos:])
pos += body(req.Body).writeTo(buf[pos:])
return pos, nil
}
// Write writes a Request.
func (req Request) Write() ([]byte, error) {
buf := make([]byte, req.WriteSize())
_, err := req.WriteTo(buf)
return buf, err
}
// String implements fmt.Stringer.
func (req Request) String() string {
var buf bytes.Buffer
req.Write(&buf)
return buf.String()
buf, _ := req.Write()
return string(buf)
}

View File

@@ -221,9 +221,9 @@ func TestRequestReadErrors(t *testing.T) {
func TestRequestWrite(t *testing.T) {
for _, ca := range casesRequest {
t.Run(ca.name, func(t *testing.T) {
var buf bytes.Buffer
ca.req.Write(&buf)
require.Equal(t, ca.byts, buf.Bytes())
buf, err := ca.req.Write()
require.NoError(t, err)
require.Equal(t, ca.byts, buf)
})
}
}

View File

@@ -2,9 +2,7 @@ package base
import (
"bufio"
"bytes"
"fmt"
"io"
"strconv"
)
@@ -203,36 +201,65 @@ func (res *Response) ReadIgnoreFrames(maxPayloadSize int, rb *bufio.Reader) erro
}
}
// Write writes a Response.
func (res Response) Write(w io.Writer) error {
// WriteSize returns the size of a Response.
func (res Response) WriteSize() int {
n := 0
if res.StatusMessage == "" {
if status, ok := statusMessages[res.StatusCode]; ok {
res.StatusMessage = status
}
}
_, err := w.Write([]byte(rtspProtocol10 + " " +
n += len([]byte(rtspProtocol10 + " " +
strconv.FormatInt(int64(res.StatusCode), 10) + " " +
res.StatusMessage + "\r\n"))
if err != nil {
return err
}
if len(res.Body) != 0 {
res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)}
}
err = res.Header.write(w)
if err != nil {
return err
n += res.Header.writeSize()
n += body(res.Body).writeSize()
return n
}
// WriteTo writes a Response.
func (res Response) WriteTo(buf []byte) (int, error) {
if res.StatusMessage == "" {
if status, ok := statusMessages[res.StatusCode]; ok {
res.StatusMessage = status
}
}
return body(res.Body).write(w)
pos := 0
pos += copy(buf[pos:], []byte(rtspProtocol10+" "+
strconv.FormatInt(int64(res.StatusCode), 10)+" "+
res.StatusMessage+"\r\n"))
if len(res.Body) != 0 {
res.Header["Content-Length"] = HeaderValue{strconv.FormatInt(int64(len(res.Body)), 10)}
}
pos += res.Header.writeTo(buf[pos:])
pos += body(res.Body).writeTo(buf[pos:])
return pos, nil
}
// Write writes a Response.
func (res Response) Write() ([]byte, error) {
buf := make([]byte, res.WriteSize())
_, err := res.WriteTo(buf)
return buf, err
}
// String implements fmt.Stringer.
func (res Response) String() string {
var buf bytes.Buffer
res.Write(&buf)
return buf.String()
buf, _ := res.Write()
return string(buf)
}

View File

@@ -178,9 +178,9 @@ func TestResponseReadErrors(t *testing.T) {
func TestResponseWrite(t *testing.T) {
for _, c := range casesResponse {
t.Run(c.name, func(t *testing.T) {
var buf bytes.Buffer
c.res.Write(&buf)
require.Equal(t, c.byts, buf.Bytes())
buf, err := c.res.Write()
require.NoError(t, err)
require.Equal(t, c.byts, buf)
})
}
}
@@ -207,9 +207,9 @@ func TestResponseWriteAutoFillStatus(t *testing.T) {
"\r\n",
)
var buf bytes.Buffer
res.Write(&buf)
require.Equal(t, byts, buf.Bytes())
buf, err := res.Write()
require.NoError(t, err)
require.Equal(t, byts, buf)
}
func TestResponseReadIgnoreFrames(t *testing.T) {