Skip to content

Commit

Permalink
cuda version
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <kemingyang@tensorchord.ai>
  • Loading branch information
kemingy committed Sep 27, 2022
1 parent 5e407d0 commit 2a45a8f
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions pkg/lang/ir/system.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ func (g Graph) compileCopy(root llb.State) llb.State {
return result
}

func (g *Graph) compileCUDAPackages(org, version string) llb.State {
return llb.Image(fmt.Sprintf(
"docker.io/%s/python:3.9-%s-cuda%s-cudnn%s-envd-%s",
org, g.OS, *g.CUDA, *g.CUDNN, version))
func (g *Graph) compileCUDAPackages(org string) llb.State {
return g.preparePythonBase(llb.Image(fmt.Sprintf(
"docker.io/%s/%s-cudnn%s-devel-%s",
org, *g.CUDA, *g.CUDNN, g.OS)))
}

func (g Graph) compileSystemPackages(root llb.State) llb.State {
Expand Down Expand Up @@ -143,14 +143,13 @@ func (g *Graph) compileExtraSource(root llb.State) (llb.State, error) {
return llb.Merge(inputs, llb.WithCustomName("[internal] build source layers")), nil
}

func (g *Graph) preparePythonBase() llb.State {
base := llb.Image(types.PythonBaseImage)
func (g *Graph) preparePythonBase(root llb.State) llb.State {
for _, env := range types.BaseEnvironment {
base = base.AddEnv(env.Name, env.Value)
root = root.AddEnv(env.Name, env.Value)
}

// envd-sshd
sshd := base.File(llb.Copy(
sshd := root.File(llb.Copy(
llb.Image(types.EnvdSshdImage), "/usr/bin/envd-sshd", "/var/envd/bin/envd-sshd",
&llb.CopyInfo{CreateDestPath: true}), llb.WithCustomName("[internal] add envd-sshd"))

Expand Down Expand Up @@ -208,13 +207,14 @@ func (g *Graph) compileBase() (llb.State, error) {
g.uid = 1001
}
case "python":
base = g.preparePythonBase()
// TODO(keming) use user input `base(os="")`
base = g.preparePythonBase(llb.Image(types.PythonBaseImage))
case "julia":
base = llb.Image(fmt.Sprintf(
"docker.io/%s/julia:1.8rc1-ubuntu20.04-envd-%s", org, v))
}
} else {
base = g.compileCUDAPackages(org, v)
base = g.compileCUDAPackages("nvidia/cuda")
}
var res llb.ExecState

Expand Down

0 comments on commit 2a45a8f

Please sign in to comment.