Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] NAS-Bench-101 #3871

Merged
merged 18 commits into from
Jul 12, 2021
51 changes: 51 additions & 0 deletions examples/nas/multi-trial/nasbench101/base_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import math

import torch.nn as nn


def truncated_normal_(tensor, mean=0, std=1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)


class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(ConvBnRelu, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.reset_parameters()

def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in))
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def forward(self, x):
return self.conv_bn_relu(x)


class Conv3x3BnRelu(ConvBnRelu):
def __init__(self, in_channels, out_channels):
super(Conv3x3BnRelu, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1)


class Conv1x1BnRelu(ConvBnRelu):
def __init__(self, in_channels, out_channels):
super(Conv1x1BnRelu, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0)


Projection = Conv1x1BnRelu
173 changes: 173 additions & 0 deletions examples/nas/multi-trial/nasbench101/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import click
import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.retiarii.strategy import Random
from pytorch_lightning.callbacks import LearningRateMonitor
from timm.optim import RMSpropTF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.datasets import CIFAR10

from base_ops import Conv3x3BnRelu, Conv1x1BnRelu, Projection


@model_wrapper
class NasBench101(nn.Module):
def __init__(self,
stem_out_channels: int = 128,
num_stacks: int = 3,
num_modules_per_stack: int = 3,
max_num_vertices: int = 7,
max_num_edges: int = 9,
num_labels: int = 10,
bn_eps: float = 1e-5,
bn_momentum: float = 0.003):
super().__init__()

op_candidates = {
'conv3x3': lambda num_features: Conv3x3BnRelu(num_features, num_features),
'conv1x1': lambda num_features: Conv1x1BnRelu(num_features, num_features),
'maxpool': lambda num_features: nn.MaxPool2d(3, 1, 1)
}

# initial stem convolution
self.stem_conv = Conv3x3BnRelu(3, stem_out_channels)

layers = []
in_channels = out_channels = stem_out_channels
for stack_num in range(num_stacks):
if stack_num > 0:
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
layers.append(downsample)
out_channels *= 2
for _ in range(num_modules_per_stack):
cell = NasBench101Cell(op_candidates, in_channels, out_channels,
lambda cin, cout: Projection(cin, cout),
max_num_vertices, max_num_edges, label='cell')
layers.append(cell)
in_channels = out_channels

self.features = nn.ModuleList(layers)
self.gap = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(out_channels, num_labels)

for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = bn_eps
module.momentum = bn_momentum

def forward(self, x):
bs = x.size(0)
out = self.stem_conv(x)
for layer in self.features:
out = layer(out)
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out

def reset_parameters(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = self.config.bn_eps
module.momentum = self.config.bn_momentum


class AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)


@serialize_cls
class NasBench101TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4):
super().__init__()
self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs')
self.criterion = nn.CrossEntropyLoss()
self.accuracy = AccuracyWithLogits()

def forward(self, x):
y_hat = self.model(x)
return y_hat

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True)

def configure_optimizers(self):
optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
momentum=0.9, alpha=0.9, eps=1.0)
return {
'optimizer': optimizer,
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs)
}

def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item())

def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item())


@click.command()
@click.option('--epochs', default=108, help='Training length.')
@click.option('--batch_size', default=256, help='Batch size.')
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _multi_trial_test(epochs, batch_size, port):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
]
train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize))

# specify training hyper-parameters
training_module = NasBench101TrainingModule(max_epochs=epochs)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer = pl.Trainer(max_epochs=epochs, gpus=1)
lightning = pl.Lightning(
lightning_module=training_module,
trainer=trainer,
train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size),
)

strategy = Random()

model = NasBench101()

exp = RetiariiExperiment(model, lightning, [], strategy)

exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 20
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False

exp.run(exp_config, port)


if __name__ == '__main__':
_multi_trial_test()
10 changes: 7 additions & 3 deletions nni/retiarii/mutator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import (Any, Iterable, List, Optional)
from typing import (Any, Iterable, List, Optional, Tuple)

from .graph import Model, Mutation, ModelStatus


__all__ = ['Sampler', 'Mutator']
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']


Choice = Any
Expand Down Expand Up @@ -77,7 +77,7 @@ def apply(self, model: Model) -> Model:
self._cur_choice_idx = None
return copy

def dry_run(self, model: Model) -> List[List[Choice]]:
def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
Expand Down Expand Up @@ -115,3 +115,7 @@ def __init__(self):
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]


class InvalidMutation(Exception):
pass
40 changes: 28 additions & 12 deletions nni/retiarii/nn/pytorch/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import copy
import warnings
from collections import OrderedDict
from typing import Any, List, Union, Dict, Optional

import torch
import torch.nn as nn

from ...serializer import Translatable, basic_unit
from ...utils import NoContextError
from .utils import generate_new_label, get_fixed_value


Expand All @@ -26,6 +26,8 @@ class LayerChoice(nn.Module):
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the layer choice.

Expand Down Expand Up @@ -55,17 +57,21 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""

def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
# FIXME: prior is designed but not supported yet

def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
try:
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except AssertionError:
except NoContextError:
return super().__new__(cls)

def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
Expand All @@ -75,10 +81,12 @@ def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], lab
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)

self.names = []
if isinstance(candidates, OrderedDict):
if isinstance(candidates, dict):
for name, module in candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
Expand Down Expand Up @@ -173,13 +181,15 @@ class InputChoice(nn.Module):
Identifier of the input choice.
"""

def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
def __new__(cls, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', label: Optional[str] = None, **kwargs):
try:
return ChosenInputs(get_fixed_value(label), reduction=reduction)
except AssertionError:
except NoContextError:
return super().__new__(cls)

def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
def __init__(self, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', label: Optional[str] = None, **kwargs):
super(InputChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
Expand Down Expand Up @@ -277,19 +287,25 @@ def forward(self, x):
----------
candidates : list
List of values to choose from.
prior : list of float
Prior distribution to sample from.
label : str
Identifier of the value choice.
"""

def __new__(cls, candidates: List[Any], label: Optional[str] = None):
# FIXME: prior is designed but not supported yet

def __new__(cls, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved
try:
return get_fixed_value(label)
except AssertionError:
except NoContextError:
return super().__new__(cls)

def __init__(self, candidates: List[Any], label: Optional[str] = None):
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self._accessor = []

Expand Down Expand Up @@ -323,7 +339,7 @@ def __copy__(self):
return self

def __deepcopy__(self, memo):
new_item = ValueChoice(self.candidates, self.label)
new_item = ValueChoice(self.candidates, label=self.label)
new_item._accessor = [*self._accessor]
return new_item

Expand Down
Loading