-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add reduction ops #28
Conversation
1fa94d4
to
9901dd0
Compare
Hey @davoclavo could you have a look when you find the time? I've added most reduction ops and also tests for most of them (made much easier thanks to your helpers). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Amazing work! And awesome docs, I will try to follow that style from now on.
I just added a quick comment regarding type promotion.
And also another quick question: I have been changing snake_case to camelCase to match scala style, however there is some benefit of sticking with snake_case to match pytorch nomenclature. Should pick
one over the other to have consistency?
* @param p | ||
* the norm to be computed | ||
*/ | ||
// TODO dtype promotion floatNN/complexNN => highest floatNN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you could use FloatPromoted[D]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but trying that I think we might need to add cases for float16/bfloat16 to FloatPromoted
:
I.e. sin
in PyTorch:
torch.sin(torch.tensor(1., dtype=torch.bfloat16)).dtype # torch.bfloat16
scala> torch.sin(torch.Tensor(1f).to(dtype=torch.bfloat16)).dtype
java.lang.ClassCastException: class torch.DType$bfloat16$ cannot be cast to class torch.Float32 (torch.DType$bfloat16$ and torch.Float32 are in unnamed module of loader sbt.internal.inc.classpath.ClasspathUtil$$anon$2 @586fb5d6)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I got it now, although the type is getting somewhat verbose 😆
Tensor[Promoted[FloatPromoted[ComplexToReal[D]], FloatPromoted[ComplexToReal[D2]]]]
Thanks! I wanted to document how I convert the docs, but didn't get to it yet. I'm using the original PyTorch rst doc sources, i.e. for ops in torch:
The conversion is not perfect but it often gets you 80% there reducing a lot of tedious manual work. Also please don't feel pressured to add all the docs immediately if you don't feel like it when porting things over. They can always be added/improved later on as well. :)
Yeah that's something I'm not satisfied with either way. I started with camelCase for idiomatic style but realized it makes porting and searching for docs a bit more inconvenient. Since most names are now camelCase already, I'm currently inclined to stick with it, but I'm absolutely not religious about it. If you feel it makes more sense to move to snake_case, now or later, let's discuss it. |
https://pytorch.org/docs/stable/torch.html#reduction-ops