diff --git a/go.mod b/go.mod index bef71a6..420b471 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/johannesboyne/gofakes3 +go 1.12 + require ( github.com/aws/aws-sdk-go v1.17.4 github.com/boltdb/bolt v1.3.1 diff --git a/gofakes3.go b/gofakes3.go index 25c6ec0..1735484 100644 --- a/gofakes3.go +++ b/gofakes3.go @@ -529,11 +529,17 @@ func (g *GoFakeS3) createObject(bucket, object string, w http.ResponseWriter, r return err } - size, err := strconv.ParseInt(r.Header.Get("Content-Length"), 10, 64) - if err != nil || size <= 0 { + contentLength := r.Header.Get("Content-Length") + if contentLength == "" { return ErrMissingContentLength } + size, err := strconv.ParseInt(contentLength, 10, 64) + if err != nil || size < 0 { + w.WriteHeader(http.StatusBadRequest) // XXX: no code for this, according to s3tests + return nil + } + if len(object) > KeySizeLimit { return ResourceError(ErrKeyTooLong, object) } diff --git a/gofakes3_test.go b/gofakes3_test.go index c486121..9f09c65 100644 --- a/gofakes3_test.go +++ b/gofakes3_test.go @@ -1,6 +1,7 @@ package gofakes3_test import ( + "bufio" "bytes" "encoding/xml" "fmt" @@ -182,6 +183,51 @@ func TestCreateObjectMD5(t *testing.T) { } } +func TestCreateObjectWithMissingContentLength(t *testing.T) { + ts := newTestServer(t) + defer ts.Close() + client := ts.rawClient() + body := []byte{} + rq, err := http.NewRequest("PUT", client.URL(fmt.Sprintf("/%s/yep", defaultBucket)).String(), maskReader(bytes.NewReader(body))) + if err != nil { + panic(err) + } + client.SetHeaders(rq, body) + rs, _ := client.Do(rq) + if rs.StatusCode != http.StatusLengthRequired { + t.Fatal() + } +} + +func TestCreateObjectWithInvalidContentLength(t *testing.T) { + ts := newTestServer(t) + defer ts.Close() + client := ts.rawClient() + + body := []byte{1, 2, 3} + rq, err := http.NewRequest("PUT", client.URL(fmt.Sprintf("/%s/yep", defaultBucket)).String(), maskReader(bytes.NewReader(body))) + if err != nil { + panic(err) + } + + client.SetHeaders(rq, body) + rq.Header.Set("Content-Length", "quack") + raw, err := client.SendRaw(rq) + if err != nil { + t.Fatal(err) + } + + rs, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(raw)), rq) + if err != nil { + t.Fatal(err) + } + defer rs.Body.Close() + + if rs.StatusCode != http.StatusBadRequest { + t.Fatal(rs.StatusCode, "!=", http.StatusBadRequest) + } +} + func TestDeleteBucket(t *testing.T) { t.Run("delete-empty", func(t *testing.T) { ts := newTestServer(t, withoutInitialBuckets()) diff --git a/init_test.go b/init_test.go index cb9aeeb..a78fd9f 100644 --- a/init_test.go +++ b/init_test.go @@ -7,6 +7,7 @@ package gofakes3_test import ( "bytes" "crypto/md5" + "crypto/sha256" "encoding/base64" "encoding/hex" "flag" @@ -14,11 +15,16 @@ import ( "io" "io/ioutil" "log" + "net" "net/http" "net/http/httptest" + "net/http/httputil" + "net/url" "os" + "path" "reflect" "sort" + "strconv" "strings" "sync" "testing" @@ -150,7 +156,9 @@ type testServer struct { type testServerOption func(ts *testServer) -func withoutInitialBuckets() testServerOption { return func(ts *testServer) { ts.initialBuckets = nil } } +func withoutInitialBuckets() testServerOption { + return func(ts *testServer) { ts.initialBuckets = nil } +} func withInitialBuckets(buckets ...string) testServerOption { return func(ts *testServer) { ts.initialBuckets = buckets } } @@ -295,6 +303,10 @@ func (ts *testServer) assertLs(bucket string, prefix string, expectedPrefixes [] ls.assertContents(ts.TT, expectedPrefixes, expectedObjects) } +func (ts *testServer) rawClient() *rawClient { + return newRawClient(httpClient(), ts.server.URL) +} + type multipartUploadOptions struct { partSize int64 } @@ -716,6 +728,12 @@ func hashMD5Bytes(body []byte) hashValue { return hashValue(h.Sum(nil)) } +func hashSHA256Bytes(body []byte) hashValue { + h := sha256.New() + h.Write(body) + return hashValue(h.Sum(nil)) +} + type hashValue []byte func (h hashValue) Base64() string { return base64.StdEncoding.EncodeToString(h) } @@ -810,3 +828,102 @@ func (b *backendWithUnimplementedPaging) ListBucket(name string, prefix *gofakes } return b.Backend.ListBucket(name, prefix, page) } + +type rawClient struct { + client *http.Client + base *url.URL +} + +func newRawClient(client *http.Client, base string) *rawClient { + u, err := url.Parse(base) + if err != nil { + panic(err) + } + return &rawClient{client: client, base: u} +} + +func (c *rawClient) URL(rqpath string) *url.URL { + u, err := url.Parse(c.base.String()) + if err != nil { + panic(err) + } + u.Path = path.Join(u.Path, rqpath) + return u +} + +func (c *rawClient) Request(method, rqpath string, body []byte) *http.Request { + u := c.URL(rqpath) + rq, err := http.NewRequest(method, u.String(), bytes.NewReader(body)) + if err != nil { + panic(err) + } + c.SetHeaders(rq, body) + return rq +} + +func (c *rawClient) SetHeaders(rq *http.Request, body []byte) { + // NOTE: This was put together by using httputil.DumpRequest inside routeBase(). We + // don't currently implement the Authorization header, so that has been skimmed for + // now. + rq.Header.Set("Accept-Encoding", "gzip") + rq.Header.Set("Authorization", "...") // TODO + rq.Header.Set("Content-Length", strconv.FormatInt(int64(len(body)), 10)) + rq.Header.Set("Content-Md5", hashMD5Bytes(body).Base64()) + rq.Header.Set("User-Agent", "aws-sdk-go/1.17.4 (go1.14rc1; linux; amd64)") + rq.Header.Set("X-Amz-Date", time.Now().In(time.UTC).Format("20060102T030405-0700")) + rq.Header.Set("X-Amz-Content-Sha256", hashSHA256Bytes(body).Hex()) +} + +// SendRaw can be used to bypass Go's http client, which helps us out a lot by taking +// care of some things for us, but which we actually want to test messing up from +// time to time. +func (c *rawClient) SendRaw(rq *http.Request) ([]byte, error) { + b, err := httputil.DumpRequest(rq, true) + if err != nil { + return nil, err + } + conn, err := net.DialTimeout("tcp", c.base.Host, 2*time.Second) + if err != nil { + return nil, err + } + defer conn.Close() + + deadline := time.Now().Add(2 * time.Second) + conn.SetDeadline(deadline) + if _, err := conn.Write(b); err != nil { + return nil, err + } + + var rs []byte + var scratch = make([]byte, 1024) + for { + n, err := conn.Read(scratch) + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + rs = append(rs, scratch[:n]...) + } + + return rs, nil +} + +func (c *rawClient) Do(rq *http.Request) (*http.Response, error) { + return c.client.Do(rq) +} + +func maskReader(r io.Reader) io.Reader { + // http.NewRequest() forces a ContentLength if it recognises + // the type of reader you pass as the body. This is a cheeky + // way to bypass that: + return &maskedReader{r} +} + +type maskedReader struct { + inner io.Reader +} + +func (r *maskedReader) Read(b []byte) (n int, err error) { + return r.inner.Read(b) +}