Skip to content

Commit

Permalink
save/loadMetadata fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
enpasos committed Jul 1, 2021
1 parent 380fb3b commit b780542
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions api/src/main/java/ai/djl/nn/norm/LayerNorm.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,26 +165,18 @@ public void prepare(Shape[] inputShapes) {
@Override
protected void saveMetadata(DataOutputStream os) throws IOException {
saveInputShapes(os);
os.writeInt(normalizedShape.getShape().length);
for (int i = 0; i < normalizedShape.getShape().length; i++) {
os.writeLong(normalizedShape.getShape()[i]);
}
os.write(normalizedShape.getEncoded());
}

/** {@inheritDoc} */
@Override
public void loadMetadata(byte version, DataInputStream is)
throws IOException, MalformedModelException {
if (version == VERSION) {
readInputShapes(is);
} else if (version != 1) {
if (version != VERSION) {
throw new MalformedModelException("Unsupported encoding version: " + version);
}
long[] shapeRaw = new long[is.readInt()];
for (int i = 0; i < shapeRaw.length; i++) {
shapeRaw[i] = is.readLong();
}
normalizedShape = new Shape(shapeRaw);
readInputShapes(is);
normalizedShape = Shape.decode(is);
}

/** The Builder to construct a {@link LayerNorm}. */
Expand Down

0 comments on commit b780542

Please sign in to comment.