Skip to content

Commit

Permalink
Merge pull request #24 from sbrunk/update-deps
Browse files Browse the repository at this point in the history
Update to Scala 3.3.0, JavaCPP 1.5.9 stable and other dep updates
  • Loading branch information
sbrunk authored Jun 12, 2023
2 parents 66266aa + 33e25af commit 3fb3c5d
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 27 deletions.
14 changes: 6 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
scala: [3.2.2]
scala: [3.3.0]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -105,7 +105,6 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
scala: [3.2.2]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -142,12 +141,12 @@ jobs:
~/Library/Caches/Coursier/v1
key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }}

- name: Download target directories (3.2.2)
- name: Download target directories (3.3.0)
uses: actions/download-artifact@v3
with:
name: target-${{ matrix.os }}-${{ matrix.java }}-3.2.2
name: target-${{ matrix.os }}-${{ matrix.java }}-3.3.0

- name: Inflate target directories (3.2.2)
- name: Inflate target directories (3.3.0)
run: |
tar xf targets.tar
rm targets.tar
Expand All @@ -164,14 +163,13 @@ jobs:
(echo "$PGP_PASSPHRASE"; echo; echo) | gpg --command-fd 0 --pinentry-mode loopback --change-passphrase $(gpg --list-secret-keys --with-colons 2> /dev/null | grep '^sec:' | cut --delimiter ':' --fields 5 | tail -n 1)
- name: Publish
run: sbt '++ ${{ matrix.scala }}' tlRelease
run: sbt tlCiRelease

site:
name: Generate Site
strategy:
matrix:
os: [ubuntu-latest]
scala: [3.2.2]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
Expand Down Expand Up @@ -209,7 +207,7 @@ jobs:
key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }}

- name: Generate site
run: sbt '++ ${{ matrix.scala }}' docs/tlSite
run: sbt docs/tlSite

- name: Publish site
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
Expand Down
14 changes: 9 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ ThisBuild / tlSitePublishBranch := Some("main")

ThisBuild / apiURL := Some(new URL("https://storch.dev/api/"))

val scrImageVersion = "4.0.32"
val scrImageVersion = "4.0.34"
val pytorchVersion = "2.0.1"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.2.2"
ThisBuild / scalaVersion := "3.3.0"
ThisBuild / javaCppVersion := "1.5.9"

ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))

Expand All @@ -37,7 +38,7 @@ ThisBuild / enableGPU := false

lazy val commonSettings = Seq(
Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"),
javaCppVersion := "1.5.9-SNAPSHOT",
javaCppVersion := (ThisBuild / javaCppVersion).value,
javaCppPlatform := Seq(),
resolvers ++= Resolver.sonatypeOssRepos("snapshots")
// This is a hack to avoid depending on the native libs when publishing
Expand Down Expand Up @@ -80,8 +81,8 @@ lazy val core = project
libraryDependencies ++= Seq(
"org.bytedeco" % "pytorch" % s"$pytorchVersion-${javaCppVersion.value}",
"org.typelevel" %% "spire" % "0.18.0",
"org.typelevel" %% "shapeless3-typeable" % "3.2.0",
"com.lihaoyi" %% "os-lib" % "0.9.0",
"org.typelevel" %% "shapeless3-typeable" % "3.3.0",
"com.lihaoyi" %% "os-lib" % "0.9.1",
"com.lihaoyi" %% "sourcecode" % "0.3.0",
"dev.dirs" % "directories" % "26",
"org.scalameta" %% "munit" % "0.7.29" % Test,
Expand Down Expand Up @@ -137,3 +138,6 @@ lazy val root = project
.enablePlugins(NoPublishPlugin)
.in(file("."))
.aggregate(core, vision, examples, docs)
.settings(
javaCppVersion := (ThisBuild / javaCppVersion).value
)
8 changes: 4 additions & 4 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ private[torch] object NativeConverters:
toOptional(t, t => pytorch.TensorOptional(t.native))

def toArray(i: Long | (Long, Long)) = i match
case i: Long => Array(i)
case (i, j): (Long, Long) => Array(i, j)
case i: Long => Array(i)
case (i, j) => Array(i, j)

def toNative(input: Int | (Int, Int)) = input match
case (h, w): (Int, Int) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)
case (h, w) => LongPointer(Array(h.toLong, w.toLong)*)
case x: Int => LongPointer(Array(x.toLong, x.toLong)*)

def toScalar(x: ScalaType): pytorch.Scalar = x match
case x: Boolean => pytorch.Scalar(if true then 1: Byte else 0: Byte)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/nn/functional/pooling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ import torch.internal.NativeConverters.toOptional

def maxPool2d[D <: DType](input: Tensor[D], kernelSize: Long | (Long, Long)): Tensor[D] =
val kernelSizeNative = kernelSize match
case (h, w): (Long, Long) => Array(h, w)
case x: Long => Array(x, x)
case (h, w) => Array(h, w)
case x: Long => Array(x, x)
Tensor(torchNative.max_pool2d(input.native, kernelSizeNative*))
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ final class AdaptiveAvgPool2d(
) extends Module {

private def nativeOutputSize = outputSize match
case (h, w): (Int, Int) => new LongOptionalVector(new LongOptional(h), new LongOptional(w))
case x: Int => new LongOptionalVector(new LongOptional(x), new LongOptional(x))
case (h, w): (Option[Int], Option[Int]) =>
case (h: Int, w: Int) => new LongOptionalVector(new LongOptional(h), new LongOptional(w))
case x: Int => new LongOptionalVector(new LongOptional(x), new LongOptional(x))
// We know this can only be int so we can suppress the type test for Option[Int] cannot be checked at runtime warning
case (h: Option[Int @unchecked], w: Option[Int @unchecked]) =>
new LongOptionalVector(toOptional(h.map(_.toLong)), toOptional(w.map(_.toLong)))
case x: Option[Int] =>
new LongOptionalVector(toOptional(x.map(_.toLong)), toOptional(x.map(_.toLong)))
Expand Down
2 changes: 1 addition & 1 deletion project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version=1.8.2
sbt.version=1.9.0
8 changes: 4 additions & 4 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.4.6")
addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.0")
addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.3.6")
addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.3.7")
addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0")
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.4.18")
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.4.18")
addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.4.22")
addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.4.22")

0 comments on commit 3fb3c5d

Please sign in to comment.