Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerNorm using PyTorch #1069

Merged
merged 8 commits into from
Jul 5, 2021
Merged

LayerNorm using PyTorch #1069

merged 8 commits into from
Jul 5, 2021

Conversation

enpasos
Copy link
Contributor

@enpasos enpasos commented Jul 1, 2021

Description

Similar to BatchNormalization there exist some other variants of normalizing data flowing through the network that have been implemented by the underlying ai frameworks. Here LayerNorm as one of them has been wired up to be used with PyTorch.

This PullRequest would close #1057.

api/src/main/java/ai/djl/nn/norm/LayerNorm.java Outdated Show resolved Hide resolved
@@ -144,6 +144,25 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNBatchNorm(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(JNIEnv* env, jobject jthis,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run:

./gradlew formatCpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did

gradlew formatCpp
Found C:\djl\\gradle\wrapper\gradle-wrapper.jar

Deprecated Gradle features were used in this build, making it incompatible with Gradle 8.0.
Use '--warning-mode all' to show the individual deprecation warnings.
See https://docs.gradle.org/7.0.2/userguide/command_line_interface.html#sec:command_line_warnings

BUILD SUCCESSFUL in 1s
6 actionable tasks: 6 executed

but not seen an effect. Do I miss something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, formatCpp doesn't work on Windows
Here is what it should looks like:

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(
    JNIEnv* env, jobject jthis, jlong jinput, jlongArray jnormalizedshape, jlong jweight, jlong jbias, jdouble jeps) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copy/paste the two lines ... hope it fits ... looks like I should build djl on linux.

api/src/main/java/ai/djl/nn/norm/LayerNorm.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/nn/norm/LayerNorm.java Outdated Show resolved Hide resolved
api/src/main/java/ai/djl/nn/norm/LayerNorm.java Outdated Show resolved Hide resolved
enpasos and others added 5 commits July 1, 2021 20:21
…Ex.java

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
…DArrayEx.java

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
…ngine/TfNDArrayEx.java

Co-authored-by: Frank Liu <frankfliu2000@gmail.com>
@codecov-commenter
Copy link

codecov-commenter commented Jul 1, 2021

Codecov Report

Merging #1069 (1b1c97d) into master (a3d7cb1) will increase coverage by 0.03%.
The diff coverage is 84.50%.

Impacted file tree graph

@@             Coverage Diff              @@
##             master    #1069      +/-   ##
============================================
+ Coverage     69.98%   70.02%   +0.03%     
- Complexity     5212     5227      +15     
============================================
  Files           510      511       +1     
  Lines         23255    23339      +84     
  Branches       2489     2492       +3     
============================================
+ Hits          16276    16342      +66     
- Misses         5650     5665      +15     
- Partials       1329     1332       +3     
Impacted Files Coverage Δ
...c/main/java/ai/djl/ndarray/internal/NDArrayEx.java 94.87% <ø> (ø)
...src/main/java/ai/djl/mxnet/engine/MxNDArrayEx.java 87.58% <0.00%> (-0.20%) ⬇️
...c/main/java/ai/djl/pytorch/jni/PyTorchLibrary.java 100.00% <ø> (ø)
...ain/java/ai/djl/tensorflow/engine/TfNDArrayEx.java 71.32% <0.00%> (-0.53%) ⬇️
api/src/main/java/ai/djl/nn/norm/LayerNorm.java 85.00% <85.00%> (ø)
...c/main/java/ai/djl/pytorch/engine/PtNDArrayEx.java 85.81% <100.00%> (+0.20%) ⬆️
...ine/src/main/java/ai/djl/pytorch/jni/JniUtils.java 93.08% <100.00%> (+0.07%) ⬆️
...ai/djl/tensorflow/engine/javacpp/JavacppUtils.java 64.94% <0.00%> (-2.80%) ⬇️
.../main/java/ai/djl/tensorflow/engine/TfNDArray.java 84.34% <0.00%> (-0.89%) ⬇️
...rc/main/java/ai/djl/sentencepiece/SpProcessor.java 77.27% <0.00%> (ø)
... and 5 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a3d7cb1...1b1c97d. Read the comment docs.

@@ -144,6 +144,25 @@ JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNBatchNorm(
API_END_RETURN()
}

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(JNIEnv* env, jobject jthis,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, formatCpp doesn't work on Windows
Here is what it should looks like:

JNIEXPORT jlong JNICALL Java_ai_djl_pytorch_jni_PyTorchLibrary_torchNNLayerNorm(
    JNIEnv* env, jobject jthis, jlong jinput, jlongArray jnormalizedshape, jlong jweight, jlong jbias, jdouble jeps) {

(PtNDArray) gamma,
(PtNDArray) beta,
eps));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add en empty line here

Copy link
Contributor Author

@enpasos enpasos Jul 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried to fix it ... but could be that I am blind here ... I am using
gradlew formatJava
and
gradlew build

@frankfliu
Copy link
Contributor

@lanking520 @stu1130 Please take a look.

@lanking520 lanking520 merged commit 7a7f41f into deepjavalibrary:master Jul 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Layer Normalization
5 participants