diff --git a/.travis.yml b/.travis.yml index 5b4ed2ae..3196115f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,7 @@ sudo: false language: go go: - - 1.9.x + - 1.11.x install: - make deps diff --git a/cli/responseemitter.go b/cli/responseemitter.go index 03a9b5f8..bc7886dd 100644 --- a/cli/responseemitter.go +++ b/cli/responseemitter.go @@ -66,10 +66,6 @@ func (re *responseEmitter) SetLength(l uint64) { re.length = l } -func (re *responseEmitter) SetEncoder(enc func(io.Writer) cmds.Encoder) { - re.enc = enc(re.stdout) -} - func (re *responseEmitter) CloseWithError(err error) error { if err == nil { return re.Close() diff --git a/command.go b/command.go index 09e54801..3fe507cb 100644 --- a/command.go +++ b/command.go @@ -104,20 +104,6 @@ func (c *Command) call(req *Request, re ResponseEmitter, env Environment) error return err } - // If this ResponseEmitter encodes messages (e.g. http, cli or writer - but not chan), - // we need to update the encoding to the one specified by the command. - if re_, ok := re.(EncodingEmitter); ok { - encType := GetEncoding(req) - - if enc, ok := cmd.Encoders[encType]; ok { - re_.SetEncoder(enc(req)) - } else if enc, ok := Encoders[encType]; ok { - re_.SetEncoder(enc(req)) - } else { - return fmt.Errorf("unknown encoding %q", encType) - } - } - return cmd.Run(req, re, env) } diff --git a/encoding.go b/encoding.go index 8def5cf4..06a5c6eb 100644 --- a/encoding.go +++ b/encoding.go @@ -4,6 +4,7 @@ import ( "encoding/json" "encoding/xml" "fmt" + "github.com/ipfs/go-ipfs-cmdkit" "io" "reflect" ) @@ -130,3 +131,32 @@ func (e TextEncoder) Encode(v interface{}) error { _, err := fmt.Fprintf(e.w, "%s%s", v, e.suffix) return err } + +// GetEncoders takes a request and returns returns the encoding type, an error encoder, and a value encoder. +func GetEncoders(req *Request, w io.Writer) (encType EncodingType, valEnc, errEnc Encoder, err error) { + encType = GetEncoding(req) + + if fn, ok := Encoders[encType]; ok { + errEnc = fn(req)(w) + } else { + return encType, nil, nil, cmdkit.Errorf(cmdkit.ErrClient, "invalid encoding: %s", encType) + } + + // Only override the value encoder. + if fn, ok := req.Command.Encoders[encType]; ok { + valEnc = fn(req)(w) + } else { + valEnc = errEnc + } + return encType, valEnc, errEnc, nil +} + +// GetDecoder takes a request and returns the encoding type and the decoder. +func GetDecoder(req *Request, r io.Reader) (encType EncodingType, dec Decoder, err error) { + encType = GetEncoding(req) + + if fn, ok := Decoders[encType]; ok { + return encType, fn(r), nil + } + return encType, nil, cmdkit.Errorf(cmdkit.ErrClient, "invalid encoding: %s", encType) +} diff --git a/executor.go b/executor.go index 35567c07..d9df7819 100644 --- a/executor.go +++ b/executor.go @@ -49,26 +49,6 @@ func (x *executor) Execute(req *Request, re ResponseEmitter, env Environment) (e return err } - // If this ResponseEmitter encodes messages (e.g. http, cli or writer - but not chan), - // we need to update the encoding to the one specified by the command. - if ee, ok := re.(EncodingEmitter); ok { - encType := GetEncoding(req) - - // use JSON if text was requested but the command doesn't have a text-encoder - if _, ok := cmd.Encoders[encType]; encType == Text && !ok { - encType = JSON - } - - if enc, ok := cmd.Encoders[encType]; ok { - ee.SetEncoder(enc(req)) - } else if enc, ok := Encoders[encType]; ok { - ee.SetEncoder(enc(req)) - } else { - log.Errorf("unknown encoding %q, using json", encType) - ee.SetEncoder(Encoders[JSON](req)) - } - } - if cmd.PreRun != nil { err = cmd.PreRun(req, env) if err != nil { diff --git a/executor_test.go b/executor_test.go index 8fa021ac..255773fe 100644 --- a/executor_test.go +++ b/executor_test.go @@ -50,7 +50,10 @@ func TestExecutor(t *testing.T) { } var buf bytes.Buffer - re := NewWriterResponseEmitter(wc{&buf, nopCloser{}}, req, Encoders[Text]) + re, err := NewWriterResponseEmitter(wc{&buf, nopCloser{}}, req) + if err != nil { + t.Fatal(err) + } x := NewExecutor(root) x.Execute(req, re, &env) diff --git a/http/client.go b/http/client.go index 0645af75..cfa2ff41 100644 --- a/http/client.go +++ b/http/client.go @@ -75,22 +75,6 @@ func NewClient(address string, opts ...ClientOpt) Client { func (c *client) Execute(req *cmds.Request, re cmds.ResponseEmitter, env cmds.Environment) error { cmd := req.Command - // If this ResponseEmitter encodes messages (e.g. http, cli or writer - but not chan), - // we need to update the encoding to the one specified by the command. - if ee, ok := re.(cmds.EncodingEmitter); ok { - encType := cmds.GetEncoding(req) - - // note the difference: cmd.Encoders vs. cmds.Encoders - if enc, ok := cmd.Encoders[encType]; ok { - ee.SetEncoder(enc(req)) - } else if enc, ok := cmds.Encoders[encType]; ok { - ee.SetEncoder(enc(req)) - } else { - log.Errorf("unknown encoding %q, using json", encType) - ee.SetEncoder(cmds.Encoders[cmds.JSON](req)) - } - } - if cmd.PreRun != nil { err := cmd.PreRun(req, env) if err != nil { diff --git a/http/errors_test.go b/http/errors_test.go index c1638050..6b3e6feb 100644 --- a/http/errors_test.go +++ b/http/errors_test.go @@ -9,11 +9,13 @@ import ( "strings" "testing" + "github.com/ipfs/go-ipfs-cmdkit" "github.com/ipfs/go-ipfs-cmds" ) func TestErrors(t *testing.T) { type testcase struct { + opts cmdkit.OptMap path []string bodyStr string status string @@ -46,6 +48,25 @@ func TestErrors(t *testing.T) { errTrailer: "an error occurred", }, + { + path: []string{"encode"}, + opts: cmdkit.OptMap{ + cmds.EncLong: cmds.Text, + }, + status: "500 Internal Server Error", + bodyStr: "an error occurred", + }, + + { + path: []string{"lateencode"}, + opts: cmdkit.OptMap{ + cmds.EncLong: cmds.Text, + }, + status: "200 OK", + bodyStr: "hello\n", + errTrailer: "an error occurred", + }, + { path: []string{"doubleclose"}, status: "200 OK", @@ -69,7 +90,7 @@ func TestErrors(t *testing.T) { return func(t *testing.T) { _, srv := getTestServer(t, nil) // handler_test:/^func getTestServer/ c := NewClient(srv.URL) - req, err := cmds.NewRequest(context.Background(), tc.path, nil, nil, nil, cmdRoot) + req, err := cmds.NewRequest(context.Background(), tc.path, tc.opts, nil, nil, cmdRoot) if err != nil { t.Fatal(err) } diff --git a/http/handler.go b/http/handler.go index 2f6a3188..03e2e4e6 100644 --- a/http/handler.go +++ b/http/handler.go @@ -157,7 +157,11 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - re := NewResponseEmitter(w, r.Method, req) + re, err := NewResponseEmitter(w, r.Method, req) + if err != nil { + re.CloseWithError(err) + return + } h.root.Call(req, re, h.env) } diff --git a/http/handler_test.go b/http/handler_test.go index 0123b27f..b71d5ce5 100644 --- a/http/handler_test.go +++ b/http/handler_test.go @@ -82,6 +82,34 @@ var ( }, Type: "", }, + "encode": &cmds.Command{ + Run: func(req *cmds.Request, re cmds.ResponseEmitter, env cmds.Environment) error { + return errors.New("an error occurred") + }, + Type: "", + Encoders: cmds.EncoderMap{ + cmds.Text: cmds.MakeTypedEncoder(func(req *cmds.Request, w io.Writer, v string) error { + fmt.Fprintln(w, v) + return nil + }), + }, + }, + "lateencode": &cmds.Command{ + Run: func(req *cmds.Request, re cmds.ResponseEmitter, env cmds.Environment) error { + re.Emit("hello") + return errors.New("an error occurred") + }, + Type: "", + Encoders: cmds.EncoderMap{ + cmds.Text: cmds.MakeTypedEncoder(func(req *cmds.Request, w io.Writer, v string) error { + fmt.Fprintln(w, v) + if v != "hello" { + return fmt.Errorf("expected hello, got %s", v) + } + return nil + }), + }, + }, "doubleclose": &cmds.Command{ Run: func(req *cmds.Request, re cmds.ResponseEmitter, env cmds.Environment) error { t, ok := getTestingT(env) diff --git a/http/responseemitter.go b/http/responseemitter.go index 087c91c4..aca0fd86 100644 --- a/http/responseemitter.go +++ b/http/responseemitter.go @@ -28,23 +28,20 @@ var ( ) // NewResponeEmitter returns a new ResponseEmitter. -func NewResponseEmitter(w http.ResponseWriter, method string, req *cmds.Request) ResponseEmitter { - encType := cmds.GetEncoding(req) - - var enc cmds.Encoder - - if _, ok := cmds.Encoders[encType]; ok { - enc = cmds.Encoders[encType](req)(w) +func NewResponseEmitter(w http.ResponseWriter, method string, req *cmds.Request) (ResponseEmitter, error) { + encType, valEnc, errEnc, err := cmds.GetEncoders(req, w) + if err != nil { + return nil, err } - re := &responseEmitter{ w: w, encType: encType, - enc: enc, + errEnc: errEnc, + valEnc: valEnc, method: method, req: req, } - return re + return re, nil } type ResponseEmitter interface { @@ -55,7 +52,8 @@ type ResponseEmitter interface { type responseEmitter struct { w http.ResponseWriter - enc cmds.Encoder + errEnc cmds.Encoder + valEnc cmds.Encoder // overrides the normal encoder encType cmds.EncodingType req *cmds.Request @@ -121,7 +119,7 @@ func (re *responseEmitter) Emit(value interface{}) error { case io.Reader: err = flushCopy(re.w, v) default: - err = re.enc.Encode(value) + err = re.valEnc.Encode(value) } if isSingle && err == nil { @@ -259,7 +257,7 @@ func (re *responseEmitter) doPreamble(value interface{}) { err = &cmdkit.Error{Message: err.Error()} } - err = re.enc.Encode(err) + err = re.errEnc.Encode(err) if err != nil { log.Error("error sending error value after non-200 response", err) } @@ -272,10 +270,6 @@ type responseWriterer interface { Lower() http.ResponseWriter } -func (re *responseEmitter) SetEncoder(enc func(io.Writer) cmds.Encoder) { - re.enc = enc(re.w) -} - func flushCopy(w io.Writer, r io.Reader) error { buf := make([]byte, 4096) f, ok := w.(http.Flusher) diff --git a/response_test.go b/response_test.go index 0f0d4258..1c690d8f 100644 --- a/response_test.go +++ b/response_test.go @@ -38,7 +38,10 @@ func TestMarshalling(t *testing.T) { buf := bytes.NewBuffer(nil) wc := writecloser{Writer: buf, Closer: nopCloser{}} - re := NewWriterResponseEmitter(wc, req, Encoders[JSON]) + re, err := NewWriterResponseEmitter(wc, req) + if err != nil { + t.Fatal(err) + } err = re.Emit(TestOutput{"beep", "boop", 1337}) if err != nil { diff --git a/responseemitter.go b/responseemitter.go index b9b3e32c..65fb6c53 100644 --- a/responseemitter.go +++ b/responseemitter.go @@ -50,12 +50,6 @@ type ResponseEmitter interface { Emit(value interface{}) error } -type EncodingEmitter interface { - ResponseEmitter - - SetEncoder(func(io.Writer) Encoder) -} - // Copy sends all values received on res to re. If res is closed, it closes re. func Copy(re ResponseEmitter, res Response) error { re.SetLength(res.Length()) diff --git a/single_test.go b/single_test.go index f88a9a24..4d9dc324 100644 --- a/single_test.go +++ b/single_test.go @@ -59,8 +59,14 @@ func TestSingleWriter(t *testing.T) { } pr, pw := io.Pipe() - re := NewWriterResponseEmitter(pw, req, Encoders["json"]) - res := NewReaderResponse(pr, "json", req) + re, err := NewWriterResponseEmitter(pw, req) + if err != nil { + t.Fatal(err) + } + res, err := NewReaderResponse(pr, req) + if err != nil { + t.Fatal(err) + } var wg sync.WaitGroup diff --git a/writer.go b/writer.go index b8175df4..76d89024 100644 --- a/writer.go +++ b/writer.go @@ -11,28 +11,34 @@ import ( "github.com/ipfs/go-ipfs-cmds/debug" ) -func NewWriterResponseEmitter(w io.WriteCloser, req *Request, enc func(*Request) func(io.Writer) Encoder) ResponseEmitter { +func NewWriterResponseEmitter(w io.WriteCloser, req *Request) (ResponseEmitter, error) { + _, valEnc, _, err := GetEncoders(req, w) + if err != nil { + return nil, err + } + re := &writerResponseEmitter{ w: w, c: w, req: req, + enc: valEnc, } - if enc != nil { - re.enc = enc(req)(w) - } - - return re + return re, nil } -func NewReaderResponse(r io.Reader, encType EncodingType, req *Request) Response { +func NewReaderResponse(r io.Reader, req *Request) (Response, error) { + encType, dec, err := GetDecoder(req, r) + if err != nil { + return nil, err + } return &readerResponse{ req: req, r: r, encType: encType, - dec: Decoders[encType](r), + dec: dec, emitted: make(chan struct{}), - } + }, nil } type readerResponse struct { @@ -101,10 +107,6 @@ type writerResponseEmitter struct { closed bool } -func (re *writerResponseEmitter) SetEncoder(mkEnc func(io.Writer) Encoder) { - re.enc = mkEnc(re.w) -} - func (re *writerResponseEmitter) CloseWithError(err error) error { if re.closed { return ErrClosingClosedEmitter