diff --git a/go/hummingbird/multipart.go b/go/hummingbird/multipart.go new file mode 100644 index 0000000000..cfb2aa87bb --- /dev/null +++ b/go/hummingbird/multipart.go @@ -0,0 +1,116 @@ +// Copyright (c) 2015 Rackspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TODO: Some of this code was pulled from the go stdlib and modified. figure out how to attribute this. +// https://wiki.openstack.org/wiki/LegalIssuesFAQ#Incorporating_BSD.2FMIT_Licensed_Code + +package hummingbird + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "net/textproto" +) + +type MultiWriter struct { + w io.Writer + boundary string + lastpart *part +} + +func NewMultiWriter(w io.Writer) *MultiWriter { + var buf [32]byte + _, err := io.ReadFull(rand.Reader, buf[:]) + if err != nil { + panic(err) + } + return &MultiWriter{ + w: w, + boundary: fmt.Sprintf("%x", buf[:]), + } +} + +func (w *MultiWriter) Boundary() string { + return w.boundary +} + +func (w *MultiWriter) CreatePart(header textproto.MIMEHeader) (io.Writer, error) { + if w.lastpart != nil { + if err := w.lastpart.close(); err != nil { + return nil, err + } + } + b := &bytes.Buffer{} + if w.lastpart != nil { + fmt.Fprintf(b, "\r\n--%s\r\n", w.boundary) + } else { + fmt.Fprintf(b, "--%s\r\n", w.boundary) + } + for k, vv := range header { + for _, v := range vv { + fmt.Fprintf(b, "%s: %s\r\n", k, v) + } + } + fmt.Fprintf(b, "\r\n") + _, err := io.Copy(w.w, b) + if err != nil { + return nil, err + } + p := &part{ + mw: w, + } + w.lastpart = p + return p, nil +} + +func (w *MultiWriter) Close() error { + if w.lastpart != nil { + if err := w.lastpart.close(); err != nil { + return err + } + w.lastpart = nil + } + _, err := fmt.Fprintf(w.w, "\r\n--%s--", w.boundary) + return err +} + +type part struct { + mw *MultiWriter + closed bool + we error +} + +func (p *part) close() error { + p.closed = true + return p.we +} + +func (p *part) Write(d []byte) (n int, err error) { + if p.closed { + return 0, errors.New("multipart: can't write to finished part") + } + n, err = p.mw.w.Write(d) + if err != nil { + p.we = err + } + return +} diff --git a/go/hummingbird/multipart_test.go b/go/hummingbird/multipart_test.go new file mode 100644 index 0000000000..4c19e6d35d --- /dev/null +++ b/go/hummingbird/multipart_test.go @@ -0,0 +1,76 @@ +// Copyright (c) 2015 Rackspace +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +// implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hummingbird + +import ( + "bytes" + "errors" + "net/textproto" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMultiWriter(t *testing.T) { + w := &bytes.Buffer{} + mw := NewMultiWriter(w) + + boundary := mw.Boundary() + + p, _ := mw.CreatePart(textproto.MIMEHeader{"Content-Type": []string{"text/plain"}}) + p.Write([]byte("HI")) + + p, _ = mw.CreatePart(textproto.MIMEHeader{"Content-Type": []string{"text/plain"}}) + p.Write([]byte("THERE")) + + mw.Close() + shouldBe := "--" + boundary + "\r\nContent-Type: text/plain\r\n\r\nHI\r\n--" + boundary + "\r\nContent-Type: text/plain\r\n\r\nTHERE\r\n--" + boundary + "--" + assert.Equal(t, shouldBe, string(w.Bytes())) +} + +func TestMultiWriterClosedPart(t *testing.T) { + w := &bytes.Buffer{} + mw := NewMultiWriter(w) + + p1, _ := mw.CreatePart(textproto.MIMEHeader{"Content-Type": []string{"text/plain"}}) + mw.CreatePart(textproto.MIMEHeader{"Content-Type": []string{"text/plain"}}) + + _, err := p1.Write([]byte("HI")) + assert.NotNil(t, err) +} + +type FailWriter struct { + n int +} + +func (f *FailWriter) Write(d []byte) (n int, err error) { + if f.n > 0 { + return 0, errors.New("SOME ERROR") + } + f.n += 1 + return len(d), nil +} + +func TestMultiWriterFails(t *testing.T) { + mw := NewMultiWriter(&FailWriter{0}) + + p, _ := mw.CreatePart(textproto.MIMEHeader{"Content-Type": []string{"text/plain"}}) + _, err := p.Write([]byte("HI")) + assert.NotNil(t, err) + assert.NotNil(t, mw.Close()) + _, err = mw.CreatePart(textproto.MIMEHeader{"Content-Type": []string{"text/plain"}}) + assert.NotNil(t, err) +} diff --git a/go/objectserver/main.go b/go/objectserver/main.go index 3077916bdc..43000c173b 100644 --- a/go/objectserver/main.go +++ b/go/objectserver/main.go @@ -24,7 +24,6 @@ import ( "fmt" "io" "log/syslog" - "mime/multipart" "net" "net/http" _ "net/http/pprof" @@ -178,8 +177,8 @@ func (server *ObjectServer) ObjGetHandler(writer http.ResponseWriter, request *h hummingbird.CopyN(file, ranges[0].End-ranges[0].Start, writer) return } else if ranges != nil && len(ranges) > 1 { - w := multipart.NewWriter(writer) - responseLength := int64(6 + len(w.Boundary()) + (len(w.Boundary())+len(metadata["Content-Type"])+47)*len(ranges)) + w := hummingbird.NewMultiWriter(writer) + responseLength := int64(4 + len(w.Boundary()) + (len(w.Boundary())+len(metadata["Content-Type"])+47)*len(ranges)) for _, rng := range ranges { responseLength += int64(len(fmt.Sprintf("%d-%d/%d", rng.Start, rng.End-1, contentLength))) + rng.End - rng.Start } diff --git a/go/objectserver/main_test.go b/go/objectserver/main_test.go index ab8c260be8..c7e5e53bfe 100644 --- a/go/objectserver/main_test.go +++ b/go/objectserver/main_test.go @@ -223,7 +223,7 @@ func TestGetRanges(t *testing.T) { resp, body = getRanges("bytes=20-,-6") assert.Equal(t, http.StatusPartialContent, resp.StatusCode) assert.True(t, strings.HasPrefix(resp.Header.Get("Content-Type"), "multipart/byteranges;boundary=")) - assert.Equal(t, "356", resp.Header.Get("Content-Length")) + assert.Equal(t, "366", resp.Header.Get("Content-Length")) assert.Equal(t, 2, strings.Count(string(body), "UVWXYZ")) }