Skip to content

Commit

Permalink
readme fix and make cifar10 work off grenoble (#6187)
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett361 committed Mar 9, 2023
1 parent 46e9f74 commit 9563274
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
4 changes: 2 additions & 2 deletions ds_autotuning_prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The basic steps:
In the same directory as this `README`, run the following, for instance:

```bash
python3 -m dsat.autotune examples/minimal_example/autotune_config.yaml examples/minimal_example
python3 -m dsat.autotune examples/ffn_example/autotune_config.yaml examples/ffn_example
```

(the config may need to be altered for your cluster.)
Expand All @@ -43,4 +43,4 @@ TODOs:
- Not all native DS AT code paths are currently supported, e.g. providing explicit batch sizes to use,
fast mode.
- Benchmark against native DS AT. The 0.8.1 update which fixed the DS AT units issues might have also
broken the `--autotuning run` flag?
broken the `--autotuning run` flag?
14 changes: 10 additions & 4 deletions ds_autotuning_prototype/examples/cifar10_example/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,16 @@ def main(
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

trainset = torchvision.datasets.CIFAR10(
root=hparams.data.root, train=True, download=False, transform=transform
)
# Cluster specific code for Grenoble.
try:
# This works on GG's Grenoble setup, but not otherwise.
trainset = torchvision.datasets.CIFAR10(
root=hparams.data.root, train=True, download=False, transform=transform
)
except:
trainset = torchvision.datasets.CIFAR10(
root=".", train=True, download=True, transform=transform
)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ searcher:
smaller_is_better: False
hyperparameters:
dim: 1024
layers: 16 # Total params will be layers * 1024 ** 2
layers: 4 # Total params will be layers * 1024 ** 2
# NOTE: dsat code expects usual DS config dict to appear as in the below.
ds_config:
train_micro_batch_size_per_gpu: 128
Expand Down

0 comments on commit 9563274

Please sign in to comment.