Skip to content

Commit

Permalink
Refactor wandb operations (#4061)
Browse files Browse the repository at this point in the history
  • Loading branch information
AyushExel committed Jul 19, 2021
1 parent 8f5a7a9 commit e78ca86
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 36 deletions.
74 changes: 43 additions & 31 deletions utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ class WandbLogger():
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine --
self.job_type = job_type
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
self.val_artifact, self.train_artifact = None, None
self.train_artifact_path, self.val_artifact_path = None, None
self.result_artifact = None
self.val_table, self.result_table = None, None
self.data_dict = data_dict
self.bbox_media_panel_images = []
self.val_table_path_map = None
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
Expand Down Expand Up @@ -156,25 +163,27 @@ def setup_training(self, opt, data_dict):
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.opt['hyp']
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
opt.artifact_alias)
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
opt.artifact_alias)
self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
if self.train_artifact_path is not None:
train_path = Path(self.train_artifact_path) / 'data/images/'
data_dict['train'] = str(train_path)
if self.val_artifact_path is not None:
val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path)
self.val_table = self.val_artifact.get("val")
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})

if self.train_artifact_path is not None:
train_path = Path(self.train_artifact_path) / 'data/images/'
data_dict['train'] = str(train_path)
if self.val_artifact_path is not None:
val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path)


if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
self.val_table = self.val_artifact.get("val")
if self.val_table_path_map is None:
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})
if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict
Expand Down Expand Up @@ -246,10 +255,10 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=
return path

def map_val_table_path(self):
self.val_table_map = {}
self.val_table_path_map = {}
print("Mapping dataset")
for i, data in enumerate(tqdm(self.val_table.data)):
self.val_table_map[data[3]] = data[0]
self.val_table_path_map[data[3]] = data[0]

def create_dataset_table(self, dataset, class_to_id, name='dataset'):
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
Expand Down Expand Up @@ -283,7 +292,6 @@ def create_dataset_table(self, dataset, class_to_id, name='dataset'):
return artifact

def log_training_progress(self, predn, path, names):
if self.val_table and self.result_table:
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
box_data = []
total_conf = 0
Expand All @@ -297,14 +305,30 @@ def log_training_progress(self, predn, path, names):
"domain": "pixel"})
total_conf = total_conf + conf
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
id = self.val_table_map[Path(path).name]
id = self.val_table_path_map[Path(path).name]
self.result_table.add_data(self.current_epoch,
id,
self.val_table.data[id][1],
wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
total_conf / max(1, len(box_data))
)

def val_one_image(self, pred, predn, path, names, im):
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
self.log_training_progress(predn, path, names)
else: # Default to bbox media panelif Val artifact not found
log_imgs = min(self.log_imgs, 100)
if len(self.bbox_media_panel_images) < log_imgs and self.current_epoch > 0:
if self.current_epoch % self.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))


def log(self, log_dict):
if self.wandb_run:
for key, value in log_dict.items():
Expand All @@ -313,8 +337,11 @@ def log(self, log_dict):
def end_epoch(self, best_result=False):
if self.wandb_run:
with all_logging_disabled():
if self.bbox_media_panel_images:
self.log_dict["Bounding Box Debugger/Images"] = self.bbox_media_panel_images
wandb.log(self.log_dict)
self.log_dict = {}
self.bbox_media_panel_images = []
if self.result_artifact:
self.result_artifact.add(self.result_table, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
Expand All @@ -332,21 +359,6 @@ def finish_run(self):
wandb.run.finish()


def wandb_val_one_image(wandb_logger, wandb_images, pred, predn, path, names, im):
# Log 1 validation image, called in val.py
log_imgs = min(wandb_logger.log_imgs, 100)
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # W&B logging - media panel plots
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb_logger.wandb.Image(im, boxes=boxes, caption=path.name))
wandb_logger.log_training_progress(predn, path, names) if wandb_logger.wandb_run else None


@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
""" source - https://gist.github.com/simon-weber/7853144
Expand Down
7 changes: 2 additions & 5 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync
from utils.wandb_logging.wandb_utils import wandb_val_one_image


def save_one_txt(predn, save_conf, shape, file):
Expand Down Expand Up @@ -154,7 +153,7 @@ def run(data,
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
t_ = time_sync()
img = img.to(device, non_blocking=True)
Expand Down Expand Up @@ -217,7 +216,7 @@ def run(data,
if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
if wandb_logger:
wandb_val_one_image(wandb_logger, wandb_images, pred, predn, path, names, img[si])
wandb_logger.val_one_image(pred, predn, path, names, img[si])

# Plot images
if plots and batch_i < 3:
Expand Down Expand Up @@ -257,8 +256,6 @@ def run(data,
if wandb_logger and wandb_logger.wandb:
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))]
wandb_logger.log({"Validation": val_batches})
if wandb_images:
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})

# Save JSON
if save_json and len(jdict):
Expand Down

0 comments on commit e78ca86

Please sign in to comment.