mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
stat/card: allow unmarshaling into the zero value
This commit is contained in:
@@ -7,8 +7,11 @@
|
|||||||
package card
|
package card
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
"math"
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -59,3 +62,48 @@ func typeNameOf(v interface{}) string {
|
|||||||
}
|
}
|
||||||
return prefix + t.PkgPath() + "." + t.Name()
|
return prefix + t.PkgPath() + "." + t.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// hashes holds registered hashes.
|
||||||
|
var hashes sync.Map // map[string]userType
|
||||||
|
|
||||||
|
type userType struct {
|
||||||
|
fn reflect.Value // Holds a func() hash.Hash{32,64}.
|
||||||
|
typ reflect.Type // Type of the returned hash implementation.
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterHash registers a function that returns a new hash.Hash32 or hash.Hash64
|
||||||
|
// to the name of the type implementing the interface. The value of fn must be a
|
||||||
|
// func() hash.Hash32 or func() hash.Hash64, otherwise RegisterHash will panic.
|
||||||
|
// RegisterHash will panic if there is not a unique mapping from the name to the
|
||||||
|
// returned type.
|
||||||
|
func RegisterHash(fn interface{}) {
|
||||||
|
const invalidType = "card: must register func() hash.Hash32 or func() hash.Hash64"
|
||||||
|
|
||||||
|
rf := reflect.ValueOf(fn)
|
||||||
|
rt := rf.Type()
|
||||||
|
if rf.Kind() != reflect.Func {
|
||||||
|
panic(invalidType)
|
||||||
|
}
|
||||||
|
if rt.NumIn() != 0 {
|
||||||
|
panic(invalidType)
|
||||||
|
}
|
||||||
|
if rt.NumOut() != 1 {
|
||||||
|
panic(invalidType)
|
||||||
|
}
|
||||||
|
h := rf.Call(nil)[0].Interface()
|
||||||
|
var name string
|
||||||
|
var h32 hash.Hash32
|
||||||
|
var h64 hash.Hash64
|
||||||
|
switch rf.Type().Out(0) {
|
||||||
|
case reflect.TypeOf(&h32).Elem(), reflect.TypeOf(&h64).Elem():
|
||||||
|
name = typeNameOf(h)
|
||||||
|
default:
|
||||||
|
panic(invalidType)
|
||||||
|
}
|
||||||
|
user := userType{fn: rf, typ: reflect.TypeOf(h)}
|
||||||
|
ut, dup := hashes.LoadOrStore(name, user)
|
||||||
|
stored := ut.(userType)
|
||||||
|
if dup && stored.typ != user.typ {
|
||||||
|
panic(fmt.Sprintf("card: registering duplicate types for %q: %s != %s", name, stored.typ, user.typ))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -7,10 +7,12 @@ package card
|
|||||||
import (
|
import (
|
||||||
"encoding"
|
"encoding"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"hash"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
@@ -205,37 +207,43 @@ type counterEncoder interface {
|
|||||||
var counterEncoderTests = []struct {
|
var counterEncoderTests = []struct {
|
||||||
name string
|
name string
|
||||||
count int
|
count int
|
||||||
src, dst func() counterEncoder
|
src, dst, zdst func() counterEncoder
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "HyperLogLog32-4-4-FNV-1a", count: 1e3,
|
name: "HyperLogLog32-4-4-FNV-1a", count: 1e3,
|
||||||
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
||||||
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
||||||
|
zdst: func() counterEncoder { return &HyperLogLog32{} },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "HyperLogLog32-4-8-FNV-1a", count: 1e3,
|
name: "HyperLogLog32-4-8-FNV-1a", count: 1e3,
|
||||||
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
||||||
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
|
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
|
||||||
|
zdst: func() counterEncoder { return &HyperLogLog32{} },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "HyperLogLog32-8-4-FNV-1a", count: 1e3,
|
name: "HyperLogLog32-8-4-FNV-1a", count: 1e3,
|
||||||
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
|
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) },
|
||||||
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) },
|
||||||
|
zdst: func() counterEncoder { return &HyperLogLog32{} },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "HyperLogLog64-4-4-FNV-1a", count: 1e3,
|
name: "HyperLogLog64-4-4-FNV-1a", count: 1e3,
|
||||||
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
||||||
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
||||||
|
zdst: func() counterEncoder { return &HyperLogLog64{} },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "HyperLogLog64-4-8-FNV-1a", count: 1e3,
|
name: "HyperLogLog64-4-8-FNV-1a", count: 1e3,
|
||||||
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
||||||
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
|
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
|
||||||
|
zdst: func() counterEncoder { return &HyperLogLog64{} },
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "HyperLogLog64-8-4-FNV-1a", count: 1e3,
|
name: "HyperLogLog64-8-4-FNV-1a", count: 1e3,
|
||||||
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
|
src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) },
|
||||||
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) },
|
||||||
|
zdst: func() counterEncoder { return &HyperLogLog64{} },
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,6 +255,11 @@ func mustCounterEncoder(c counterEncoder, err error) counterEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestBinaryEncoding(t *testing.T) {
|
func TestBinaryEncoding(t *testing.T) {
|
||||||
|
RegisterHash(fnv.New32a)
|
||||||
|
RegisterHash(fnv.New64a)
|
||||||
|
defer func() {
|
||||||
|
hashes = sync.Map{}
|
||||||
|
}()
|
||||||
for _, test := range counterEncoderTests {
|
for _, test := range counterEncoderTests {
|
||||||
rnd := rand.New(rand.NewSource(1))
|
rnd := rand.New(rand.NewSource(1))
|
||||||
src := test.src()
|
src := test.src()
|
||||||
@@ -277,12 +290,56 @@ func TestBinaryEncoding(t *testing.T) {
|
|||||||
t.Errorf("unexpected error unmarshaling binary for %s: %v", test.name, err)
|
t.Errorf("unexpected error unmarshaling binary for %s: %v", test.name, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
zdst := test.zdst()
|
||||||
|
err = zdst.UnmarshalBinary(buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("unexpected error unmarshaling binary into zero receiver for %s: %v", test.name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
gotSrc := src.Count()
|
gotSrc := src.Count()
|
||||||
gotDst := dst.Count()
|
gotDst := dst.Count()
|
||||||
|
gotZdst := zdst.Count()
|
||||||
|
|
||||||
if gotSrc != gotDst {
|
if gotSrc != gotDst {
|
||||||
t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, gotDst, gotSrc)
|
t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, gotDst, gotSrc)
|
||||||
}
|
}
|
||||||
|
if gotSrc != gotZdst {
|
||||||
|
t.Errorf("unexpected count for %s into zero receiver: got:%.0f want:%.0f", test.name, gotZdst, gotSrc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var invalidRegisterTests = []struct {
|
||||||
|
fn interface{}
|
||||||
|
panics bool
|
||||||
|
}{
|
||||||
|
{fn: int(0), panics: true},
|
||||||
|
{fn: func() {}, panics: true},
|
||||||
|
{fn: func(int) {}, panics: true},
|
||||||
|
{fn: func() int { return 0 }, panics: true},
|
||||||
|
{fn: func() hash.Hash { return fnv.New32a() }, panics: true},
|
||||||
|
{fn: func() hash.Hash32 { return fnv.New32a() }, panics: false},
|
||||||
|
{fn: func() hash.Hash { return fnv.New64a() }, panics: true},
|
||||||
|
{fn: func() hash.Hash64 { return fnv.New64a() }, panics: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegisterInvalid(t *testing.T) {
|
||||||
|
for _, test := range invalidRegisterTests {
|
||||||
|
var r interface{}
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
r = recover()
|
||||||
|
}()
|
||||||
|
RegisterHash(test.fn)
|
||||||
|
}()
|
||||||
|
panicked := r != nil
|
||||||
|
if panicked != test.panics {
|
||||||
|
if panicked {
|
||||||
|
t.Errorf("unexpected panic for %T", test.fn)
|
||||||
|
} else {
|
||||||
|
t.Errorf("expected panic for %T", test.fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -14,5 +14,6 @@ echo -e '// Code generated by "go generate gonum.org/v1/gonum/stat/card"; DO NOT
|
|||||||
-e 's/rho32/rho64/' \
|
-e 's/rho32/rho64/' \
|
||||||
-e 's/HyperLogLog32/HyperLogLog64/g' \
|
-e 's/HyperLogLog32/HyperLogLog64/g' \
|
||||||
-e 's/Hash32/Hash64/' \
|
-e 's/Hash32/Hash64/' \
|
||||||
|
-e 's/hash32/hash64/' \
|
||||||
-e 's/w32/w64/g' \
|
-e 's/w32/w64/g' \
|
||||||
>> hll64.go
|
>> hll64.go
|
@@ -179,9 +179,6 @@ func (h *HyperLogLog32) MarshalBinary() ([]byte, error) {
|
|||||||
// return. The receiver must have a non-nil hash function value that is
|
// return. The receiver must have a non-nil hash function value that is
|
||||||
// the same type as the one that was stored in the binary data.
|
// the same type as the one that was stored in the binary data.
|
||||||
func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
|
func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
|
||||||
if h.hash == nil {
|
|
||||||
return errors.New("card: hash function not set")
|
|
||||||
}
|
|
||||||
dec := gob.NewDecoder(bytes.NewReader(b))
|
dec := gob.NewDecoder(bytes.NewReader(b))
|
||||||
var size uint8
|
var size uint8
|
||||||
err := dec.Decode(&size)
|
err := dec.Decode(&size)
|
||||||
@@ -196,10 +193,17 @@ func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if h.hash == nil {
|
||||||
|
h.hash = hash32For(srcHash)
|
||||||
|
if h.hash == nil {
|
||||||
|
return fmt.Errorf("card: hash function not set and no hash registered for %q", srcHash)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
dstHash := typeNameOf(h.hash)
|
dstHash := typeNameOf(h.hash)
|
||||||
if dstHash != srcHash {
|
if dstHash != srcHash {
|
||||||
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash)
|
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
err = dec.Decode(&h.p)
|
err = dec.Decode(&h.p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -212,3 +216,12 @@ func (h *HyperLogLog32) UnmarshalBinary(b []byte) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hash32For(name string) hash.Hash32 {
|
||||||
|
fn, ok := hashes.Load(name)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h, _ := fn.(userType).fn.Call(nil)[0].Interface().(hash.Hash32)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
@@ -181,9 +181,6 @@ func (h *HyperLogLog64) MarshalBinary() ([]byte, error) {
|
|||||||
// return. The receiver must have a non-nil hash function value that is
|
// return. The receiver must have a non-nil hash function value that is
|
||||||
// the same type as the one that was stored in the binary data.
|
// the same type as the one that was stored in the binary data.
|
||||||
func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
|
func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
|
||||||
if h.hash == nil {
|
|
||||||
return errors.New("card: hash function not set")
|
|
||||||
}
|
|
||||||
dec := gob.NewDecoder(bytes.NewReader(b))
|
dec := gob.NewDecoder(bytes.NewReader(b))
|
||||||
var size uint8
|
var size uint8
|
||||||
err := dec.Decode(&size)
|
err := dec.Decode(&size)
|
||||||
@@ -198,10 +195,17 @@ func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if h.hash == nil {
|
||||||
|
h.hash = hash64For(srcHash)
|
||||||
|
if h.hash == nil {
|
||||||
|
return fmt.Errorf("card: hash function not set and no hash registered for %q", srcHash)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
dstHash := typeNameOf(h.hash)
|
dstHash := typeNameOf(h.hash)
|
||||||
if dstHash != srcHash {
|
if dstHash != srcHash {
|
||||||
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash)
|
return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
err = dec.Decode(&h.p)
|
err = dec.Decode(&h.p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -214,3 +218,12 @@ func (h *HyperLogLog64) UnmarshalBinary(b []byte) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hash64For(name string) hash.Hash64 {
|
||||||
|
fn, ok := hashes.Load(name)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h, _ := fn.(userType).fn.Call(nil)[0].Interface().(hash.Hash64)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user