diff --git a/mrpc/jstreamrpc/jstreamrpc.go b/mrpc/jstreamrpc/jstreamrpc.go index b3f9e35..5b07c7c 100644 --- a/mrpc/jstreamrpc/jstreamrpc.go +++ b/mrpc/jstreamrpc/jstreamrpc.go @@ -6,7 +6,6 @@ package jstreamrpc import ( "context" "errors" - "io" "github.com/mediocregopher/mediocre-go-lib/jstream" "github.com/mediocregopher/mediocre-go-lib/mrpc" @@ -39,18 +38,8 @@ const ( func unmarshalBody(i interface{}, el jstream.Element) error { switch iT := i.(type) { - case func(*jstream.StreamReader) error: - stream, err := el.DecodeStream() - if err != nil { - return err - } - return iT(stream) - case *io.Reader: - ioR, err := el.DecodeBytes() - if err != nil { - return err - } - *iT = ioR + case *jstream.Element: + *iT = el return nil default: return el.DecodeValue(i) @@ -60,9 +49,7 @@ func unmarshalBody(i interface{}, el jstream.Element) error { func marshalBody(w *jstream.StreamWriter, i interface{}) error { switch iT := i.(type) { case func(*jstream.StreamWriter) error: - return w.EncodeStream(0, iT) - case io.Reader: - return w.EncodeBytes(0, iT) + return iT(w) default: return w.EncodeValue(iT) } @@ -88,17 +75,20 @@ func HandleCall( return errors.New("request head missing 'method' field") } - var didReadBody bool ctx = context.WithValue(ctx, ctxValR, r) ctx = context.WithValue(ctx, ctxValW, w) + body := r.Next() + if body.Err != nil { + return body.Err + } + rw := new(mrpc.ResponseWriter) h.ServeRPC(mrpc.Request{ Context: ctx, Method: head.Method, Unmarshal: func(i interface{}) error { - didReadBody = true - return unmarshalBody(i, r.Next()) + return unmarshalBody(i, body) }, Debug: head.debug.Debug, }, rw) @@ -118,12 +108,9 @@ func HandleCall( } } - // Reading the tail (and maybe discarding the body) should only be done once - // marshalBody has finished - if !didReadBody { - if err := r.Next().Discard(); err != nil { - return err - } + // make sure the body has been consumed + if err := body.Discard(); err != nil { + return err } if err := w.EncodeValue(resTail{ @@ -138,32 +125,40 @@ func HandleCall( /* func sqr(r mrpc.Request, rw *mrpc.ResponseWriter) { + var el jstream.Element + if err := r.Unmarshal(&el); err != nil { + rw.Response = err + return + } + + sr, err := el.DecodeStream() + if err != nil { + rw.Response = err + return + } + ch := make(chan int) - rw.Response = func(w *jstream.StreamWriter) error { + go func() { + defer close(ch) + for { + var i int + if err := sr.Next().Value(&i); err == jstream.ErrStreamEnded { + return + } else if err != nil { + panic("TODO") + } + ch <- i + } + }() + + rw.Response = func(sw *jstream.StreamWriter) error { + sw = sw.EncodeStream() for i := range ch { - if err := w.EncodeValue(i * i); err != nil { + if err := sw.EncodeValue(i * i); err != nil { return err } } return nil } - - go func() { - defer close(ch) - err := r.Unmarshal(func(r *jstream.StreamReader) error { - for { - var i int - if err := r.Next().Value(&i); err == jstream.ErrStreamEnded { - return nil - } else if err != nil { - return err - } - ch <- i - } - }) - if err != nil { - panic("TODO") - } - }() } */