diff --git a/setup.py b/setup.py index aa1f6766d1..2899e7a7e7 100644 --- a/setup.py +++ b/setup.py @@ -133,8 +133,10 @@ def _parse_requirements_file(file_path): "opencv-python<=4.6.0.66", ] _openpifpaf_integration_deps = [ - "openpifpaf==0.13.6", + "openpifpaf==0.13.11", "opencv-python<=4.6.0.66", + "pycocotools >=2.0.6", + "scipy==1.10.1", ] _yolov8_integration_deps = _yolo_integration_deps + ["ultralytics==8.0.30"] @@ -277,6 +279,7 @@ def _setup_entry_points() -> Dict: "deepsparse.yolov8.annotate=deepsparse.yolov8.annotate:main", "deepsparse.yolov8.eval=deepsparse.yolov8.validation:main", "deepsparse.pose_estimation.annotate=deepsparse.open_pif_paf.annotate:main", + "deepsparse.pose_estimation.eval=deepsparse.open_pif_paf.validation:main", "deepsparse.image_classification.annotate=deepsparse.image_classification.annotate:main", # noqa E501 "deepsparse.instance_segmentation.annotate=deepsparse.yolact.annotate:main", f"deepsparse.image_classification.eval={ic_eval}", diff --git a/src/deepsparse/open_pif_paf/README.md b/src/deepsparse/open_pif_paf/README.md index 456f92ad4d..bd6fb25728 100644 --- a/src/deepsparse/open_pif_paf/README.md +++ b/src/deepsparse/open_pif_paf/README.md @@ -10,14 +10,70 @@ DeepSparse pipeline for OpenPifPaf ```python from deepsparse import Pipeline -model_path: str = ... # path to open_pif_paf model -pipeline = Pipeline.create(task="open_pif_paf", batch_size=1, model_path=model_path) +model_path: str = ... # path to open_pif_paf model (SparseZoo stub or onnx model) +pipeline = Pipeline.create(task="open_pif_paf", model_path=model_path) predictions = pipeline(images=['dancers.jpg']) # predictions have attributes `data', 'keypoints', 'scores', 'skeletons' predictions[0].scores >> scores=[0.8542259724243828, 0.7930507659912109] ``` -# predictions have attributes `data', 'keypoints', 'scores', 'skeletons' +### Output CifCaf fields +Alternatively, instead of returning the detected poses, it is possible to return the intermediate output - the CifCaf fields. +This is the representation returned directly by the neural network, but not yet processed by the matching algorithm + +```python +... +pipeline = Pipeline.create(task="open_pif_paf", model_path=model_path, return_cifcaf_fields=True) +predictions = pipeline(images=['dancers.jpg']) +predictions.fields +``` + +## Validation script: +This paragraph describes how to run validation of the ONNX model/SparseZoo stub + +### Dataset +For evaluation, you need to download the dataset. The [Open Pif Paf documentation](https://openpifpaf.github.io/) describes +thoroughly how to prepare different datasets for validation. This is the example for `crowdpose` dataset: + +```bash +mkdir data-crowdpose +cd data-crowdpose +# download links here: https://github.com/Jeff-sjtu/CrowdPose +unzip annotations.zip +unzip images.zip +# Now you can use the standard openpifpaf.train and openpifpaf.eval +# commands as documented in Training with --dataset=crowdpose. +``` +### Create an ONNX model: + +```bash +python3 -m openpifpaf.export_onnx --input-width 641 --input-height 641 +``` + +### Validation command +Once the dataset has been downloaded, run the command: +```bash +deepsparse.pose_estimation.eval --model-path openpifpaf-resnet50.onnx --dataset cocokp --image_size 641 +``` + +This should result in the evaluation output similar to this: +```bash +... + Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.502 + Average Precision (AP) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.732 + Average Precision (AP) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.523 + Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.429 + Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.605 + Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 20 ] = 0.534 + Average Recall (AR) @[ IoU=0.50 | area= all | maxDets= 20 ] = 0.744 + Average Recall (AR) @[ IoU=0.75 | area= all | maxDets= 20 ] = 0.554 + Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.457 + Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.643 +... +```` + + +### Expected output: ## The necessity of external OpenPifPaf helper function image diff --git a/src/deepsparse/open_pif_paf/__init__.py b/src/deepsparse/open_pif_paf/__init__.py index 0c44f887a4..8d3ec2e88e 100644 --- a/src/deepsparse/open_pif_paf/__init__.py +++ b/src/deepsparse/open_pif_paf/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa +from .utils import * diff --git a/src/deepsparse/open_pif_paf/utils/__init__.py b/src/deepsparse/open_pif_paf/utils/__init__.py index be2130c93b..0873a8ab4d 100644 --- a/src/deepsparse/open_pif_paf/utils/__init__.py +++ b/src/deepsparse/open_pif_paf/utils/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. # flake8: noqa from .annotate import * +from .validation import * diff --git a/src/deepsparse/open_pif_paf/utils/validation/__init__.py b/src/deepsparse/open_pif_paf/utils/validation/__init__.py new file mode 100644 index 0000000000..5280e7e84b --- /dev/null +++ b/src/deepsparse/open_pif_paf/utils/validation/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .cli import * +from .deepsparse_evaluator import * diff --git a/src/deepsparse/open_pif_paf/utils/validation/cli.py b/src/deepsparse/open_pif_paf/utils/validation/cli.py new file mode 100644 index 0000000000..b92c839d06 --- /dev/null +++ b/src/deepsparse/open_pif_paf/utils/validation/cli.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +import torch +from deepsparse.open_pif_paf.utils.validation.deepsparse_evaluator import ( + DeepSparseEvaluator, +) +from deepsparse.open_pif_paf.utils.validation.deepsparse_predictor import ( + DeepSparsePredictor, +) +from openpifpaf import __version__, datasets, decoder, logger, network, show, visualizer +from openpifpaf.eval import CustomFormatter + + +LOG = logging.getLogger(__name__) + +__all__ = ["cli"] + + +# adapted from OPENPIFPAF GITHUB: +# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/eval.py +# the appropriate edits are marked with # deepsparse edit: +def cli(): + parser = argparse.ArgumentParser( + prog="python3 -m openpifpaf.eval", + usage="%(prog)s [options]", + description=__doc__, + formatter_class=CustomFormatter, + ) + parser.add_argument( + "--version", + action="version", + version="OpenPifPaf {version}".format(version=__version__), + ) + + datasets.cli(parser) + decoder.cli(parser) + logger.cli(parser) + network.Factory.cli(parser) + DeepSparsePredictor.cli(parser, skip_batch_size=True, skip_loader_workers=True) + show.cli(parser) + visualizer.cli(parser) + DeepSparseEvaluator.cli(parser) + + parser.add_argument("--disable-cuda", action="store_true", help="disable CUDA") + parser.add_argument( + "--output", default=None, help="output filename without file extension" + ) + parser.add_argument( + "--watch", + default=False, + const=60, + nargs="?", + type=int, + help=( + "Watch a directory for new checkpoint files. " + "Optionally specify the number of seconds between checks." + ), + ) + + # deepsparse edit: replace the parse_args call with parse_known_args + args, _ = parser.parse_known_args() + + # add args.device + args.device = torch.device("cpu") + args.pin_memory = False + if not args.disable_cuda and torch.cuda.is_available(): + args.device = torch.device("cuda") + args.pin_memory = True + LOG.debug("neural network device: %s", args.device) + + datasets.configure(args) + decoder.configure(args) + network.Factory.configure(args) + DeepSparsePredictor.configure(args) + show.configure(args) + visualizer.configure(args) + DeepSparseEvaluator.configure(args) + + return args diff --git a/src/deepsparse/open_pif_paf/utils/validation/deepsparse_decoder.py b/src/deepsparse/open_pif_paf/utils/validation/deepsparse_decoder.py new file mode 100644 index 0000000000..8e2a9af102 --- /dev/null +++ b/src/deepsparse/open_pif_paf/utils/validation/deepsparse_decoder.py @@ -0,0 +1,75 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import List, Union + +from deepsparse import Pipeline +from deepsparse.open_pif_paf.utils.validation.helpers import deepsparse_fields_to_torch +from openpifpaf.decoder import CifCaf +from openpifpaf.decoder.decoder import DummyPool + + +LOG = logging.getLogger(__name__) + +__all__ = ["DeepSparseCifCaf"] + + +class DeepSparseCifCaf(CifCaf): + def __init__( + self, + head_metas: List[Union["Cif", "Caf"]], # noqa: F821 + pipeline: Pipeline, + ): + self.pipeline = pipeline + cif_metas, caf_metas = head_metas + super().__init__([cif_metas], [caf_metas]) + + # adapted from OPENPIFPAF GITHUB: + # https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/decoder/decoder.py + # the appropriate edits are marked with # deepsparse edit: + + # deepsparse edit: removed model argument (not needed, substituted with '_') + def batch(self, _, image_batch, *, device=None, gt_anns_batch=None): + """From image batch straight to annotations batch.""" + start_nn = time.perf_counter() + # deepsparse edit: inference using deepsparse pipeline + # instead of torch model + fields_batch = deepsparse_fields_to_torch( + self.pipeline(images=image_batch.numpy()) + ) + self.last_nn_time = time.perf_counter() - start_nn + + if gt_anns_batch is None: + gt_anns_batch = [None for _ in fields_batch] + + if not isinstance(self.worker_pool, DummyPool): + # remove debug_images to save time during pickle + image_batch = [None for _ in fields_batch] + gt_anns_batch = [None for _ in fields_batch] + + LOG.debug("parallel execution with worker %s", self.worker_pool) + start_decoder = time.perf_counter() + result = self.worker_pool.starmap( + self._mappable_annotations, zip(fields_batch, image_batch, gt_anns_batch) + ) + self.last_decoder_time = time.perf_counter() - start_decoder + + LOG.debug( + "time: nn = %.1fms, dec = %.1fms", + self.last_nn_time * 1000.0, + self.last_decoder_time * 1000.0, + ) + return result diff --git a/src/deepsparse/open_pif_paf/utils/validation/deepsparse_evaluator.py b/src/deepsparse/open_pif_paf/utils/validation/deepsparse_evaluator.py new file mode 100644 index 0000000000..2fa86c7a58 --- /dev/null +++ b/src/deepsparse/open_pif_paf/utils/validation/deepsparse_evaluator.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import sys +import typing as t +from collections import defaultdict + +from deepsparse import Pipeline +from deepsparse.open_pif_paf.utils.validation.deepsparse_predictor import ( + DeepSparsePredictor, +) +from deepsparse.open_pif_paf.utils.validation.helpers import ( + apply_deepsparse_preprocessing, +) +from openpifpaf import __version__, network +from openpifpaf.eval import Evaluator + + +LOG = logging.getLogger(__name__) + +__all__ = ["DeepSparseEvaluator"] + + +# adapted from OPENPIFPAF GITHUB: +# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/eval.py +# the appropriate edits are marked with # deepsparse edit: +class DeepSparseEvaluator(Evaluator): + # deepsparse edit: allow for passing in a pipeline + def __init__(self, pipeline: Pipeline, img_size: int, **kwargs): + self.pipeline = pipeline + super().__init__(**kwargs) + # deepsparse edit: required to enforce square images + apply_deepsparse_preprocessing(self.data_loader, img_size) + + def evaluate(self, output: t.Optional[str]): + # generate a default output filename + if output is None: + assert self.args is not None + output = self.default_output_name(self.args) + + # skip existing? + if self.skip_epoch0: + assert network.Factory.checkpoint is not None + if network.Factory.checkpoint.endswith(".epoch000"): + print("Not evaluating epoch 0.") + return + if self.skip_existing: + stats_file = output + ".stats.json" + if os.path.exists(stats_file): + print("Output file {} exists already. Exiting.".format(stats_file)) + return + print( + "{} not found. Processing: {}".format( + stats_file, network.Factory.checkpoint + ) + ) + + # deepsparse edit: allow for passing in a pipeline + predictor = DeepSparsePredictor( + pipeline=self.pipeline, head_metas=self.datamodule.head_metas + ) + metrics = self.datamodule.metrics() + + total_time = self.accumulate(predictor, metrics) + + # model stats + # deepsparse edit: removed model stats that are + # only applicable to torch models + + # write + additional_data = { + "args": sys.argv, + "version": __version__, + "dataset": self.dataset_name, + "total_time": total_time, + "n_images": predictor.total_images, + "decoder_time": predictor.total_decoder_time, + "nn_time": predictor.total_nn_time, + } + + metric_stats = defaultdict(list) + for metric in metrics: + if self.write_predictions: + metric.write_predictions(output, additional_data=additional_data) + + this_metric_stats = metric.stats() + assert len(this_metric_stats.get("text_labels", [])) == len( + this_metric_stats.get("stats", []) + ) + + for k, v in this_metric_stats.items(): + metric_stats[k] = metric_stats[k] + v + + stats = dict(**metric_stats, **additional_data) + + # write stats file + with open(output + ".stats.json", "w") as f: + json.dump(stats, f) + + LOG.info("stats:\n%s", json.dumps(stats, indent=4)) + LOG.info( + "time per image: decoder = %.0fms, nn = %.0fms, total = %.0fms", + 1000 * stats["decoder_time"] / stats["n_images"], + 1000 * stats["nn_time"] / stats["n_images"], + 1000 * stats["total_time"] / stats["n_images"], + ) diff --git a/src/deepsparse/open_pif_paf/utils/validation/deepsparse_predictor.py b/src/deepsparse/open_pif_paf/utils/validation/deepsparse_predictor.py new file mode 100644 index 0000000000..b22af1dc96 --- /dev/null +++ b/src/deepsparse/open_pif_paf/utils/validation/deepsparse_predictor.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from deepsparse import Pipeline +from deepsparse.open_pif_paf.utils.validation.deepsparse_decoder import DeepSparseCifCaf +from openpifpaf import Predictor + + +LOG = logging.getLogger(__name__) + +__all__ = ["DeepSparsePredictor"] + + +# adapted from OPENPIFPAF GITHUB: +# https://github.com/openpifpaf/openpifpaf/blob/main/src/openpifpaf/predictor.py +# the appropriate edits are marked with # deepsparse edit: +class DeepSparsePredictor(Predictor): + """ + Convenience class to predict from various + inputs with a common configuration. + """ + + # deepsparse edit: allow for passing in a pipeline + def __init__(self, pipeline: Pipeline, **kwargs): + super().__init__(**kwargs) + # deepsparse edit: allow for passing in a pipeline and fix the processor + # to CifCaf processor. Note: we are creating here a default torch model + # but we only use it to get its head metas. This is required to + # initialize the DeepSparseCifCaf processor. + self.processor = DeepSparseCifCaf( + pipeline=pipeline, head_metas=self.model_cpu.head_metas + ) diff --git a/src/deepsparse/open_pif_paf/utils/validation/helpers.py b/src/deepsparse/open_pif_paf/utils/validation/helpers.py new file mode 100644 index 0000000000..36c43c4996 --- /dev/null +++ b/src/deepsparse/open_pif_paf/utils/validation/helpers.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +from typing import List + +import torch +from deepsparse.open_pif_paf.schemas import OpenPifPafFields +from openpifpaf import transforms + + +LOG = logging.getLogger(__name__) + +__all__ = ["apply_deepsparse_preprocessing", "deepsparse_fields_to_torch"] + + +def deepsparse_fields_to_torch( + fields_batch: OpenPifPafFields, device="cpu" +) -> List[List[torch.Tensor]]: + """ + Convert a batch of fields from the deepsparse + openpifpaf fields schema to torch tensors + + :param fields_batch: the batch of fields to convert + :param device: the device to move the tensors to + :return: a list of lists of torch tensors. The first + list is the batch dimension, the second list + contains two tensors: Cif and Caf field values + """ + return [ + [ + torch.from_numpy(array).to(device) + for field in fields_batch.fields + for array in field + ] + ] + + +def apply_deepsparse_preprocessing( + data_loader: torch.utils.data.DataLoader, img_size: int +) -> torch.utils.data.DataLoader: + """ + Replace the CenterPadTight transform in the data loader + with a CenterPad transform to ensure that the images + from the data loader are (B, 3, D, D) where D is + the img_size. This function changes `data_loader` + in place + + :param data_loader: the data loader to modify + :param img_size: the image size to pad to + """ + data_loader.dataset.preprocess.preprocess_list[2] = transforms.CenterPad(img_size) diff --git a/src/deepsparse/open_pif_paf/validation.py b/src/deepsparse/open_pif_paf/validation.py new file mode 100644 index 0000000000..9ca4ae70f5 --- /dev/null +++ b/src/deepsparse/open_pif_paf/validation.py @@ -0,0 +1,135 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Optional + +import click + +from deepsparse import Pipeline +from deepsparse.open_pif_paf.utils.validation import DeepSparseEvaluator, cli + + +DEEPSPARSE_ENGINE = "deepsparse" +ORT_ENGINE = "onnxruntime" +SUPPORTED_DATASET_CONFIGS = ["cocokp"] + +logging.basicConfig(level=logging.INFO) + + +@click.command( + context_settings=( + dict(token_normalize_func=lambda x: x.replace("-", "_"), show_default=True) + ) +) +@click.option( + "--model-path", + required=True, + help="Path to the OpenPifPaf onnx model or" "SparseZoo stub to be evaluated.", +) +@click.option( + "--dataset", + type=str, + default="cocokp", + show_default=True, + help="Dataset name supported by the openpifpaf framework. ", +) +@click.option( + "--num-cores", + type=int, + default=None, + show_default=True, + help="Number of CPU cores to run deepsparse with, default is all available", +) +@click.option( + "--image_size", + type=int, + default=641, + show_default=True, + help="Image size to use for evaluation. Will " + "be used to resize images to the same size " + "(B, C, image_size, image_size)", +) +@click.option( + "--name-validation-run", + type=str, + default="openpifpaf_validation", + show_default=True, + help="Name of the validation run, used for" "creating a file to store the results", +) +@click.option( + "--engine-type", + default=DEEPSPARSE_ENGINE, + type=click.Choice([DEEPSPARSE_ENGINE, ORT_ENGINE]), + show_default=True, + help="engine type to use, valid choices: ['deepsparse', 'onnxruntime']", +) +@click.option( + "--device", + default="cuda", + type=str, + show_default=True, + help="Use 'device=cpu' or pass valid CUDA device(s) if available, " + "i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU", +) +def main( + model_path: str, + dataset: str, + num_cores: Optional[int], + image_size: int, + name_validation_run: str, + engine_type: str, + device: str, +): + + if dataset not in SUPPORTED_DATASET_CONFIGS: + raise ValueError( + f"Dataset {dataset} is not supported. " + f"Supported datasets are {SUPPORTED_DATASET_CONFIGS}" + ) + args = cli() + args.dataset = dataset + args.output = name_validation_run + args.device = device + + if dataset == "cocokp": + # eval for coco keypoints dataset + args.coco_eval_long_edge = image_size + + pipeline = Pipeline.create( + task="open_pif_paf", + model_path=model_path, + engine_type=engine_type, + num_cores=num_cores, + image_size=image_size, + return_cifcaf_fields=True, + ) + + evaluator = DeepSparseEvaluator( + pipeline=pipeline, + dataset_name=args.dataset, + skip_epoch0=False, + img_size=image_size, + ) + if args.watch: + # this pathway has not been tested + # and is not supported + assert args.output is None + evaluator.watch(args.checkpoint, args.watch) + else: + evaluator.evaluate(args.output) + + +if __name__ == "__main__": + main()