Skip to content

Commit

Permalink
Merge branch 'dev-1.x' into 1.x
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Dec 30, 2022
2 parents e9f9bb2 + 0d8f918 commit c7ec630
Show file tree
Hide file tree
Showing 220 changed files with 9,805 additions and 452 deletions.
6 changes: 3 additions & 3 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ workflows:
- lint
- build_cpu_with_3rdparty:
name: maximum_version_cpu
torch: 1.12.1
torchvision: 0.13.1
python: 3.9.0
torch: 1.13.0
torchvision: 0.14.0
python: 3.10.0
requires:
- minimum_version_cpu
- hold:
Expand Down
48 changes: 28 additions & 20 deletions .dev_scripts/benchmark_regression/1-benchmark_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from mmengine import Config, DictAction, MMLogger
from mmengine.dataset import Compose, default_collate
from mmengine.fileio import FileClient
from mmengine.runner import Runner
from mmengine.runner import Runner, load_checkpoint
from modelindex.load_model_index import load
from rich.console import Console
from rich.table import Table

from mmcls.apis import init_model
from mmcls.datasets import CIFAR10, CIFAR100, ImageNet
from mmcls.utils import register_all_modules
from mmcls.visualization import ClsVisualizer
Expand Down Expand Up @@ -82,37 +83,44 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# build the data pipeline
test_dataset = cfg.test_dataloader.dataset
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
# The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))

data = Compose(test_dataset.pipeline)({'img_path': args.img})
data = default_collate([data] * args.batch_size)
resolution = tuple(data['inputs'].shape[-2:])

runner: Runner = Runner.from_cfg(cfg)
model = runner.model
if 'test_dataloader' in cfg:
# build the data pipeline
test_dataset = cfg.test_dataloader.dataset
if test_dataset.pipeline[0]['type'] != 'LoadImageFromFile':
test_dataset.pipeline.insert(0, dict(type='LoadImageFromFile'))
if test_dataset.type in ['CIFAR10', 'CIFAR100']:
# The image shape of CIFAR is (32, 32, 3)
test_dataset.pipeline.insert(1, dict(type='Resize', scale=32))

data = Compose(test_dataset.pipeline)({'img_path': args.img})
data = default_collate([data] * args.batch_size)
resolution = tuple(data['inputs'].shape[-2:])
model = Runner.from_cfg(cfg).model
forward = model.val_step
else:
# For configs only for get model.
model = init_model(cfg)
load_checkpoint(model, checkpoint, map_location='cpu')
data = torch.empty(1, 3, 224, 224).to(model.data_preprocessor.device)
resolution = (224, 224)
forward = model.extract_feat

# forward the model
result = {'resolution': resolution}
with torch.no_grad():
if args.inference_time:
time_record = []
for _ in range(10):
model.val_step(data) # warmup before profiling
forward(data) # warmup before profiling
torch.cuda.synchronize()
start = time()
model.val_step(data)
forward(data)
torch.cuda.synchronize()
time_record.append((time() - start) / args.batch_size * 1000)
result['time_mean'] = np.mean(time_record[1:-1])
result['time_std'] = np.std(time_record[1:-1])
else:
model.val_step(data)
forward(data)

result['model'] = config_file.stem

Expand Down Expand Up @@ -144,8 +152,8 @@ def show_summary(summary_data, args):
if args.inference_time:
table.add_column('Inference Time (std) (ms/im)')
if args.flops:
table.add_column('Flops', justify='right', width=11)
table.add_column('Params', justify='right')
table.add_column('Flops', justify='right', width=13)
table.add_column('Params', justify='right', width=11)

for model_name, summary in summary_data.items():
row = [model_name]
Expand Down
186 changes: 186 additions & 0 deletions .dev_scripts/ckpt_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import argparse
import math
from pathlib import Path

import torch
from rich.console import Console

console = Console()

prog_description = """\
Draw the state dict tree.
"""


def parse_args():
parser = argparse.ArgumentParser(description=prog_description)
parser.add_argument(
'path',
type=Path,
help='The path of the checkpoint or model config to draw.')
parser.add_argument('--depth', type=int, help='The max depth to draw.')
parser.add_argument(
'--full-name',
action='store_true',
help='Whether to print the full name of the key.')
parser.add_argument(
'--shape',
action='store_true',
help='Whether to print the shape of the parameter.')
parser.add_argument(
'--state-key',
type=str,
help='The key of the state dict in the checkpoint.')
parser.add_argument(
'--number',
action='store_true',
help='Mark all parameters and their index number.')
parser.add_argument(
'--node',
type=str,
help='Show the sub-tree of a node, like "backbone.layers".')
args = parser.parse_args()
return args


def ckpt_to_state_dict(checkpoint, key=None):
if key is not None:
state_dict = checkpoint[key]
elif 'state_dict' in checkpoint:
# try mmcls style
state_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
state_dict = checkpoint['model']
elif isinstance(next(iter(checkpoint.values())), torch.Tensor):
# try native style
state_dict = checkpoint
else:
raise KeyError('Please specify the key of state '
f'dict from {list(checkpoint.keys())}.')
return state_dict


class StateDictTree:

def __init__(self, key='', value=None):
self.children = {}
self.key: str = key
self.value = value

def add_parameter(self, key, value):
keys = key.split('.', 1)
if len(keys) == 1:
self.children[key] = StateDictTree(key, value)
elif keys[0] in self.children:
self.children[keys[0]].add_parameter(keys[1], value)
else:
node = StateDictTree(keys[0])
node.add_parameter(keys[1], value)
self.children[keys[0]] = node

def __getitem__(self, key: str):
return self.children[key]

def __repr__(self) -> str:
with console.capture() as capture:
for line in self.iter_tree():
console.print(line)
return capture.get()

def __len__(self):
return len(self.children)

def draw_tree(self,
max_depth=None,
full_name=False,
with_shape=False,
with_value=False):
for line in self.iter_tree(
max_depth=max_depth,
full_name=full_name,
with_shape=with_shape,
with_value=with_value,
):
console.print(line, highlight=False)

def iter_tree(
self,
lead='',
prefix='',
max_depth=None,
full_name=False,
with_shape=False,
with_value=False,
):
if self.value is None:
key_str = f'[blue]{self.key}[/]'
elif with_shape:
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}'
elif with_value:
key_str = f'[green]{self.key}[/] {self.value}'
else:
key_str = f'[green]{self.key}[/]'

yield lead + prefix + key_str

lead = lead.replace('├─', '│ ')
lead = lead.replace('└─', ' ')
if self.key and full_name:
prefix = f'{prefix}{self.key}.'

if max_depth == 0:
return
elif max_depth is not None:
max_depth -= 1

for i, child in enumerate(self.children.values()):
level_lead = '├─' if i < len(self.children) - 1 else '└─'
yield from child.iter_tree(
lead=f'{lead}{level_lead} ',
prefix=prefix,
max_depth=max_depth,
full_name=full_name,
with_shape=with_shape,
with_value=with_value)


def main():
args = parse_args()
if args.path.suffix in ['.json', '.py', '.yml']:
from mmengine.runner import get_state_dict

from mmcls.apis import init_model
model = init_model(args.path, device='cpu')
state_dict = get_state_dict(model)
else:
ckpt = torch.load(args.path, map_location='cpu')
state_dict = ckpt_to_state_dict(ckpt, args.state_key)

root = StateDictTree()
for k, v in state_dict.items():
root.add_parameter(k, v)

para_index = 0
mark_width = math.floor(math.log(len(state_dict), 10) + 1)
if args.node is not None:
for key in args.node.split('.'):
root = root[key]

for line in root.iter_tree(
max_depth=args.depth,
full_name=args.full_name,
with_shape=args.shape,
):
if not args.number:
mark = ''
# A hack method to determine whether a line is parameter.
elif '[green]' in line:
mark = f'[red]({str(para_index).ljust(mark_width)})[/]'
para_index += 1
else:
mark = ' ' * (mark_width + 2)
console.print(mark + line, highlight=False)


if __name__ == '__main__':
main()
Loading

0 comments on commit c7ec630

Please sign in to comment.