stat/card: allow unmarshaling into the zero value

This commit is contained in:
Dan Kortschak
2019-12-26 14:46:11 +10:30
parent c9bf87fc97
commit 81776cc9ef
5 changed files with 159 additions and 27 deletions

View File

@@ -7,8 +7,11 @@
package card package card
import ( import (
"fmt"
"hash"
"math" "math"
"reflect" "reflect"
"sync"
) )
const ( const (
@@ -59,3 +62,48 @@ func typeNameOf(v interface{}) string {
} }
return prefix + t.PkgPath() + "." + t.Name() return prefix + t.PkgPath() + "." + t.Name()
} }
// hashes holds registered hashes.
var hashes sync.Map // map[string]userType
type userType struct {
fn reflect.Value // Holds a func() hash.Hash{32,64}.
typ reflect.Type // Type of the returned hash implementation.
}
// RegisterHash registers a function that returns a new hash.Hash32 or hash.Hash64
// to the name of the type implementing the interface. The value of fn must be a
// func() hash.Hash32 or func() hash.Hash64, otherwise RegisterHash will panic.
// RegisterHash will panic if there is not a unique mapping from the name to the
// returned type.
func RegisterHash(fn interface{}) {
const invalidType = "card: must register func() hash.Hash32 or func() hash.Hash64"
rf := reflect.ValueOf(fn)
rt := rf.Type()
if rf.Kind() != reflect.Func {
panic(invalidType)
}
if rt.NumIn() != 0 {
panic(invalidType)
}
if rt.NumOut() != 1 {
panic(invalidType)
}
h := rf.Call(nil)[0].Interface()
var name string
var h32 hash.Hash32
var h64 hash.Hash64
switch rf.Type().Out(0) {
case reflect.TypeOf(&h32).Elem(), reflect.TypeOf(&h64).Elem():
name = typeNameOf(h)
default:
panic(invalidType)
}
user := userType{fn: rf, typ: reflect.TypeOf(h)}
ut, dup := hashes.LoadOrStore(name, user)
stored := ut.(userType)
if dup && stored.typ != user.typ {
panic(fmt.Sprintf("card: registering duplicate types for %q: %s != %s", name, stored.typ, user.typ))
}
}

View File

@@ -7,10 +7,12 @@ package card
import ( import (
"encoding" "encoding"
"fmt" "fmt"
"hash"
"hash/fnv" "hash/fnv"
"io" "io"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"golang.org/x/exp/rand" "golang.org/x/exp/rand"
@@ -203,39 +205,45 @@ type counterEncoder interface {
} }
var counterEncoderTests = []struct { var counterEncoderTests = []struct {
name string name string
count int count int
src, dst func() counterEncoder src, dst, zdst func() counterEncoder
}{ }{
{ {
name: "HyperLogLog32-4-4-FNV-1a", count: 1e3, name: "HyperLogLog32-4-4-FNV-1a", count: 1e3,
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
zdst: func() counterEncoder { return &HyperLogLog32{} },
}, },
{ {
name: "HyperLogLog32-4-8-FNV-1a", count: 1e3, name: "HyperLogLog32-4-8-FNV-1a", count: 1e3,
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) }, dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
zdst: func() counterEncoder { return &HyperLogLog32{} },
}, },
{ {
name: "HyperLogLog32-8-4-FNV-1a", count: 1e3, name: "HyperLogLog32-8-4-FNV-1a", count: 1e3,
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) }, src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
zdst: func() counterEncoder { return &HyperLogLog32{} },
}, },
{ {
name: "HyperLogLog64-4-4-FNV-1a", count: 1e3, name: "HyperLogLog64-4-4-FNV-1a", count: 1e3,
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
zdst: func() counterEncoder { return &HyperLogLog64{} },
}, },
{ {
name: "HyperLogLog64-4-8-FNV-1a", count: 1e3, name: "HyperLogLog64-4-8-FNV-1a", count: 1e3,
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) }, dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
zdst: func() counterEncoder { return &HyperLogLog64{} },
}, },
{ {
name: "HyperLogLog64-8-4-FNV-1a", count: 1e3, name: "HyperLogLog64-8-4-FNV-1a", count: 1e3,
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) }, src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
zdst: func() counterEncoder { return &HyperLogLog64{} },
}, },
} }
@@ -247,6 +255,11 @@ func mustCounterEncoder(c counterEncoder, err error) counterEncoder {
} }
func TestBinaryEncoding(t *testing.T) { func TestBinaryEncoding(t *testing.T) {
RegisterHash(fnv.New32a)
RegisterHash(fnv.New64a)
defer func() {
hashes = sync.Map{}
}()
for _, test := range counterEncoderTests { for _, test := range counterEncoderTests {
rnd := rand.New(rand.NewSource(1)) rnd := rand.New(rand.NewSource(1))
src := test.src() src := test.src()
@@ -277,12 +290,56 @@ func TestBinaryEncoding(t *testing.T) {
t.Errorf("unexpected error unmarshaling binary for %s: %v", test.name, err) t.Errorf("unexpected error unmarshaling binary for %s: %v", test.name, err)
continue continue
} }
zdst := test.zdst()
err = zdst.UnmarshalBinary(buf)
if err != nil {
t.Errorf("unexpected error unmarshaling binary into zero receiver for %s: %v", test.name, err)
continue
}
gotSrc := src.Count() gotSrc := src.Count()
gotDst := dst.Count() gotDst := dst.Count()
gotZdst := zdst.Count()
if gotSrc != gotDst { if gotSrc != gotDst {
t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, gotDst, gotSrc) t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, gotDst, gotSrc)
} }
if gotSrc != gotZdst {
t.Errorf("unexpected count for %s into zero receiver: got:%.0f want:%.0f", test.name, gotZdst, gotSrc)
}
}
}
var invalidRegisterTests = []struct {
fn interface{}
panics bool
}{
{fn: int(0), panics: true},
{fn: func() {}, panics: true},
{fn: func(int) {}, panics: true},
{fn: func() int { return 0 }, panics: true},
{fn: func() hash.Hash { return fnv.New32a() }, panics: true},
{fn: func() hash.Hash32 { return fnv.New32a() }, panics: false},
{fn: func() hash.Hash { return fnv.New64a() }, panics: true},
{fn: func() hash.Hash64 { return fnv.New64a() }, panics: false},
}
func TestRegisterInvalid(t *testing.T) {
for _, test := range invalidRegisterTests {
var r interface{}
func() {
defer func() {
r = recover()
}()
RegisterHash(test.fn)
}()
panicked := r != nil
if panicked != test.panics {
if panicked {
t.Errorf("unexpected panic for %T", test.fn)
} else {
t.Errorf("expected panic for %T", test.fn)
}
}
} }
} }

View File

@@ -14,5 +14,6 @@ echo -e '// Code generated by "go generate gonum.org/v1/gonum/stat/card"; DO NOT
-e 's/rho32/rho64/' \ -e 's/rho32/rho64/' \
-e 's/HyperLogLog32/HyperLogLog64/g' \ -e 's/HyperLogLog32/HyperLogLog64/g' \
-e 's/Hash32/Hash64/' \ -e 's/Hash32/Hash64/' \
-e 's/hash32/hash64/' \
-e 's/w32/w64/g' \ -e 's/w32/w64/g' \
>> hll64.go >> hll64.go

View File

@@ -179,9 +179,6 @@ func (h *HyperLogLog32) MarshalBinary() ([]byte, error) {
// return. The receiver must have a non-nil hash function value that is // return. The receiver must have a non-nil hash function value that is
// the same type as the one that was stored in the binary data. // the same type as the one that was stored in the binary data.
func (h *HyperLogLog32) UnmarshalBinary(b []byte) error { func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
if h.hash == nil {
return errors.New("card: hash function not set")
}
dec := gob.NewDecoder(bytes.NewReader(b)) dec := gob.NewDecoder(bytes.NewReader(b))
var size uint8 var size uint8
err := dec.Decode(&size) err := dec.Decode(&size)
@@ -196,9 +193,16 @@ func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
if err != nil { if err != nil {
return err return err
} }
dstHash := typeNameOf(h.hash) if h.hash == nil {
if dstHash != srcHash { h.hash = hash32For(srcHash)
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash) if h.hash == nil {
return fmt.Errorf("card: hash function not set and no hash registered for %q", srcHash)
}
} else {
dstHash := typeNameOf(h.hash)
if dstHash != srcHash {
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash)
}
} }
err = dec.Decode(&h.p) err = dec.Decode(&h.p)
if err != nil { if err != nil {
@@ -212,3 +216,12 @@ func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
} }
return nil return nil
} }
func hash32For(name string) hash.Hash32 {
fn, ok := hashes.Load(name)
if !ok {
return nil
}
h, _ := fn.(userType).fn.Call(nil)[0].Interface().(hash.Hash32)
return h
}

View File

@@ -181,9 +181,6 @@ func (h *HyperLogLog64) MarshalBinary() ([]byte, error) {
// return. The receiver must have a non-nil hash function value that is // return. The receiver must have a non-nil hash function value that is
// the same type as the one that was stored in the binary data. // the same type as the one that was stored in the binary data.
func (h *HyperLogLog64) UnmarshalBinary(b []byte) error { func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
if h.hash == nil {
return errors.New("card: hash function not set")
}
dec := gob.NewDecoder(bytes.NewReader(b)) dec := gob.NewDecoder(bytes.NewReader(b))
var size uint8 var size uint8
err := dec.Decode(&size) err := dec.Decode(&size)
@@ -198,9 +195,16 @@ func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
if err != nil { if err != nil {
return err return err
} }
dstHash := typeNameOf(h.hash) if h.hash == nil {
if dstHash != srcHash { h.hash = hash64For(srcHash)
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash) if h.hash == nil {
return fmt.Errorf("card: hash function not set and no hash registered for %q", srcHash)
}
} else {
dstHash := typeNameOf(h.hash)
if dstHash != srcHash {
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash)
}
} }
err = dec.Decode(&h.p) err = dec.Decode(&h.p)
if err != nil { if err != nil {
@@ -214,3 +218,12 @@ func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
} }
return nil return nil
} }
func hash64For(name string) hash.Hash64 {
fn, ok := hashes.Load(name)
if !ok {
return nil
}
h, _ := fn.(userType).fn.Call(nil)[0].Interface().(hash.Hash64)
return h
}