Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Support aml (#2615)
Browse files Browse the repository at this point in the history
  • Loading branch information
SparkSnail authored Jul 1, 2020
1 parent f5caa19 commit 93f96d4
Show file tree
Hide file tree
Showing 26 changed files with 791 additions and 45 deletions.
66 changes: 66 additions & 0 deletions docs/en_US/TrainingService/AMLMode.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
**Run an Experiment on Azure Machine Learning**
===
NNI supports running an experiment on [AML](https://azure.microsoft.com/en-us/services/machine-learning/) , called aml mode.

## Setup environment
Step 1. Install NNI, follow the install guide [here](../Tutorial/QuickStart.md).

Step 2. Create AML account, follow the document [here](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-workspace-cli).

Step 3. Get your account information.
![](../../img/aml_account.png)

Step4. Install AML package environment.
```
python3 -m pip install azureml --user
python3 -m pip install azureml-sdk --user
```

## Run an experiment
Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like:

```yaml
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: aml
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}

```

Note: You should set `trainingServicePlatform: aml` in NNI config YAML file if you want to start experiment in aml mode.

Compared with [LocalMode](LocalMode.md) trial configuration in aml mode have these additional keys:
* computeTarget
* required key. The computer cluster name you want to use in your AML workspace.
* image
* required key. The docker image name used in job.

amlConfig:
* subscriptionId
* the subscriptionId of your account
* resourceGroup
* the resourceGroup of your account
* workspaceName
* the workspaceName of your account

1 change: 1 addition & 0 deletions docs/en_US/training_services.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Introduction to NNI Training Services
Kubeflow<./TrainingService/KubeflowMode>
FrameworkController<./TrainingService/FrameworkControllerMode>
DLTS<./TrainingService/DLTSMode>
AML<./TrainingService/AMLMode>
Binary file added docs/img/aml_account.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 25 additions & 0 deletions examples/trials/mnist-pytorch/config_aml.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
authorName: default
experimentName: example_mnist_pytorch
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: aml
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
25 changes: 25 additions & 0 deletions examples/trials/mnist-tfv1/config_aml.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: aml
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
56 changes: 56 additions & 0 deletions src/nni_manager/config/aml/amlUtil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys
import time
import json
from argparse import ArgumentParser
from azureml.core import Experiment, RunConfiguration, ScriptRunConfig
from azureml.core.compute import ComputeTarget
from azureml.core.run import RUNNING_STATES, RunStatus, Run
from azureml.core import Workspace
from azureml.core.conda_dependencies import CondaDependencies

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--subscription_id', help='the subscription id of aml')
parser.add_argument('--resource_group', help='the resource group of aml')
parser.add_argument('--workspace_name', help='the workspace name of aml')
parser.add_argument('--compute_target', help='the compute cluster name of aml')
parser.add_argument('--docker_image', help='the docker image of job')
parser.add_argument('--experiment_name', help='the experiment name')
parser.add_argument('--script_dir', help='script directory')
parser.add_argument('--script_name', help='script name')
args = parser.parse_args()

ws = Workspace(args.subscription_id, args.resource_group, args.workspace_name)
compute_target = ComputeTarget(workspace=ws, name=args.compute_target)
experiment = Experiment(ws, args.experiment_name)
run_config = RunConfiguration()
dependencies = CondaDependencies()
dependencies.add_pip_package("azureml-sdk")
dependencies.add_pip_package("azureml")
run_config.environment.python.conda_dependencies = dependencies
run_config.environment.docker.enabled = True
run_config.environment.docker.base_image = args.docker_image
run_config.target = compute_target
run_config.node_count = 1
config = ScriptRunConfig(source_directory=args.script_dir, script=args.script_name, run_config=run_config)
run = experiment.submit(config)
print(run.get_details()["runId"])
while True:
line = sys.stdin.readline().rstrip()
if line == 'update_status':
print('status:' + run.get_status())
elif line == 'tracking_url':
print('tracking_url:' + run.get_portal_url())
elif line == 'stop':
run.cancel()
exit(0)
elif line == 'receive':
print('receive:' + json.dumps(run.get_metrics()))
elif line:
items = line.split(':')
if items[0] == 'command':
run.log('nni_manager', line[8:])
8 changes: 6 additions & 2 deletions src/nni_manager/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(TrainingService)
.to(DLTSTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'aml') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else {
throw new Error(`Error: unsupported mode: ${platformMode}`);
}
Expand Down Expand Up @@ -93,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN

function usage(): void {
console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
}

const strPort: string = parseArg(['--port', '-p']);
Expand All @@ -113,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10);

const mode: string = parseArg(['--mode', '-m']);
if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'].includes(mode)) {
if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`);
usage();
process.exit(1);
Expand Down
1 change: 1 addition & 0 deletions src/nni_manager/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"ignore": "^5.1.4",
"js-base64": "^2.4.9",
"kubernetes-client": "^6.5.0",
"python-shell": "^2.0.1",
"rx": "^4.1.0",
"sqlite3": "^4.0.2",
"ssh2": "^0.6.1",
Expand Down
7 changes: 7 additions & 0 deletions src/nni_manager/rest_server/restValidationSchemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ export namespace ValidationSchemas {
nniManagerNFSMountPath: joi.string().min(1),
containerNFSMountPath: joi.string().min(1),
paiConfigPath: joi.string(),
computeTarget: joi.string(),
nodeCount: joi.number(),
paiStorageConfigName: joi.string().min(1),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
portList: joi.array().items(joi.object({
Expand Down Expand Up @@ -150,6 +152,11 @@ export namespace ValidationSchemas {
email: joi.string().min(1),
password: joi.string().min(1)
}),
aml_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
subscriptionId: joi.string().min(1),
resourceGroup: joi.string().min(1),
workspaceName: joi.string().min(1)
}),
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export enum TrialConfigMetadataKey {
NNI_MANAGER_IP = 'nni_manager_ip',
FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config',
DLTS_CLUSTER_CONFIG = 'dlts_config',
AML_CLUSTER_CONFIG = 'aml_config',
VERSION_CHECK = 'version_check',
LOG_COLLECTION = 'log_collection'
}
125 changes: 125 additions & 0 deletions src/nni_manager/training_service/reusable/aml/amlClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import { Deferred } from 'ts-deferred';
import { PythonShell } from 'python-shell';

export class AMLClient {
public subscriptionId: string;
public resourceGroup: string;
public workspaceName: string;
public experimentId: string;
public image: string;
public scriptName: string;
public pythonShellClient: undefined | PythonShell;
public codeDir: string;
public computeTarget: string;

constructor(
subscriptionId: string,
resourceGroup: string,
workspaceName: string,
experimentId: string,
computeTarget: string,
image: string,
scriptName: string,
codeDir: string,
) {
this.subscriptionId = subscriptionId;
this.resourceGroup = resourceGroup;
this.workspaceName = workspaceName;
this.experimentId = experimentId;
this.image = image;
this.scriptName = scriptName;
this.codeDir = codeDir;
this.computeTarget = computeTarget;
}

public submit(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
this.pythonShellClient = new PythonShell('amlUtil.py', {
scriptPath: './config/aml',
pythonOptions: ['-u'], // get print results in real-time
args: [
'--subscription_id', this.subscriptionId,
'--resource_group', this.resourceGroup,
'--workspace_name', this.workspaceName,
'--compute_target', this.computeTarget,
'--docker_image', this.image,
'--experiment_name', `nni_exp_${this.experimentId}`,
'--script_dir', this.codeDir,
'--script_name', this.scriptName
]
});
this.pythonShellClient.on('message', function (envId: any) {
// received a message sent from the Python script (a simple "print" statement)
deferred.resolve(envId);
});
return deferred.promise;
}

public stop(): void {
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send('stop');
}

public getTrackingUrl(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send('tracking_url');
let trackingUrl = '';
this.pythonShellClient.on('message', function (status: any) {
const items = status.split(':');
if (items[0] === 'tracking_url') {
trackingUrl = items.splice(1, items.length).join('')
}
deferred.resolve(trackingUrl);
});
return deferred.promise;
}

public updateStatus(oldStatus: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
let newStatus = oldStatus;
this.pythonShellClient.send('update_status');
this.pythonShellClient.on('message', function (status: any) {
const items = status.split(':');
if (items[0] === 'status') {
newStatus = items.splice(1, items.length).join('')
}
deferred.resolve(newStatus);
});
return deferred.promise;
}

public sendCommand(message: string): void {
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send(`command:${message}`);
}

public receiveCommand(): Promise<any> {
const deferred: Deferred<any> = new Deferred<any>();
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send('receive');
this.pythonShellClient.on('message', function (command: any) {
const items = command.split(':')
if (items[0] === 'receive') {
deferred.resolve(JSON.parse(command.slice(8)))
}
});
return deferred.promise;
}
}
Loading

0 comments on commit 93f96d4

Please sign in to comment.