From 54562d67ce99111f9d4a5a628679d4848c8464f2 Mon Sep 17 00:00:00 2001 From: Mrinal Jain Date: Mon, 30 Nov 2020 08:06:16 -0500 Subject: [PATCH] Minor fixes --- README.md | 2 +- capstone/training/base_trainer.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index bffaa71..ee89ffb 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Executing the following code will download, extract, and split the dataset. ### Base Requirements 1. Python (3.7) -2. [pynrrd](https://github.com/mhe/pynrrd) - For loading MICCAI data in `.nrrd` format +2. [pynrrd](https://github.com/mhe/pynrrd) (0.4) - For loading MICCAI data in `.nrrd` format 3. Tqdm - For displaying progress bars 4. PyTorch (1.7) 5. Torchvision (0.8) diff --git a/capstone/training/base_trainer.py b/capstone/training/base_trainer.py index 6087a68..0c62453 100644 --- a/capstone/training/base_trainer.py +++ b/capstone/training/base_trainer.py @@ -26,7 +26,7 @@ def __init__( use_res_units: bool = False, downsample: bool = False, lr: float = 1e-3, - loss_fx: list = ["CrossEntropy"], + loss_fx: list = ["Focal", "Dice"], exclude_missing: bool = False, **kwargs, ) -> None: @@ -57,9 +57,6 @@ def __init__( ) self.dice_score = DiceMetricWrapper() - if isinstance(self.logger, WandbLogger): - self.logger.watch(self) - @property def _n_classes(self): return len(miccai.STRUCTURES) + 1 # Additional background @@ -175,7 +172,7 @@ def add_model_specific_args(parent_parser): parser.add_argument( "--use_res_units", action="store_true", - default=True, + default=False, help="For using residual units in UNet.", ) parser.add_argument( @@ -197,7 +194,7 @@ def add_model_specific_args(parent_parser): parser.add_argument( "--exclude_missing", action="store_true", - default=True, + default=False, help="Exclude missing annotations from loss computation (as described in AnatomyNet).", ) parser.add_argument(