Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add device detection in lmes driver #298

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ var (
grpcService = flag.String("grpc-service", "", "grpc service name")
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", true, "detect available device(s), CUDA or CPU")
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
driverLog = ctrl.Log.WithName("driver")
)
Expand Down Expand Up @@ -83,6 +84,7 @@ func main() {
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
Args: args,
ReportInterval: *reportInterval,
Expand Down
2 changes: 2 additions & 0 deletions controllers/lmes/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
GrpcClientSecretKey = "lmes-grpc-client-secret"
MaxBatchSizeKey = "lmes-max-batch-size"
DefaultBatchSizeKey = "lmes-default-batch-size"
DetectDeviceKey = "lmes-detect-device"
DriverReportIntervalKey = "driver-report-interval"
GrpcServerCertEnv = "GRPC_SERVER_CERT"
GrpcServerKeyEnv = "GRPC_SERVER_KEY"
Expand All @@ -51,5 +52,6 @@ const (
DefaultGrpcClientSecret = "grpc-client-cert"
DefaultMaxBatchSize = 24
DefaultBatchSize = 8
DefaultDetectDevice = true
ServiceName = "LMES"
)
53 changes: 53 additions & 0 deletions controllers/lmes/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type DriverOption struct {
GrpcService string
GrpcPort int
OutputPath string
DetectDevice bool
Logger logr.Logger
Args []string
ReportInterval time.Duration
Expand Down Expand Up @@ -200,7 +201,59 @@ func getGRPCClientConn(option *DriverOption) (clientConn *grpc.ClientConn, err e
return
}

func (d *driverImpl) detectDevice() error {
if d == nil || !d.Option.DetectDevice {
return nil
}

// assuming python and torch python package are available.
// use torch python API to detect CUDA's availability
out, err := exec.Command(
"python",
"-c",
"import torch; print('=={}:{}=='.format(torch.cuda.is_available(), torch.cuda.device_count()));",
).Output()
if err != nil {
return fmt.Errorf("failed to detect available device(s): %v", err)
}

re := regexp.MustCompile(`(?m)^==(True|False):(\d+?)==$`)
matches := re.FindStringSubmatch(string(out))
if matches == nil {
return fmt.Errorf("failed to find the matched output")
}

patchDevice(d.Option.Args, matches[1] == "True")

return nil
}

func patchDevice(args []string, hasCuda bool) {
var device = "cpu"
if hasCuda {
device = "cuda"
}
// patch the python command in the Option.Arg by adding the `--device cuda` option
// find the string with the `python -m lm_eval` prefix. usually it should be the last one
for idx, arg := range args {
if strings.HasPrefix(arg, "python -m lm_eval") {
if !strings.Contains(arg, "--device") {
args[idx] = fmt.Sprintf("%s --device %s", arg, device)
}
break
}
}
}

func (d *driverImpl) exec() error {

// Detect available devices if needed
if err := d.detectDevice(); err != nil {
return err
}

fmt.Printf("%q\n", d.Option.Args)

// Run user program.
var args []string
if len(d.Option.Args) > 1 {
Expand Down
73 changes: 73 additions & 0 deletions controllers/lmes/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,76 @@ func Test_ProgressUpdate(t *testing.T) {
assert.Nil(t, os.Remove("./stderr.log"))
assert.Nil(t, os.Remove("./stdout.log"))
}

func Test_DetectDeviceError(t *testing.T) {
server := grpc.NewServer()
progresssServer := ProgressUpdateServer{}
v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082))
assert.Nil(t, err)
go server.Serve(lis)

driver, err := NewDriver(&DriverOption{
Context: context.Background(),
JobNamespace: "fms-lm-eval-service-system",
JobName: "evaljob-sample",
GrpcService: "localhost",
GrpcPort: 8082,
OutputPath: ".",
DetectDevice: true,
Logger: driverLog,
Args: []string{"sh", "-ec", "python -m lm_eval --output_path ./output --model test --model_args arg1=value1 --tasks task1,task2"},
ReportInterval: time.Second * 5,
})
assert.Nil(t, err)

assert.Nil(t, driver.Run())
assert.Equal(t, []string{
"update status from the driver: running",
"failed to detect available device(s): exit status 1",
}, progresssServer.progressMsgs)

server.Stop()

// the following files don't exist for this case
assert.NotNil(t, os.Remove("./stderr.log"))
assert.NotNil(t, os.Remove("./stdout.log"))
}

func Test_PatchDevice(t *testing.T) {
driverOpt := DriverOption{
Context: context.Background(),
JobNamespace: "fms-lm-eval-service-system",
JobName: "evaljob-sample",
GrpcService: "localhost",
GrpcPort: 8082,
OutputPath: ".",
DetectDevice: true,
Logger: driverLog,
Args: []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"},
ReportInterval: time.Second * 5,
}

// append `--device cuda`
patchDevice(driverOpt.Args, true)
assert.Equal(t,
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --device cuda",
driverOpt.Args[2],
)

// append `--device cpu`
driverOpt.Args = []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}
patchDevice(driverOpt.Args, false)
assert.Equal(t,
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --device cpu",
driverOpt.Args[2],
)

// no change because `--device cpu` exists
driverOpt.Args = []string{"sh", "-ec", "python -m lm_eval --device cpu --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}
patchDevice(driverOpt.Args, true)
assert.Equal(t,
"python -m lm_eval --device cpu --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2",
driverOpt.Args[2],
)
}
7 changes: 5 additions & 2 deletions controllers/lmes/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ var (
"DriverReportInterval": DriverReportIntervalKey,
"DefaultBatchSize": DefaultBatchSizeKey,
"MaxBatchSize": MaxBatchSizeKey,
"DetectDevice": DetectDeviceKey,
}
)

Expand Down Expand Up @@ -101,6 +102,7 @@ type ServiceOptions struct {
GrpcClientSecret string
MaxBatchSize int
DefaultBatchSize int
DetectDevice bool
grpcTLSMode TLSMode
}

Expand Down Expand Up @@ -303,6 +305,7 @@ func (r *LMEvalJobReconciler) constructOptionsFromConfigMap(
GrpcServerSecret: DefaultGrpcServerSecret,
GrpcClientSecret: DefaultGrpcClientSecret,
MaxBatchSize: DefaultMaxBatchSize,
DetectDevice: DefaultDetectDevice,
DefaultBatchSize: DefaultBatchSize,
}

Expand Down Expand Up @@ -679,8 +682,7 @@ func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr
}

cmds := make([]string, 0, 10)
// FIXME: use CPU for now
cmds = append(cmds, "python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output", "--device", "cpu")
cmds = append(cmds, "python", "-m", "lm_eval", "--output_path", "/opt/app-root/src/output")
// --model
cmds = append(cmds, "--model", job.Spec.Model)
// --model_args
Expand Down Expand Up @@ -732,6 +734,7 @@ func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string
"--grpc-service", fmt.Sprintf("%s.%s.svc", r.options.GrpcService, r.Namespace),
"--grpc-port", strconv.Itoa(r.options.GrpcPort),
"--output-path", "/opt/app-root/src/output",
"--detect-device", fmt.Sprintf("%t", r.options.DetectDevice),
"--report-interval", r.options.DriverReportInterval.String(),
"--",
}
Expand Down
6 changes: 3 additions & 3 deletions controllers/lmes/lmevaljob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,22 +760,22 @@ func Test_GenerateArgBatchSize(t *testing.T) {
// no batchSize in the job, use default batchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 8",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 8",
}, lmevalRec.generateArgs(job, log))

// exceed the max-batch-size, use max-batch-size
var biggerBatchSize = 30
job.Spec.BatchSize = &biggerBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 24",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 24",
}, lmevalRec.generateArgs(job, log))

// normal batchSize
var normalBatchSize = 16
job.Spec.BatchSize = &normalBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 16",
"python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 16",
}, lmevalRec.generateArgs(job, log))
}
Loading