Skip to content

Commit 74a39f3

Browse files
committed
Turns out we do need explicit Content-Length for file uploads
This reverts commit 141388f.
1 parent 4d95349 commit 74a39f3

File tree

3 files changed

+104
-36
lines changed

3 files changed

+104
-36
lines changed

pkg/cmd/api/api.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,16 @@ func apiRun(opts *ApiOptions) error {
103103
}
104104

105105
if opts.RequestInputFile != "" {
106-
file, err := openUserFile(opts.RequestInputFile, opts.IO.In)
106+
file, size, err := openUserFile(opts.RequestInputFile, opts.IO.In)
107107
if err != nil {
108108
return err
109109
}
110110
defer file.Close()
111111
requestPath = addQuery(requestPath, params)
112112
requestBody = file
113+
if size >= 0 {
114+
requestHeaders = append([]string{fmt.Sprintf("Content-Length: %d", size)}, requestHeaders...)
115+
}
113116
}
114117

115118
httpClient, err := opts.HttpClient()
@@ -240,19 +243,36 @@ func magicFieldValue(v string, stdin io.ReadCloser) (interface{}, error) {
240243
}
241244

242245
func readUserFile(fn string, stdin io.ReadCloser) ([]byte, error) {
243-
r, err := openUserFile(fn, stdin)
244-
if err != nil {
245-
return nil, err
246+
var r io.ReadCloser
247+
if fn == "-" {
248+
r = stdin
249+
} else {
250+
var err error
251+
r, err = os.Open(fn)
252+
if err != nil {
253+
return nil, err
254+
}
246255
}
247256
defer r.Close()
248257
return ioutil.ReadAll(r)
249258
}
250259

251-
func openUserFile(fn string, stdin io.ReadCloser) (io.ReadCloser, error) {
260+
func openUserFile(fn string, stdin io.ReadCloser) (io.ReadCloser, int64, error) {
252261
if fn == "-" {
253-
return stdin, nil
262+
return stdin, -1, nil
254263
}
255-
return os.Open(fn)
264+
265+
r, err := os.Open(fn)
266+
if err != nil {
267+
return r, -1, err
268+
}
269+
270+
s, err := os.Stat(fn)
271+
if err != nil {
272+
return r, -1, err
273+
}
274+
275+
return r, s.Size(), nil
256276
}
257277

258278
func parseErrorResponse(r io.Reader, statusCode int) (io.Reader, string, error) {

pkg/cmd/api/api_test.go

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -247,41 +247,78 @@ func Test_apiRun(t *testing.T) {
247247
}
248248

249249
func Test_apiRun_inputFile(t *testing.T) {
250-
io, stdin, _, _ := iostreams.Test()
251-
resp := &http.Response{StatusCode: 204}
252-
253-
options := ApiOptions{
254-
RequestPath: "hello",
255-
RequestInputFile: "-",
256-
RawFields: []string{"a=b", "c=d"},
250+
tests := []struct {
251+
name string
252+
inputFile string
253+
inputContents []byte
257254

258-
IO: io,
259-
HttpClient: func() (*http.Client, error) {
260-
var tr roundTripper = func(req *http.Request) (*http.Response, error) {
261-
resp.Request = req
262-
return resp, nil
263-
}
264-
return &http.Client{Transport: tr}, nil
255+
contentLength int64
256+
expectedContents []byte
257+
}{
258+
{
259+
name: "stdin",
260+
inputFile: "-",
261+
inputContents: []byte("I WORK OUT"),
262+
contentLength: 0,
263+
},
264+
{
265+
name: "from file",
266+
inputFile: "gh-test-file",
267+
inputContents: []byte("I WORK OUT"),
268+
contentLength: 10,
265269
},
266270
}
271+
for _, tt := range tests {
272+
t.Run(tt.name, func(t *testing.T) {
273+
io, stdin, _, _ := iostreams.Test()
274+
resp := &http.Response{StatusCode: 204}
267275

268-
fmt.Fprintln(stdin, "I WORK OUT")
276+
inputFile := tt.inputFile
277+
if tt.inputFile == "-" {
278+
_, _ = stdin.Write(tt.inputContents)
279+
} else {
280+
f, err := ioutil.TempFile("", tt.inputFile)
281+
if err != nil {
282+
t.Fatal(err)
283+
}
284+
_, _ = f.Write(tt.inputContents)
285+
f.Close()
286+
t.Cleanup(func() { os.Remove(f.Name()) })
287+
inputFile = f.Name()
288+
}
269289

270-
err := apiRun(&options)
271-
if err != nil {
272-
t.Errorf("got error %v", err)
273-
}
290+
var bodyBytes []byte
291+
options := ApiOptions{
292+
RequestPath: "hello",
293+
RequestInputFile: inputFile,
294+
RawFields: []string{"a=b", "c=d"},
274295

275-
assert.Equal(t, "POST", resp.Request.Method)
276-
assert.Equal(t, "/hello?a=b&c=d", resp.Request.URL.RequestURI())
277-
assert.Equal(t, "", resp.Request.Header.Get("Content-Length"))
278-
assert.Equal(t, "", resp.Request.Header.Get("Content-Type"))
296+
IO: io,
297+
HttpClient: func() (*http.Client, error) {
298+
var tr roundTripper = func(req *http.Request) (*http.Response, error) {
299+
var err error
300+
if bodyBytes, err = ioutil.ReadAll(req.Body); err != nil {
301+
return nil, err
302+
}
303+
resp.Request = req
304+
return resp, nil
305+
}
306+
return &http.Client{Transport: tr}, nil
307+
},
308+
}
279309

280-
bb, err := ioutil.ReadAll(resp.Request.Body)
281-
if err != nil {
282-
t.Errorf("got error %v", err)
310+
err := apiRun(&options)
311+
if err != nil {
312+
t.Errorf("got error %v", err)
313+
}
314+
315+
assert.Equal(t, "POST", resp.Request.Method)
316+
assert.Equal(t, "/hello?a=b&c=d", resp.Request.URL.RequestURI())
317+
assert.Equal(t, tt.contentLength, resp.Request.ContentLength)
318+
assert.Equal(t, "", resp.Request.Header.Get("Content-Type"))
319+
assert.Equal(t, tt.inputContents, bodyBytes)
320+
})
283321
}
284-
assert.Equal(t, "I WORK OUT\n", string(bb))
285322
}
286323

287324
func Test_parseFields(t *testing.T) {
@@ -400,7 +437,7 @@ func Test_openUserFile(t *testing.T) {
400437
f.Close()
401438
t.Cleanup(func() { os.Remove(f.Name()) })
402439

403-
file, err := openUserFile(f.Name(), nil)
440+
file, length, err := openUserFile(f.Name(), nil)
404441
if err != nil {
405442
t.Fatal(err)
406443
}
@@ -411,5 +448,6 @@ func Test_openUserFile(t *testing.T) {
411448
t.Fatal(err)
412449
}
413450

451+
assert.Equal(t, int64(13), length)
414452
assert.Equal(t, "file contents", string(fb))
415453
}

pkg/cmd/api/http.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net/http"
99
"net/url"
10+
"strconv"
1011
"strings"
1112
)
1213

@@ -62,7 +63,16 @@ func httpRequest(client *http.Client, method string, p string, params interface{
6263
if idx == -1 {
6364
return nil, fmt.Errorf("header %q requires a value separated by ':'", h)
6465
}
65-
req.Header.Add(h[0:idx], strings.TrimSpace(h[idx+1:]))
66+
name, value := h[0:idx], strings.TrimSpace(h[idx+1:])
67+
if strings.EqualFold(name, "Content-Length") {
68+
length, err := strconv.ParseInt(value, 10, 0)
69+
if err != nil {
70+
return nil, err
71+
}
72+
req.ContentLength = length
73+
} else {
74+
req.Header.Add(name, value)
75+
}
6676
}
6777
if bodyIsJSON && req.Header.Get("Content-Type") == "" {
6878
req.Header.Set("Content-Type", "application/json; charset=utf-8")

0 commit comments

Comments
 (0)