Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.randn to use proper native function #27

Merged
merged 1 commit into from
Jun 17, 2023

Conversation

davoclavo
Copy link
Contributor

@davoclavo davoclavo commented Jun 16, 2023

Fix subtle bug with torch.randn, as it was mistakenly using rand under the hood.

I tried implementing a variance check as well, but I couldn't find the var function under javacpp. If anyone is able to help me I'd appreciate it.

As a side note, I will keep implementing more tensor ops soon, and improving test coverage to try to avoid these kind of subtle errors; a handful of them happened to me while implementing pointwise ops.

@davoclavo davoclavo force-pushed the fix_randn branch 2 times, most recently from 1ee0ce3 to 6e2c3aa Compare June 16, 2023 19:19
@davoclavo
Copy link
Contributor Author

davoclavo commented Jun 16, 2023

Update: I found the variance function, it is under nativeTorch.`var` - I will implement that function + the unit test check, so please don't merge yet

@davoclavo davoclavo marked this pull request as draft June 16, 2023 20:07
@davoclavo davoclavo marked this pull request as ready for review June 16, 2023 21:21
@davoclavo
Copy link
Contributor Author

davoclavo commented Jun 16, 2023

Done! I just added Tensor.variance + relevant check in the randn unit test

I hesitated sticking to the pytorch nomenclature due to the forced backticks in Scala (Tensor.`var` ) because it feels unnatural, and would require translating those functions names anyways (var -> `var`). However, please let me know I should implement it with pytorch's name (and backticks), or if I should add it under both names, happy to do either one.

@sbrunk
Copy link
Owner

sbrunk commented Jun 17, 2023

Looks good to me, thanks!

I hesitated sticking to the pytorch nomenclature due to the forced backticks in Scala (Tensor.`var` ) because it feels unnatural, and would require translating those functions names anyways (var -> `var`). However, please let me know I should implement it with pytorch's name (and backticks), or if I should add it under both names, happy to do either one.

Also in favor of just using variance here. I think it helps with discoverability. If you type var in the IDE or in Scaladoc search and it should suggest variance

As a side note, I will keep implementing more tensor ops soon, and improving test coverage to try to avoid these kind of subtle errors; a handful of them happened to me while implementing pointwise ops.

Awesome! FYI just to avoid double work, I've started implementing reduction ops in #28

@sbrunk sbrunk merged commit 45146d5 into sbrunk:main Jun 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants