Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to Scala 3.3.0, JavaCPP 1.5.9 stable and other dep updates #24

Merged
merged 4 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
scala: [3.2.2]
scala: [3.3.0]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -105,7 +105,6 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
scala: [3.2.2]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -142,12 +141,12 @@ jobs:
~/Library/Caches/Coursier/v1
key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }}

- name: Download target directories (3.2.2)
- name: Download target directories (3.3.0)
uses: actions/download-artifact@v3
with:
name: target-${{ matrix.os }}-${{ matrix.java }}-3.2.2
name: target-${{ matrix.os }}-${{ matrix.java }}-3.3.0

- name: Inflate target directories (3.2.2)
- name: Inflate target directories (3.3.0)
run: |
tar xf targets.tar
rm targets.tar
Expand All @@ -164,14 +163,13 @@ jobs:
(echo "$PGP_PASSPHRASE"; echo; echo) | gpg --command-fd 0 --pinentry-mode loopback --change-passphrase $(gpg --list-secret-keys --with-colons 2> /dev/null | grep '^sec:' | cut --delimiter ':' --fields 5 | tail -n 1)

- name: Publish
run: sbt '++ ${{ matrix.scala }}' tlRelease
run: sbt tlCiRelease

site:
name: Generate Site
strategy:
matrix:
os: [ubuntu-latest]
scala: [3.2.2]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -209,7 +207,7 @@ jobs:
key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }}

- name: Generate site
run: sbt '++ ${{ matrix.scala }}' docs/tlSite
run: sbt docs/tlSite

- name: Publish site
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
Expand Down
14 changes: 9 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ ThisBuild / tlSitePublishBranch := Some("main")

ThisBuild / apiURL := Some(new URL("https://storch.dev/api/"))

val scrImageVersion = "4.0.32"
val scrImageVersion = "4.0.34"
val pytorchVersion = "2.0.1"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.2.2"
ThisBuild / scalaVersion := "3.3.0"
ThisBuild / javaCppVersion := "1.5.9"

ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))

Expand All @@ -37,7 +38,7 @@ ThisBuild / enableGPU := false

lazy val commonSettings = Seq(
Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"),
javaCppVersion := "1.5.9-SNAPSHOT",
javaCppVersion := (ThisBuild / javaCppVersion).value,
javaCppPlatform := Seq(),
resolvers ++= Resolver.sonatypeOssRepos("snapshots")
// This is a hack to avoid depending on the native libs when publishing
Expand Down Expand Up @@ -80,8 +81,8 @@ lazy val core = project
libraryDependencies ++= Seq(
"org.bytedeco" % "pytorch" % s"$pytorchVersion-${javaCppVersion.value}",
"org.typelevel" %% "spire" % "0.18.0",
"org.typelevel" %% "shapeless3-typeable" % "3.2.0",
"com.lihaoyi" %% "os-lib" % "0.9.0",
"org.typelevel" %% "shapeless3-typeable" % "3.3.0",
"com.lihaoyi" %% "os-lib" % "0.9.1",
"com.lihaoyi" %% "sourcecode" % "0.3.0",
"dev.dirs" % "directories" % "26",
"org.scalameta" %% "munit" % "0.7.29" % Test,
Expand Down Expand Up @@ -137,3 +138,6 @@ lazy val root = project
.enablePlugins(NoPublishPlugin)
.in(file("."))
.aggregate(core, vision, examples, docs)
.settings(
javaCppVersion := (ThisBuild / javaCppVersion).value
)
8 changes: 4 additions & 4 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ private[torch] object NativeConverters:
toOptional(t, t => pytorch.TensorOptional(t.native))

def toArray(i: Long | (Long, Long)) = i match
case i: Long => Array(i)
case (i, j): (Long, Long) => Array(i, j)
case i: Long => Array(i)
case (i, j) => Array(i, j)

def toNative(input: Int | (Int, Int)) = input match
case (h, w): (Int, Int) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)

def toScalar(x: ScalaType): pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/nn/functional/pooling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ import torch.internal.NativeConverters.toOptional

def maxPool2d[D <: DType](input: Tensor[D], kernelSize: Long | (Long, Long)): Tensor[D] =
val kernelSizeNative = kernelSize match
case (h, w): (Long, Long) => Array(h, w)
case x: Long => Array(x, x)
case (h, w) => Array(h, w)
case x: Long => Array(x, x)
Tensor(torchNative.max_pool2d(input.native, kernelSizeNative*))
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ final class AdaptiveAvgPool2d(
) extends Module {

private def nativeOutputSize = outputSize match
case (h, w): (Int, Int) => new LongOptionalVector(new LongOptional(h), new LongOptional(w))
case x: Int => new LongOptionalVector(new LongOptional(x), new LongOptional(x))
case (h, w): (Option[Int], Option[Int]) =>
case (h: Int, w: Int) => new LongOptionalVector(new LongOptional(h), new LongOptional(w))
case x: Int => new LongOptionalVector(new LongOptional(x), new LongOptional(x))
// We know this can only be int so we can suppress the type test for Option[Int] cannot be checked at runtime warning
case (h: Option[Int @unchecked], w: Option[Int @unchecked]) =>
new LongOptionalVector(toOptional(h.map(_.toLong)), toOptional(w.map(_.toLong)))
case x: Option[Int] =>
new LongOptionalVector(toOptional(x.map(_.toLong)), toOptional(x.map(_.toLong)))
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.8.2
sbt.version=1.9.0
8 changes: 4 additions & 4 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.0")
addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.3.6")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.3.7")
addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0")
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.4.18")
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.4.18")
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.4.22")
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.4.22")