From 1073b463b24b2b28541e270099b470265ad87798 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 28 Feb 2020 12:32:13 +0700 Subject: [PATCH] re-add remove PSK decoding function, and deprecate it --- codec.go | 68 ++++++++++++++++++++++++++++ codec_test.go | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+) create mode 100644 codec.go create mode 100644 codec_test.go diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..bb20a61 --- /dev/null +++ b/codec.go @@ -0,0 +1,68 @@ +package pnet + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/hex" + "fmt" + "io" +) + +var ( + pathPSKv1 = []byte("/key/swarm/psk/1.0.0/") + pathBin = "/bin/" + pathBase16 = "/base16/" + pathBase64 = "/base64/" +) + +func readHeader(r *bufio.Reader) ([]byte, error) { + header, err := r.ReadBytes('\n') + if err != nil { + return nil, err + } + + return bytes.TrimRight(header, "\r\n"), nil +} + +func expectHeader(r *bufio.Reader, expected []byte) error { + header, err := readHeader(r) + if err != nil { + return err + } + if !bytes.Equal(header, expected) { + return fmt.Errorf("expected file header %s, got: %s", pathPSKv1, header) + } + return nil +} + +// DecodeV1PSK reads a Multicodec encoded V1 PSK. +// +// Deprecated: This functionality will soon be removed from libp2p. +func DecodeV1PSK(in io.Reader) ([]byte, error) { + reader := bufio.NewReader(in) + if err := expectHeader(reader, pathPSKv1); err != nil { + return nil, err + } + header, err := readHeader(reader) + if err != nil { + return nil, err + } + + var decoder io.Reader + switch string(header) { + case pathBase16: + decoder = hex.NewDecoder(reader) + case pathBase64: + decoder = base64.NewDecoder(base64.StdEncoding, reader) + case pathBin: + decoder = reader + default: + return nil, fmt.Errorf("unknown encoding: %s", header) + } + out := make([]byte, 32) + if _, err = io.ReadFull(decoder, out[:]); err != nil { + return nil, err + } + return out, nil +} diff --git a/codec_test.go b/codec_test.go new file mode 100644 index 0000000..b4b9272 --- /dev/null +++ b/codec_test.go @@ -0,0 +1,122 @@ +package pnet + +import ( + "bytes" + "encoding/base64" + "testing" +) + +func bufWithBase(base string, windows bool) *bytes.Buffer { + b := &bytes.Buffer{} + b.Write(pathPSKv1) + if windows { + b.WriteString("\r") + } + b.WriteString("\n") + b.WriteString(base) + if windows { + b.WriteString("\r") + } + b.WriteString("\n") + return b +} + +func TestDecodeHex(t *testing.T) { + testDecodeHex(t, true) + testDecodeHex(t, false) +} + +func TestDecodeBad(t *testing.T) { + testDecodeBad(t, true) + testDecodeBad(t, false) +} + +func testDecodeBad(t *testing.T, windows bool) { + b := bufWithBase("/verybadbase/", windows) + b.WriteString("Have fun decoding that key") + + _, err := DecodeV1PSK(b) + if err == nil { + t.Fatal("expected 'unknown encoding' got nil") + } +} + +func testDecodeHex(t *testing.T, windows bool) { + b := bufWithBase("/base16/", windows) + for i := 0; i < 32; i++ { + b.WriteString("FF") + } + + psk, err := DecodeV1PSK(b) + if err != nil { + t.Fatal(err) + } + + for _, b := range psk { + if b != 255 { + t.Fatal("byte was wrong") + } + } +} + +func TestDecodeB64(t *testing.T) { + testDecodeB64(t, true) + testDecodeB64(t, false) +} + +func testDecodeB64(t *testing.T, windows bool) { + b := bufWithBase("/base64/", windows) + key := make([]byte, 32) + for i := 0; i < 32; i++ { + key[i] = byte(i) + } + + e := base64.NewEncoder(base64.StdEncoding, b) + _, err := e.Write(key) + if err != nil { + t.Fatal(err) + } + err = e.Close() + if err != nil { + t.Fatal(err) + } + + psk, err := DecodeV1PSK(b) + if err != nil { + t.Fatal(err) + } + + for i, b := range psk { + if b != psk[i] { + t.Fatal("byte was wrong") + } + } + +} + +func TestDecodeBin(t *testing.T) { + testDecodeBin(t, true) + testDecodeBin(t, false) +} + +func testDecodeBin(t *testing.T, windows bool) { + b := bufWithBase("/bin/", windows) + key := make([]byte, 32) + for i := 0; i < 32; i++ { + key[i] = byte(i) + } + + b.Write(key) + + psk, err := DecodeV1PSK(b) + if err != nil { + t.Fatal(err) + } + + for i, b := range psk { + if b != psk[i] { + t.Fatal("byte was wrong") + } + } + +}