diff --git a/filetype.go b/filetype.go index 933058c..c99691e 100644 --- a/filetype.go +++ b/filetype.go @@ -29,8 +29,8 @@ func AddType(ext, mime string) types.Type { // Is checks if a given buffer matches with the given file type extension func Is(buf []byte, ext string) bool { - kind, ok := types.Types[ext] - if ok { + kind := types.Get(ext) + if kind != types.Unknown { return IsType(buf, kind) } return false @@ -52,33 +52,48 @@ func IsType(buf []byte, kind types.Type) bool { // IsMIME checks if a given buffer matches with the given MIME type func IsMIME(buf []byte, mime string) bool { - for _, kind := range types.Types { + result := false + types.Types.Range(func(k, v interface{}) bool { + kind := v.(types.Type) if kind.MIME.Value == mime { matcher := matchers.Matchers[kind] - return matcher(buf) != types.Unknown + result = matcher(buf) != types.Unknown + return false } - } - return false + return true + }) + + return result } // IsSupported checks if a given file extension is supported func IsSupported(ext string) bool { - for name := range Types { - if name == ext { - return true + result := false + types.Types.Range(func(k, v interface{}) bool { + key := k.(string) + if key == ext { + result = true + return false } - } - return false + return true + }) + + return result } // IsMIMESupported checks if a given MIME type is supported func IsMIMESupported(mime string) bool { - for _, m := range Types { - if m.MIME.Value == mime { - return true + result := false + types.Types.Range(func(k, v interface{}) bool { + kind := v.(types.Type) + if kind.MIME.Value == mime { + result = true + return false } - } - return false + return true + }) + + return result } // GetType retrieves a Type by file extension diff --git a/filetype_test.go b/filetype_test.go index f409784..49fffaa 100644 --- a/filetype_test.go +++ b/filetype_test.go @@ -2,10 +2,26 @@ package filetype import ( "testing" + "time" "github.com/h2non/filetype/types" ) +func TestConcurrent(t *testing.T) { + go func() { + for i := 0; i < 10000; i++ { + types.NewType("xml", "text/xml") + } + }() + go func() { + for i := 0; i < 10000; i++ { + types.NewType("xml", "text/xml") + } + }() + + time.Sleep(time.Second * 2) +} + func TestIs(t *testing.T) { cases := []struct { buf []byte @@ -22,6 +38,7 @@ func TestIs(t *testing.T) { t.Fatalf("Invalid match: %s", test.ext) } } + } func TestIsType(t *testing.T) { diff --git a/types/types.go b/types/types.go index 27d433e..f59e256 100644 --- a/types/types.go +++ b/types/types.go @@ -1,18 +1,23 @@ package types -var Types = make(map[string]Type) +import "sync" + +// Types Support concurrent map writes +var Types sync.Map // Add registers a new type in the package func Add(t Type) Type { - Types[t.Extension] = t + Types.Store(t.Extension, t) return t } // Get retrieves a Type by extension func Get(ext string) Type { - kind := Types[ext] - if kind.Extension != "" { - return kind + if tmp, ok := Types.Load(ext); ok { + kind := tmp.(Type) + if kind.Extension != "" { + return kind + } } return Unknown }