diff --git a/gen.go b/gen.go index 67f31bd..c572c0d 100644 --- a/gen.go +++ b/gen.go @@ -589,16 +589,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { return doTemplate(w, f, ` { {{ if .Pointer }} - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { {{ end }} c, err := cbg.ReadCid(br) if err != nil { @@ -628,16 +626,14 @@ func emitCborUnmarshalStructField(w io.Writer, f Field) error { return doTemplate(w, f, ` { {{ if .Pointer }} - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { {{ .Name }} = new({{ .TypeName }}) if err := {{ .Name }}.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling {{ .Name }} pointer: %w", err) @@ -685,16 +681,14 @@ func emitCborUnmarshalUint64Field(w io.Writer, f Field) error { return doTemplate(w, f, ` { {{ if .Pointer }} - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = {{ ReadHeader "br" }} if err != nil { return err diff --git a/peeker.go b/peeker.go new file mode 100644 index 0000000..2b0658c --- /dev/null +++ b/peeker.go @@ -0,0 +1,80 @@ +package typegen + +import ( + "bufio" + "io" +) + +// BytePeeker combines the Reader and ByteScanner interfaces. +type BytePeeker interface { + io.Reader + io.ByteScanner +} + +func GetPeeker(r io.Reader) BytePeeker { + if r, ok := r.(BytePeeker); ok { + return r + } + return &peeker{reader: r} +} + +// peeker is a non-buffering BytePeeker. +type peeker struct { + reader io.Reader + peekState int + lastByte byte +} + +const ( + peekEmpty = iota + peekSet + peekUnread +) + +func (p *peeker) Read(buf []byte) (n int, err error) { + // Read "nothing". I.e., read an error, maybe. + if len(buf) == 0 { + // There's something pending in the + if p.peekState == peekUnread { + return 0, nil + } + return p.reader.Read(nil) + } + + if p.peekState == peekUnread { + buf[0] = p.lastByte + n, err = p.reader.Read(buf[1:]) + n += 1 + } else { + n, err = p.reader.Read(buf) + } + if n > 0 { + p.peekState = peekSet + p.lastByte = buf[n-1] + } + return n, err +} + +func (p *peeker) ReadByte() (byte, error) { + if p.peekState == peekUnread { + p.peekState = peekSet + return p.lastByte, nil + } + var buf [1]byte + n, err := p.reader.Read(buf[:]) + if n == 0 { + return 0, err + } + b := buf[0] + p.lastByte = b + p.peekState = peekSet + return b, err +} + +func (p *peeker) UnreadByte() error { + if p.peekState != peekSet { + return bufio.ErrInvalidUnreadByte + } + p.peekState = peekUnread + return nil +} diff --git a/peeker_test.go b/peeker_test.go new file mode 100644 index 0000000..17df763 --- /dev/null +++ b/peeker_test.go @@ -0,0 +1,103 @@ +package typegen + +import ( + "bufio" + "bytes" + "io" + "testing" +) + +func TestPeeker(t *testing.T) { + buf := bytes.NewBuffer([]byte{0, 1, 2, 3}) + p := peeker{reader: buf} + n, err := p.Read(nil) + if err != nil { + t.Fatal(err) + } + if n != 0 { + t.Fatal(err) + } + + err = p.UnreadByte() + if err != bufio.ErrInvalidUnreadByte { + t.Fatal(err) + } + + // read 2 bytes + var out [2]byte + n, err = p.Read(out[:]) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %d", n) + } + if !bytes.Equal(out[:], []byte{0, 1}) { + t.Fatalf("unexpected output") + } + + // unread that last byte and read it again. + err = p.UnreadByte() + if err != nil { + t.Fatal(err) + } + b, err := p.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 1 { + t.Fatal("expected 1") + } + + // unread that last byte then read 2 + err = p.UnreadByte() + if err != nil { + t.Fatal(err) + } + n, err = p.Read(out[:]) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %d", n) + } + if !bytes.Equal(out[:], []byte{1, 2}) { + t.Fatalf("unexpected output") + } + + // read another byte + b, err = p.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 3 { + t.Fatal("expected 1") + } + + // Should read eof at end. + n, err = p.Read(out[:]) + if err != io.EOF { + t.Fatal(err) + } + if n != 0 { + t.Fatal("should have been at end") + } + // should unread eof + err = p.UnreadByte() + if err != nil { + t.Fatal(err) + } + + _, err = p.Read(nil) + if err != nil { + t.Fatal(err) + } + + b, err = p.ReadByte() + if err != nil { + t.Fatal(err) + } + if b != 3 { + t.Fatal("expected 1") + } +} diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index f9d3655..186f9a2 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -427,16 +427,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stuff = new(SimpleTypeTwo) if err := t.Stuff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) @@ -617,16 +615,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err @@ -643,16 +639,14 @@ func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err @@ -753,16 +747,14 @@ func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stuff = new(SimpleTypeOne) if err := t.Stuff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index a29385e..1016c5d 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -221,16 +221,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stuff = new(SimpleTypeTree) if err := t.Stuff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stuff pointer: %w", err) @@ -243,16 +241,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { t.Stufff = new(SimpleTypeTwo) if err := t.Stufff.UnmarshalCBOR(br); err != nil { return xerrors.Errorf("unmarshaling t.Stufff pointer: %w", err) @@ -384,16 +380,14 @@ func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { { - pb, err := br.PeekByte() + b, err := br.ReadByte() if err != nil { return err } - if pb == cbg.CborNull[0] { - var nbuf [1]byte - if _, err := br.Read(nbuf[:]); err != nil { + if b != cbg.CborNull[0] { + if err := br.UnreadByte(); err != nil { return err } - } else { maj, extra, err = cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err diff --git a/utils.go b/utils.go index 858b253..b2496c8 100644 --- a/utils.go +++ b/utils.go @@ -1,8 +1,6 @@ package typegen import ( - "bufio" - "bytes" "encoding/binary" "errors" "fmt" @@ -203,57 +201,6 @@ func (d *Deferred) UnmarshalCBOR(br io.Reader) error { } } -// this is a bit gnarly i should just switch to taking in a byte array at the top level -type BytePeeker interface { - io.Reader - PeekByte() (byte, error) -} - -type peeker struct { - io.Reader -} - -func (p *peeker) PeekByte() (byte, error) { - switch r := p.Reader.(type) { - case *bytes.Reader: - b, err := r.ReadByte() - if err != nil { - return 0, err - } - return b, r.UnreadByte() - case *bytes.Buffer: - b, err := r.ReadByte() - if err != nil { - return 0, err - } - return b, r.UnreadByte() - case *bufio.Reader: - o, err := r.Peek(1) - if err != nil { - return 0, err - } - - return o[0], nil - default: - panic("invariant violated") - } -} - -func GetPeeker(r io.Reader) BytePeeker { - switch r := r.(type) { - case *bytes.Reader: - return &peeker{r} - case *bytes.Buffer: - return &peeker{r} - case *bufio.Reader: - return &peeker{r} - case *peeker: - return r - default: - return &peeker{bufio.NewReaderSize(r, 16)} - } -} - func readByte(r io.Reader) (byte, error) { if br, ok := r.(io.ByteReader); ok { return br.ReadByte()