mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 15:47:01 +08:00
mat: retain type and dimension attributes in encodings
This commit is contained in:
294
mat/io.go
294
mat/io.go
@@ -5,51 +5,86 @@
|
||||
package mat
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxLen is the biggest slice/array len one can create on a 32/64b platform.
|
||||
maxLen = int64(int(^uint(0) >> 1))
|
||||
)
|
||||
// version is the current on-disk codec version.
|
||||
const version uint32 = 0x1
|
||||
|
||||
// maxLen is the biggest slice/array len one can create on a 32/64b platform.
|
||||
const maxLen = int64(int(^uint(0) >> 1))
|
||||
|
||||
var (
|
||||
headerSize = binary.Size(storage{})
|
||||
sizeInt64 = binary.Size(int64(0))
|
||||
sizeFloat64 = binary.Size(float64(0))
|
||||
|
||||
errWrongType = errors.New("mat: wrong data type")
|
||||
|
||||
errTooBig = errors.New("mat: resulting data slice too big")
|
||||
errTooSmall = errors.New("mat: input slice too small")
|
||||
errBadBuffer = errors.New("mat: data buffer size mismatch")
|
||||
errBadSize = errors.New("mat: invalid dimension")
|
||||
)
|
||||
|
||||
// Type encoding scheme:
|
||||
//
|
||||
// Type Form Packing Uplo Unit Rows Columns kU kL
|
||||
// uint8 [GST] uint8 [BPF] uint8 [AUL] bool int64 int64 int64 int64
|
||||
// General 'G' 'F' 'A' false r c 0 0
|
||||
// Band 'G' 'B' 'A' false r c kU kL
|
||||
// Symmetric 'S' 'F' ul false n n 0 0
|
||||
// SymmetricBand 'S' 'B' ul false n n k k
|
||||
// SymmetricPacked 'S' 'P' ul false n n 0 0
|
||||
// Triangular 'T' 'F' ul Diag==Unit n n 0 0
|
||||
// TriangularBand 'T' 'B' ul Diag==Unit n n k k
|
||||
// TriangularPacked 'T' 'P' ul Diag==Unit n n 0 0
|
||||
//
|
||||
// G - general, S - symmetric, T - triangular
|
||||
// F - full, B - band, P - packed
|
||||
// A - all, U - upper, L - lower
|
||||
|
||||
// MarshalBinary encodes the receiver into a binary form and returns the result.
|
||||
//
|
||||
// Dense is little-endian encoded as follows:
|
||||
// 0 - 7 number of rows (int64)
|
||||
// 8 - 15 number of columns (int64)
|
||||
// 16 - .. matrix data elements (float64)
|
||||
// 0 - 3 Version = 1 (uint32)
|
||||
// 4 'G' (byte)
|
||||
// 5 'F' (byte)
|
||||
// 6 'A' (byte)
|
||||
// 7 0 (byte)
|
||||
// 8 - 15 number of rows (int64)
|
||||
// 16 - 23 number of columns (int64)
|
||||
// 24 - 31 0 (int64)
|
||||
// 32 - 39 0 (int64)
|
||||
// 40 - .. matrix data elements (float64)
|
||||
// [0,0] [0,1] ... [0,ncols-1]
|
||||
// [1,0] [1,1] ... [1,ncols-1]
|
||||
// ...
|
||||
// [nrows-1,0] ... [nrows-1,ncols-1]
|
||||
func (m Dense) MarshalBinary() ([]byte, error) {
|
||||
bufLen := int64(m.mat.Rows)*int64(m.mat.Cols)*int64(sizeFloat64) + 2*int64(sizeInt64)
|
||||
bufLen := int64(headerSize) + int64(m.mat.Rows)*int64(m.mat.Cols)*int64(sizeFloat64)
|
||||
if bufLen <= 0 {
|
||||
// bufLen is too big and has wrapped around.
|
||||
return nil, errTooBig
|
||||
}
|
||||
|
||||
p := 0
|
||||
header := storage{
|
||||
Form: 'G', Packing: 'F', Uplo: 'A',
|
||||
Rows: int64(m.mat.Rows), Cols: int64(m.mat.Cols),
|
||||
Version: version,
|
||||
}
|
||||
buf := make([]byte, bufLen)
|
||||
binary.LittleEndian.PutUint64(buf[p:p+sizeInt64], uint64(m.mat.Rows))
|
||||
p += sizeInt64
|
||||
binary.LittleEndian.PutUint64(buf[p:p+sizeInt64], uint64(m.mat.Cols))
|
||||
p += sizeInt64
|
||||
n, err := header.marshalBinaryTo(bytes.NewBuffer(buf[:0]))
|
||||
if err != nil {
|
||||
return buf[:n], err
|
||||
}
|
||||
|
||||
p := headerSize
|
||||
r, c := m.Dims()
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
@@ -66,26 +101,22 @@ func (m Dense) MarshalBinary() ([]byte, error) {
|
||||
//
|
||||
// See MarshalBinary for the on-disk layout.
|
||||
func (m Dense) MarshalBinaryTo(w io.Writer) (int, error) {
|
||||
var n int
|
||||
var buf [8]byte
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(m.mat.Rows))
|
||||
nn, err := w.Write(buf[:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
header := storage{
|
||||
Form: 'G', Packing: 'F', Uplo: 'A',
|
||||
Rows: int64(m.mat.Rows), Cols: int64(m.mat.Cols),
|
||||
Version: version,
|
||||
}
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(m.mat.Cols))
|
||||
nn, err = w.Write(buf[:])
|
||||
n += nn
|
||||
n, err := header.marshalBinaryTo(w)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
r, c := m.Dims()
|
||||
var b [8]byte
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
binary.LittleEndian.PutUint64(buf[:], math.Float64bits(m.at(i, j)))
|
||||
nn, err = w.Write(buf[:])
|
||||
binary.LittleEndian.PutUint64(b[:], math.Float64bits(m.at(i, j)))
|
||||
nn, err := w.Write(b[:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
@@ -113,19 +144,26 @@ func (m *Dense) UnmarshalBinary(data []byte) error {
|
||||
panic("mat: unmarshal into non-zero matrix")
|
||||
}
|
||||
|
||||
if len(data) < 2*sizeInt64 {
|
||||
if len(data) < headerSize {
|
||||
return errTooSmall
|
||||
}
|
||||
|
||||
p := 0
|
||||
rows := int64(binary.LittleEndian.Uint64(data[p : p+sizeInt64]))
|
||||
p += sizeInt64
|
||||
cols := int64(binary.LittleEndian.Uint64(data[p : p+sizeInt64]))
|
||||
p += sizeInt64
|
||||
var header storage
|
||||
err := header.unmarshalBinary(data[:headerSize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows := header.Rows
|
||||
cols := header.Cols
|
||||
header.Version = 0
|
||||
header.Rows = 0
|
||||
header.Cols = 0
|
||||
if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
|
||||
return errWrongType
|
||||
}
|
||||
if rows < 0 || cols < 0 {
|
||||
return errBadSize
|
||||
}
|
||||
|
||||
size := rows * cols
|
||||
if size == 0 {
|
||||
return ErrZeroLength
|
||||
@@ -133,11 +171,11 @@ func (m *Dense) UnmarshalBinary(data []byte) error {
|
||||
if int(size) < 0 || size > maxLen {
|
||||
return errTooBig
|
||||
}
|
||||
|
||||
if len(data) != int(size)*sizeFloat64+2*sizeInt64 {
|
||||
if len(data) != headerSize+int(rows*cols)*sizeFloat64 {
|
||||
return errBadBuffer
|
||||
}
|
||||
|
||||
p := headerSize
|
||||
m.reuseAs(int(rows), int(cols))
|
||||
for i := range m.mat.Data {
|
||||
m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[p : p+sizeFloat64]))
|
||||
@@ -165,27 +203,22 @@ func (m *Dense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
|
||||
panic("mat: unmarshal into non-zero matrix")
|
||||
}
|
||||
|
||||
var (
|
||||
n int
|
||||
buf [8]byte
|
||||
)
|
||||
nn, err := readFull(r, buf[:])
|
||||
n += nn
|
||||
var header storage
|
||||
n, err := header.unmarshalBinaryFrom(r)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
rows := int64(binary.LittleEndian.Uint64(buf[:]))
|
||||
|
||||
nn, err = readFull(r, buf[:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
rows := header.Rows
|
||||
cols := header.Cols
|
||||
header.Version = 0
|
||||
header.Rows = 0
|
||||
header.Cols = 0
|
||||
if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
|
||||
return n, errWrongType
|
||||
}
|
||||
cols := int64(binary.LittleEndian.Uint64(buf[:]))
|
||||
if rows < 0 || cols < 0 {
|
||||
return n, errBadSize
|
||||
}
|
||||
|
||||
size := rows * cols
|
||||
if size == 0 {
|
||||
return n, ErrZeroLength
|
||||
@@ -195,13 +228,17 @@ func (m *Dense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
|
||||
}
|
||||
|
||||
m.reuseAs(int(rows), int(cols))
|
||||
var b [8]byte
|
||||
for i := range m.mat.Data {
|
||||
nn, err = readFull(r, buf[:])
|
||||
nn, err := readFull(r, b[:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(buf[:]))
|
||||
m.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(b[:]))
|
||||
}
|
||||
|
||||
return n, nil
|
||||
@@ -210,20 +247,36 @@ func (m *Dense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
|
||||
// MarshalBinary encodes the receiver into a binary form and returns the result.
|
||||
//
|
||||
// VecDense is little-endian encoded as follows:
|
||||
// 0 - 7 number of elements (int64)
|
||||
// 8 - .. vector's data elements (float64)
|
||||
//
|
||||
// 0 - 3 Version = 1 (uint32)
|
||||
// 4 'G' (byte)
|
||||
// 5 'F' (byte)
|
||||
// 6 'A' (byte)
|
||||
// 7 0 (byte)
|
||||
// 8 - 15 number of elements (int64)
|
||||
// 16 - 23 1 (int64)
|
||||
// 24 - 31 0 (int64)
|
||||
// 32 - 39 0 (int64)
|
||||
// 40 - .. vector's data elements (float64)
|
||||
func (v VecDense) MarshalBinary() ([]byte, error) {
|
||||
bufLen := int64(sizeInt64) + int64(v.n)*int64(sizeFloat64)
|
||||
bufLen := int64(headerSize) + int64(v.n)*int64(sizeFloat64)
|
||||
if bufLen <= 0 {
|
||||
// bufLen is too big and has wrapped around.
|
||||
return nil, errTooBig
|
||||
}
|
||||
|
||||
p := 0
|
||||
header := storage{
|
||||
Form: 'G', Packing: 'F', Uplo: 'A',
|
||||
Rows: int64(v.n), Cols: 1,
|
||||
Version: version,
|
||||
}
|
||||
buf := make([]byte, bufLen)
|
||||
binary.LittleEndian.PutUint64(buf[p:p+sizeInt64], uint64(v.n))
|
||||
p += sizeInt64
|
||||
n, err := header.marshalBinaryTo(bytes.NewBuffer(buf[:0]))
|
||||
if err != nil {
|
||||
return buf[:n], err
|
||||
}
|
||||
|
||||
p := headerSize
|
||||
for i := 0; i < v.n; i++ {
|
||||
binary.LittleEndian.PutUint64(buf[p:p+sizeFloat64], math.Float64bits(v.at(i)))
|
||||
p += sizeFloat64
|
||||
@@ -237,21 +290,20 @@ func (v VecDense) MarshalBinary() ([]byte, error) {
|
||||
//
|
||||
// See MarshalBainry for the on-disk format.
|
||||
func (v VecDense) MarshalBinaryTo(w io.Writer) (int, error) {
|
||||
var (
|
||||
n int
|
||||
buf [8]byte
|
||||
)
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:], uint64(v.n))
|
||||
nn, err := w.Write(buf[:])
|
||||
n += nn
|
||||
header := storage{
|
||||
Form: 'G', Packing: 'F', Uplo: 'A',
|
||||
Rows: int64(v.n), Cols: 1,
|
||||
Version: version,
|
||||
}
|
||||
n, err := header.marshalBinaryTo(w)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
var buf [8]byte
|
||||
for i := 0; i < v.n; i++ {
|
||||
binary.LittleEndian.PutUint64(buf[:], math.Float64bits(v.at(i)))
|
||||
nn, err = w.Write(buf[:])
|
||||
nn, err := w.Write(buf[:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
return n, err
|
||||
@@ -278,22 +330,39 @@ func (v *VecDense) UnmarshalBinary(data []byte) error {
|
||||
panic("mat: unmarshal into non-zero vector")
|
||||
}
|
||||
|
||||
p := 0
|
||||
n := int64(binary.LittleEndian.Uint64(data[p : p+sizeInt64]))
|
||||
if len(data) < headerSize {
|
||||
return errTooSmall
|
||||
}
|
||||
|
||||
var header storage
|
||||
err := header.unmarshalBinary(data[:headerSize])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if header.Cols != 1 {
|
||||
return ErrShape
|
||||
}
|
||||
n := header.Rows
|
||||
header.Version = 0
|
||||
header.Rows = 0
|
||||
header.Cols = 0
|
||||
if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
|
||||
return errWrongType
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrZeroLength
|
||||
}
|
||||
p += sizeInt64
|
||||
if n < 0 {
|
||||
return errBadSize
|
||||
}
|
||||
if n > maxLen {
|
||||
if int64(maxLen) < n {
|
||||
return errTooBig
|
||||
}
|
||||
if len(data) != int(n)*sizeFloat64+sizeInt64 {
|
||||
if len(data) != headerSize+int(n)*sizeFloat64 {
|
||||
return errBadBuffer
|
||||
}
|
||||
|
||||
p := headerSize
|
||||
v.reuseAs(int(n))
|
||||
for i := range v.mat.Data {
|
||||
v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(data[p : p+sizeFloat64]))
|
||||
@@ -314,43 +383,94 @@ func (v *VecDense) UnmarshalBinaryFrom(r io.Reader) (int, error) {
|
||||
panic("mat: unmarshal into non-zero vector")
|
||||
}
|
||||
|
||||
var (
|
||||
n int
|
||||
buf [8]byte
|
||||
)
|
||||
nn, err := readFull(r, buf[:])
|
||||
n += nn
|
||||
var header storage
|
||||
n, err := header.unmarshalBinaryFrom(r)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
sz := int64(binary.LittleEndian.Uint64(buf[:]))
|
||||
if sz == 0 {
|
||||
if header.Cols != 1 {
|
||||
return n, ErrShape
|
||||
}
|
||||
l := header.Rows
|
||||
header.Version = 0
|
||||
header.Rows = 0
|
||||
header.Cols = 0
|
||||
if (header != storage{Form: 'G', Packing: 'F', Uplo: 'A'}) {
|
||||
return n, errWrongType
|
||||
}
|
||||
if l == 0 {
|
||||
return n, ErrZeroLength
|
||||
}
|
||||
if sz < 0 {
|
||||
if l < 0 {
|
||||
return n, errBadSize
|
||||
}
|
||||
if sz > maxLen {
|
||||
if int64(maxLen) < l {
|
||||
return n, errTooBig
|
||||
}
|
||||
|
||||
v.reuseAs(int(sz))
|
||||
v.reuseAs(int(l))
|
||||
var b [8]byte
|
||||
for i := range v.mat.Data {
|
||||
nn, err = readFull(r, buf[:])
|
||||
nn, err := readFull(r, b[:])
|
||||
n += nn
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(buf[:]))
|
||||
}
|
||||
|
||||
if n != sizeInt64+int(sz)*sizeFloat64 {
|
||||
return n, io.ErrUnexpectedEOF
|
||||
v.mat.Data[i] = math.Float64frombits(binary.LittleEndian.Uint64(b[:]))
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// storage is the internal representation of the storage format of a
|
||||
// serialised matrix.
|
||||
type storage struct {
|
||||
Version uint32 // Keep this first.
|
||||
Form byte // [GST]
|
||||
Packing byte // [BPF]
|
||||
Uplo byte // [AUL]
|
||||
Unit bool
|
||||
Rows int64
|
||||
Cols int64
|
||||
KU int64
|
||||
KL int64
|
||||
}
|
||||
|
||||
// TODO(kortschak): Consider replacing these with calls to direct
|
||||
// encoding/decoding of fields rather than to binary.Write/binary.Read.
|
||||
|
||||
func (s storage) marshalBinaryTo(w io.Writer) (int, error) {
|
||||
buf := bytes.NewBuffer(make([]byte, 0, headerSize))
|
||||
err := binary.Write(buf, binary.LittleEndian, s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
func (s *storage) unmarshalBinary(buf []byte) error {
|
||||
err := binary.Read(bytes.NewReader(buf), binary.LittleEndian, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.Version != version {
|
||||
return fmt.Errorf("mat: incorrect version: %d", s.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *storage) unmarshalBinaryFrom(r io.Reader) (int, error) {
|
||||
buf := make([]byte, headerSize)
|
||||
n, err := readFull(r, buf)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
return n, s.unmarshalBinary(buf[:n])
|
||||
}
|
||||
|
||||
// readFull reads from r into buf until it has read len(buf).
|
||||
// It returns the number of bytes copied and an error if fewer bytes were read.
|
||||
// If an EOF happens after reading fewer than len(buf) bytes, io.ErrUnexpectedEOF is returned.
|
||||
|
Reference in New Issue
Block a user