Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Enable fixed_arch on Retiarii #3972

Merged
merged 6 commits into from
Jul 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/en_US/NAS/ApiReference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,6 @@ Retiarii Experiments
Utilities
---------

.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.serialize

.. autofunction:: nni.retiarii.fixed_arch
8 changes: 7 additions & 1 deletion docs/en_US/NAS/OneshotTrainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an
trainer.fit()
final_architecture = trainer.export()

**Format of the exported architecture.** TBD.
After the searching is done, we can use the exported architecture to instantiate the full network for retraining. Here is an example:

.. code-block:: python

from nni.retiarii import fixed_arch
with fixed_arch('/path/to/checkpoint.json'):
model = Model()
2 changes: 1 addition & 1 deletion docs/en_US/NAS/WriteOneshot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.name = layer_choice.label
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)

Expand Down
8 changes: 4 additions & 4 deletions examples/nas/oneshot/darts/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

import ops
from nni.nas.pytorch import mutables
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice


class AuxiliaryHead(nn.Module):
Expand Down Expand Up @@ -45,17 +45,17 @@ def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(OrderedDict([
LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
("sepconv3x3", ops.SepConv(channels, channels, 3, stride, 1, affine=False)),
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
]), label=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
self.input_switch = InputChoice(n_candidates=len(choice_keys), n_chosen=2, label="{}_switch".format(node_id))

def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
Expand Down
6 changes: 3 additions & 3 deletions examples/nas/oneshot/darts/retrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from nni.retiarii import fixed_arch

logger = logging.getLogger('nni')

Expand Down Expand Up @@ -119,8 +119,8 @@ def validate(config, valid_loader, model, criterion, epoch, cur_step):
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)

model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint)
with fixed_arch(args.arc_checkpoint):
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
criterion = nn.CrossEntropyLoss()

model.to(device)
Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .operation import Operation
from .graph import *
from .execution import *
from .fixed import fixed_arch
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
40 changes: 40 additions & 0 deletions nni/retiarii/fixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any

from .utils import ContextStack

_logger = logging.getLogger(__name__)


def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,

.. code-block:: python

with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)

Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True

Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""

if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)

if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)

return ContextStack('fixed', fixed_arch)
9 changes: 5 additions & 4 deletions nni/retiarii/oneshot/pytorch/darts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy
import logging
from collections import OrderedDict

import torch
import torch.nn as nn
Expand All @@ -18,8 +19,8 @@
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.name = layer_choice.label
self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names]))
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)

def forward(self, *args, **kwargs):
Expand All @@ -38,13 +39,13 @@ def named_parameters(self):
yield name, p

def export(self):
return torch.argmax(self.alpha).item()
return list(self.op_choices.keys())[torch.argmax(self.alpha).item()]


class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.name = input_choice.label
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1

Expand Down