From b7805429c17e4436e0641d6787950f6db2dace65 Mon Sep 17 00:00:00 2001 From: Matthias Unverzagt Date: Thu, 1 Jul 2021 20:42:11 +0200 Subject: [PATCH] save/loadMetadata fixes --- api/src/main/java/ai/djl/nn/norm/LayerNorm.java | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java index 7b76a2c20a7..0861e1fa72f 100644 --- a/api/src/main/java/ai/djl/nn/norm/LayerNorm.java +++ b/api/src/main/java/ai/djl/nn/norm/LayerNorm.java @@ -165,26 +165,18 @@ public void prepare(Shape[] inputShapes) { @Override protected void saveMetadata(DataOutputStream os) throws IOException { saveInputShapes(os); - os.writeInt(normalizedShape.getShape().length); - for (int i = 0; i < normalizedShape.getShape().length; i++) { - os.writeLong(normalizedShape.getShape()[i]); - } + os.write(normalizedShape.getEncoded()); } /** {@inheritDoc} */ @Override public void loadMetadata(byte version, DataInputStream is) throws IOException, MalformedModelException { - if (version == VERSION) { - readInputShapes(is); - } else if (version != 1) { + if (version != VERSION) { throw new MalformedModelException("Unsupported encoding version: " + version); } - long[] shapeRaw = new long[is.readInt()]; - for (int i = 0; i < shapeRaw.length; i++) { - shapeRaw[i] = is.readLong(); - } - normalizedShape = new Shape(shapeRaw); + readInputShapes(is); + normalizedShape = Shape.decode(is); } /** The Builder to construct a {@link LayerNorm}. */