Skip to content

Commit

Permalink
Fix recurrent block memory leak and output shape calculation (#556)
Browse files Browse the repository at this point in the history
* Fix recurrent block memory leak and output shape calculation

* Code formatting

* Applied code review remarks
  • Loading branch information
mpskowron authored Jan 25, 2021
1 parent a56ba55 commit 0b4f3a5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
16 changes: 9 additions & 7 deletions api/src/main/java/ai/djl/nn/recurrent/LSTM.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
Expand Down Expand Up @@ -128,27 +129,28 @@ protected NDList opInputs(ParameterStore parameterStore, NDList inputs, boolean
validateInputSize(inputs);
long batchSize = inputs.head().getShape().get(0);
inputs = updateInputLayoutToTNC(inputs);
NDArray head = inputs.singletonOrThrow();
NDArray head = inputs.head();
NDManager manager = head.getManager();
Device device = head.getDevice();

NDList result = new NDList(head);
try (NDList parameterList = new NDList()) {
for (Parameter parameter : parameters.values()) {
NDArray array = parameterStore.getValue(parameter, device, training);
parameterList.add(array.flatten());
NDArray array = parameterStore.getValue(parameter, device, training).flatten();
array.attach(manager);
parameterList.add(array);
}
NDArray array = NDArrays.concat(parameterList);
result.add(array);
}
// Adding state and stateCell
Shape stateShape = new Shape(numStackedLayers * numDirections, batchSize, stateSize);
Shape stateShape = new Shape((long) numStackedLayers * numDirections, batchSize, stateSize);
if (beginState != null) {
result.add(beginState);
result.add(beginStateCell);
} else {
// TODO manager creates the NDArray with the wrong device
result.add(head.getManager().zeros(stateShape, DataType.FLOAT32, device));
result.add(head.getManager().zeros(stateShape, DataType.FLOAT32, device));
result.add(manager.zeros(stateShape, DataType.FLOAT32, device));
result.add(manager.zeros(stateShape, DataType.FLOAT32, device));
}
if (useSequenceLength) {
result.add(inputs.get(1));
Expand Down
24 changes: 13 additions & 11 deletions api/src/main/java/ai/djl/nn/recurrent/RecurrentBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -164,17 +164,18 @@ protected void resetBeginStates() {
/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
// Input shape at this point is TNC. Output Shape should be NTS
// Input shape at this point is NTC. Output Shape should be NTS
Shape inputShape = inputs[0];
long nShape = inputShape.get(0);
long tShape = inputShape.get(1);
Shape nonStateOutputShape = new Shape(nShape, tShape, stateSize * numDirections);
if (stateOutputs) {
return new Shape[] {
new Shape(inputShape.get(1), inputShape.get(0), stateSize * numDirections),
new Shape(numStackedLayers * numDirections, inputShape.get(1), stateSize)
nonStateOutputShape,
new Shape((long) numStackedLayers * numDirections, nShape, stateSize)
};
}
return new Shape[] {
new Shape(inputShape.get(1), inputShape.get(0), stateSize * numDirections)
};
return new Shape[] {nonStateOutputShape};
}

/** {@inheritDoc} */
Expand All @@ -183,7 +184,6 @@ public void beforeInitialize(Shape[] inputs) {
super.beforeInitialize(inputs);
Shape inputShape = inputs[0];
Block.validateLayout(EXPECTED_LAYOUT, inputShape.getLayout());
inputs[0] = new Shape(inputShape.get(1), inputShape.get(0), inputShape.get(2));
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -227,22 +227,24 @@ protected NDList opInputs(ParameterStore parameterStore, NDList inputs, boolean
long batchSize = inputs.head().getShape().get(0);
inputs = updateInputLayoutToTNC(inputs);
NDArray head = inputs.head();
NDManager manager = head.getManager();
Device device = head.getDevice();

NDList result = new NDList(head);
try (NDList parameterList = new NDList()) {
for (Parameter parameter : parameters.values()) {
NDArray array = parameterStore.getValue(parameter, device, training);
parameterList.add(array.flatten());
NDArray array = parameterStore.getValue(parameter, device, training).flatten();
array.attach(manager);
parameterList.add(array);
}
NDArray array = NDArrays.concat(parameterList);
result.add(array);
}
Shape stateShape = new Shape(numStackedLayers * numDirections, batchSize, stateSize);
Shape stateShape = new Shape((long) numStackedLayers * numDirections, batchSize, stateSize);
if (beginState != null) {
result.add(beginState);
} else {
result.add(inputs.head().getManager().zeros(stateShape));
result.add(manager.zeros(stateShape));
}
if (useSequenceLength) {
result.add(inputs.get(1));
Expand Down

0 comments on commit 0b4f3a5

Please sign in to comment.