From 81776cc9ef2600ab0e893cfa1cf6de643da36f2c Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Thu, 26 Dec 2019 14:46:11 +1030 Subject: [PATCH] stat/card: allow unmarshaling into the zero value --- stat/card/card.go | 48 ++++++++++++++++++++ stat/card/card_test.go | 87 ++++++++++++++++++++++++++++++------- stat/card/generate_64bit.sh | 1 + stat/card/hll32.go | 25 ++++++++--- stat/card/hll64.go | 25 ++++++++--- 5 files changed, 159 insertions(+), 27 deletions(-) diff --git a/stat/card/card.go b/stat/card/card.go index e3a0ea09..aff789f4 100644 --- a/stat/card/card.go +++ b/stat/card/card.go @@ -7,8 +7,11 @@ package card import ( + "fmt" + "hash" "math" "reflect" + "sync" ) const ( @@ -59,3 +62,48 @@ func typeNameOf(v interface{}) string { } return prefix + t.PkgPath() + "." + t.Name() } + +// hashes holds registered hashes. +var hashes sync.Map // map[string]userType + +type userType struct { + fn reflect.Value // Holds a func() hash.Hash{32,64}. + typ reflect.Type // Type of the returned hash implementation. +} + +// RegisterHash registers a function that returns a new hash.Hash32 or hash.Hash64 +// to the name of the type implementing the interface. The value of fn must be a +// func() hash.Hash32 or func() hash.Hash64, otherwise RegisterHash will panic. +// RegisterHash will panic if there is not a unique mapping from the name to the +// returned type. +func RegisterHash(fn interface{}) { + const invalidType = "card: must register func() hash.Hash32 or func() hash.Hash64" + + rf := reflect.ValueOf(fn) + rt := rf.Type() + if rf.Kind() != reflect.Func { + panic(invalidType) + } + if rt.NumIn() != 0 { + panic(invalidType) + } + if rt.NumOut() != 1 { + panic(invalidType) + } + h := rf.Call(nil)[0].Interface() + var name string + var h32 hash.Hash32 + var h64 hash.Hash64 + switch rf.Type().Out(0) { + case reflect.TypeOf(&h32).Elem(), reflect.TypeOf(&h64).Elem(): + name = typeNameOf(h) + default: + panic(invalidType) + } + user := userType{fn: rf, typ: reflect.TypeOf(h)} + ut, dup := hashes.LoadOrStore(name, user) + stored := ut.(userType) + if dup && stored.typ != user.typ { + panic(fmt.Sprintf("card: registering duplicate types for %q: %s != %s", name, stored.typ, user.typ)) + } +} diff --git a/stat/card/card_test.go b/stat/card/card_test.go index d3c1baa0..d7734328 100644 --- a/stat/card/card_test.go +++ b/stat/card/card_test.go @@ -7,10 +7,12 @@ package card import ( "encoding" "fmt" + "hash" "hash/fnv" "io" "strconv" "strings" + "sync" "testing" "golang.org/x/exp/rand" @@ -203,39 +205,45 @@ type counterEncoder interface { } var counterEncoderTests = []struct { - name string - count int - src, dst func() counterEncoder + name string + count int + src, dst, zdst func() counterEncoder }{ { name: "HyperLogLog32-4-4-FNV-1a", count: 1e3, - src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, - dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, + src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, + dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, + zdst: func() counterEncoder { return &HyperLogLog32{} }, }, { name: "HyperLogLog32-4-8-FNV-1a", count: 1e3, - src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, - dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) }, + src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, + dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) }, + zdst: func() counterEncoder { return &HyperLogLog32{} }, }, { name: "HyperLogLog32-8-4-FNV-1a", count: 1e3, - src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) }, - dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, + src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(8, fnv.New32a())) }, + dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog32(4, fnv.New32a())) }, + zdst: func() counterEncoder { return &HyperLogLog32{} }, }, { name: "HyperLogLog64-4-4-FNV-1a", count: 1e3, - src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, - dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, + src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, + dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, + zdst: func() counterEncoder { return &HyperLogLog64{} }, }, { name: "HyperLogLog64-4-8-FNV-1a", count: 1e3, - src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, - dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) }, + src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, + dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) }, + zdst: func() counterEncoder { return &HyperLogLog64{} }, }, { name: "HyperLogLog64-8-4-FNV-1a", count: 1e3, - src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) }, - dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, + src: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(8, fnv.New64a())) }, + dst: func() counterEncoder { return mustCounterEncoder(NewHyperLogLog64(4, fnv.New64a())) }, + zdst: func() counterEncoder { return &HyperLogLog64{} }, }, } @@ -247,6 +255,11 @@ func mustCounterEncoder(c counterEncoder, err error) counterEncoder { } func TestBinaryEncoding(t *testing.T) { + RegisterHash(fnv.New32a) + RegisterHash(fnv.New64a) + defer func() { + hashes = sync.Map{} + }() for _, test := range counterEncoderTests { rnd := rand.New(rand.NewSource(1)) src := test.src() @@ -277,12 +290,56 @@ func TestBinaryEncoding(t *testing.T) { t.Errorf("unexpected error unmarshaling binary for %s: %v", test.name, err) continue } + zdst := test.zdst() + err = zdst.UnmarshalBinary(buf) + if err != nil { + t.Errorf("unexpected error unmarshaling binary into zero receiver for %s: %v", test.name, err) + continue + } gotSrc := src.Count() gotDst := dst.Count() + gotZdst := zdst.Count() if gotSrc != gotDst { t.Errorf("unexpected count for %s: got:%.0f want:%.0f", test.name, gotDst, gotSrc) } + if gotSrc != gotZdst { + t.Errorf("unexpected count for %s into zero receiver: got:%.0f want:%.0f", test.name, gotZdst, gotSrc) + } + } +} + +var invalidRegisterTests = []struct { + fn interface{} + panics bool +}{ + {fn: int(0), panics: true}, + {fn: func() {}, panics: true}, + {fn: func(int) {}, panics: true}, + {fn: func() int { return 0 }, panics: true}, + {fn: func() hash.Hash { return fnv.New32a() }, panics: true}, + {fn: func() hash.Hash32 { return fnv.New32a() }, panics: false}, + {fn: func() hash.Hash { return fnv.New64a() }, panics: true}, + {fn: func() hash.Hash64 { return fnv.New64a() }, panics: false}, +} + +func TestRegisterInvalid(t *testing.T) { + for _, test := range invalidRegisterTests { + var r interface{} + func() { + defer func() { + r = recover() + }() + RegisterHash(test.fn) + }() + panicked := r != nil + if panicked != test.panics { + if panicked { + t.Errorf("unexpected panic for %T", test.fn) + } else { + t.Errorf("expected panic for %T", test.fn) + } + } } } diff --git a/stat/card/generate_64bit.sh b/stat/card/generate_64bit.sh index 6840860c..63a5fa2d 100755 --- a/stat/card/generate_64bit.sh +++ b/stat/card/generate_64bit.sh @@ -14,5 +14,6 @@ echo -e '// Code generated by "go generate gonum.org/v1/gonum/stat/card"; DO NOT -e 's/rho32/rho64/' \ -e 's/HyperLogLog32/HyperLogLog64/g' \ -e 's/Hash32/Hash64/' \ + -e 's/hash32/hash64/' \ -e 's/w32/w64/g' \ >> hll64.go \ No newline at end of file diff --git a/stat/card/hll32.go b/stat/card/hll32.go index db3095a7..e62e3279 100644 --- a/stat/card/hll32.go +++ b/stat/card/hll32.go @@ -179,9 +179,6 @@ func (h *HyperLogLog32) MarshalBinary() ([]byte, error) { // return. The receiver must have a non-nil hash function value that is // the same type as the one that was stored in the binary data. func (h *HyperLogLog32) UnmarshalBinary(b []byte) error { - if h.hash == nil { - return errors.New("card: hash function not set") - } dec := gob.NewDecoder(bytes.NewReader(b)) var size uint8 err := dec.Decode(&size) @@ -196,9 +193,16 @@ func (h *HyperLogLog32) UnmarshalBinary(b []byte) error { if err != nil { return err } - dstHash := typeNameOf(h.hash) - if dstHash != srcHash { - return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash) + if h.hash == nil { + h.hash = hash32For(srcHash) + if h.hash == nil { + return fmt.Errorf("card: hash function not set and no hash registered for %q", srcHash) + } + } else { + dstHash := typeNameOf(h.hash) + if dstHash != srcHash { + return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash) + } } err = dec.Decode(&h.p) if err != nil { @@ -212,3 +216,12 @@ func (h *HyperLogLog32) UnmarshalBinary(b []byte) error { } return nil } + +func hash32For(name string) hash.Hash32 { + fn, ok := hashes.Load(name) + if !ok { + return nil + } + h, _ := fn.(userType).fn.Call(nil)[0].Interface().(hash.Hash32) + return h +} diff --git a/stat/card/hll64.go b/stat/card/hll64.go index 4adeeb83..39d154b9 100644 --- a/stat/card/hll64.go +++ b/stat/card/hll64.go @@ -181,9 +181,6 @@ func (h *HyperLogLog64) MarshalBinary() ([]byte, error) { // return. The receiver must have a non-nil hash function value that is // the same type as the one that was stored in the binary data. func (h *HyperLogLog64) UnmarshalBinary(b []byte) error { - if h.hash == nil { - return errors.New("card: hash function not set") - } dec := gob.NewDecoder(bytes.NewReader(b)) var size uint8 err := dec.Decode(&size) @@ -198,9 +195,16 @@ func (h *HyperLogLog64) UnmarshalBinary(b []byte) error { if err != nil { return err } - dstHash := typeNameOf(h.hash) - if dstHash != srcHash { - return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash) + if h.hash == nil { + h.hash = hash64For(srcHash) + if h.hash == nil { + return fmt.Errorf("card: hash function not set and no hash registered for %q", srcHash) + } + } else { + dstHash := typeNameOf(h.hash) + if dstHash != srcHash { + return fmt.Errorf("card: mismatched hash function: dst=%s src=%s", dstHash, srcHash) + } } err = dec.Decode(&h.p) if err != nil { @@ -214,3 +218,12 @@ func (h *HyperLogLog64) UnmarshalBinary(b []byte) error { } return nil } + +func hash64For(name string) hash.Hash64 { + fn, ok := hashes.Load(name) + if !ok { + return nil + } + h, _ := fn.(userType).fn.Call(nil)[0].Interface().(hash.Hash64) + return h +}