From e78ca86ed60950ceca9802ab912b67c0fd7315d5 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Mon, 19 Jul 2021 14:02:47 +0530 Subject: [PATCH] Refactor wandb operations (#4061) --- utils/wandb_logging/wandb_utils.py | 74 +++++++++++++++++------------- val.py | 7 +-- 2 files changed, 45 insertions(+), 36 deletions(-) diff --git a/utils/wandb_logging/wandb_utils.py b/utils/wandb_logging/wandb_utils.py index 07cb8f27ecd8..a7e84ca100e4 100644 --- a/utils/wandb_logging/wandb_utils.py +++ b/utils/wandb_logging/wandb_utils.py @@ -98,7 +98,14 @@ class WandbLogger(): def __init__(self, opt, name, run_id, data_dict, job_type='Training'): # Pre-training routine -- self.job_type = job_type - self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict + self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run + self.val_artifact, self.train_artifact = None, None + self.train_artifact_path, self.val_artifact_path = None, None + self.result_artifact = None + self.val_table, self.result_table = None, None + self.data_dict = data_dict + self.bbox_media_panel_images = [] + self.val_table_path_map = None # It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call if isinstance(opt.resume, str): # checks resume from artifact if opt.resume.startswith(WANDB_ARTIFACT_PREFIX): @@ -156,25 +163,27 @@ def setup_training(self, opt, data_dict): self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \ config.opt['hyp'] data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume - if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download + if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'), opt.artifact_alias) self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'), opt.artifact_alias) - self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None - if self.train_artifact_path is not None: - train_path = Path(self.train_artifact_path) / 'data/images/' - data_dict['train'] = str(train_path) - if self.val_artifact_path is not None: - val_path = Path(self.val_artifact_path) / 'data/images/' - data_dict['val'] = str(val_path) - self.val_table = self.val_artifact.get("val") - self.map_val_table_path() - wandb.log({"validation dataset": self.val_table}) + + if self.train_artifact_path is not None: + train_path = Path(self.train_artifact_path) / 'data/images/' + data_dict['train'] = str(train_path) + if self.val_artifact_path is not None: + val_path = Path(self.val_artifact_path) / 'data/images/' + data_dict['val'] = str(val_path) + if self.val_artifact is not None: self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation") self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"]) + self.val_table = self.val_artifact.get("val") + if self.val_table_path_map is None: + self.map_val_table_path() + wandb.log({"validation dataset": self.val_table}) if opt.bbox_interval == -1: self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1 return data_dict @@ -246,10 +255,10 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config= return path def map_val_table_path(self): - self.val_table_map = {} + self.val_table_path_map = {} print("Mapping dataset") for i, data in enumerate(tqdm(self.val_table.data)): - self.val_table_map[data[3]] = data[0] + self.val_table_path_map[data[3]] = data[0] def create_dataset_table(self, dataset, class_to_id, name='dataset'): # TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging @@ -283,7 +292,6 @@ def create_dataset_table(self, dataset, class_to_id, name='dataset'): return artifact def log_training_progress(self, predn, path, names): - if self.val_table and self.result_table: class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()]) box_data = [] total_conf = 0 @@ -297,7 +305,7 @@ def log_training_progress(self, predn, path, names): "domain": "pixel"}) total_conf = total_conf + conf boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space - id = self.val_table_map[Path(path).name] + id = self.val_table_path_map[Path(path).name] self.result_table.add_data(self.current_epoch, id, self.val_table.data[id][1], @@ -305,6 +313,22 @@ def log_training_progress(self, predn, path, names): total_conf / max(1, len(box_data)) ) + def val_one_image(self, pred, predn, path, names, im): + if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact + self.log_training_progress(predn, path, names) + else: # Default to bbox media panelif Val artifact not found + log_imgs = min(self.log_imgs, 100) + if len(self.bbox_media_panel_images) < log_imgs and self.current_epoch > 0: + if self.current_epoch % self.bbox_interval == 0: + box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, + "class_id": int(cls), + "box_caption": "%s %.3f" % (names[cls], conf), + "scores": {"class_score": conf}, + "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] + boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space + self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name)) + + def log(self, log_dict): if self.wandb_run: for key, value in log_dict.items(): @@ -313,8 +337,11 @@ def log(self, log_dict): def end_epoch(self, best_result=False): if self.wandb_run: with all_logging_disabled(): + if self.bbox_media_panel_images: + self.log_dict["Bounding Box Debugger/Images"] = self.bbox_media_panel_images wandb.log(self.log_dict) self.log_dict = {} + self.bbox_media_panel_images = [] if self.result_artifact: self.result_artifact.add(self.result_table, 'result') wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), @@ -332,21 +359,6 @@ def finish_run(self): wandb.run.finish() -def wandb_val_one_image(wandb_logger, wandb_images, pred, predn, path, names, im): - # Log 1 validation image, called in val.py - log_imgs = min(wandb_logger.log_imgs, 100) - if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # W&B logging - media panel plots - if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0: - box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]}, - "class_id": int(cls), - "box_caption": "%s %.3f" % (names[cls], conf), - "scores": {"class_score": conf}, - "domain": "pixel"} for *xyxy, conf, cls in pred.tolist()] - boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space - wandb_images.append(wandb_logger.wandb.Image(im, boxes=boxes, caption=path.name)) - wandb_logger.log_training_progress(predn, path, names) if wandb_logger.wandb_run else None - - @contextmanager def all_logging_disabled(highest_level=logging.CRITICAL): """ source - https://gist.github.com/simon-weber/7853144 diff --git a/val.py b/val.py index 4edbb1c26a85..5a8486720577 100644 --- a/val.py +++ b/val.py @@ -26,7 +26,6 @@ from utils.metrics import ap_per_class, ConfusionMatrix from utils.plots import plot_images, output_to_target, plot_study_txt from utils.torch_utils import select_device, time_sync -from utils.wandb_logging.wandb_utils import wandb_val_one_image def save_one_txt(predn, save_conf, shape, file): @@ -154,7 +153,7 @@ def run(data, s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95') p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. loss = torch.zeros(3, device=device) - jdict, stats, ap, ap_class, wandb_images = [], [], [], [], [] + jdict, stats, ap, ap_class = [], [], [], [] for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): t_ = time_sync() img = img.to(device, non_blocking=True) @@ -217,7 +216,7 @@ def run(data, if save_json: save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary if wandb_logger: - wandb_val_one_image(wandb_logger, wandb_images, pred, predn, path, names, img[si]) + wandb_logger.val_one_image(pred, predn, path, names, img[si]) # Plot images if plots and batch_i < 3: @@ -257,8 +256,6 @@ def run(data, if wandb_logger and wandb_logger.wandb: val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))] wandb_logger.log({"Validation": val_batches}) - if wandb_images: - wandb_logger.log({"Bounding Box Debugger/Images": wandb_images}) # Save JSON if save_json and len(jdict):