Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor wandb operations #4061

Merged
merged 1 commit into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 43 additions & 31 deletions utils/wandb_logging/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ class WandbLogger():
def __init__(self, opt, name, run_id, data_dict, job_type='Training'):
# Pre-training routine --
self.job_type = job_type
self.wandb, self.wandb_run, self.data_dict = wandb, None if not wandb else wandb.run, data_dict
self.wandb, self.wandb_run = wandb, None if not wandb else wandb.run
self.val_artifact, self.train_artifact = None, None
self.train_artifact_path, self.val_artifact_path = None, None
self.result_artifact = None
self.val_table, self.result_table = None, None
self.data_dict = data_dict
self.bbox_media_panel_images = []
self.val_table_path_map = None
# It's more elegant to stick to 1 wandb.init call, but useful config data is overwritten in the WandbLogger's wandb.init call
if isinstance(opt.resume, str): # checks resume from artifact
if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
Expand Down Expand Up @@ -156,25 +163,27 @@ def setup_training(self, opt, data_dict):
self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
config.opt['hyp']
data_dict = dict(self.wandb_run.config.data_dict) # eliminates the need for config file to resume
if 'val_artifact' not in self.__dict__: # If --upload_dataset is set, use the existing artifact, don't download
if self.val_artifact is None: # If --upload_dataset is set, use the existing artifact, don't download
self.train_artifact_path, self.train_artifact = self.download_dataset_artifact(data_dict.get('train'),
opt.artifact_alias)
self.val_artifact_path, self.val_artifact = self.download_dataset_artifact(data_dict.get('val'),
opt.artifact_alias)
self.result_artifact, self.result_table, self.val_table, self.weights = None, None, None, None
if self.train_artifact_path is not None:
train_path = Path(self.train_artifact_path) / 'data/images/'
data_dict['train'] = str(train_path)
if self.val_artifact_path is not None:
val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path)
self.val_table = self.val_artifact.get("val")
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})

if self.train_artifact_path is not None:
train_path = Path(self.train_artifact_path) / 'data/images/'
data_dict['train'] = str(train_path)
if self.val_artifact_path is not None:
val_path = Path(self.val_artifact_path) / 'data/images/'
data_dict['val'] = str(val_path)


if self.val_artifact is not None:
self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
self.result_table = wandb.Table(["epoch", "id", "ground truth", "prediction", "avg_confidence"])
self.val_table = self.val_artifact.get("val")
if self.val_table_path_map is None:
self.map_val_table_path()
wandb.log({"validation dataset": self.val_table})
if opt.bbox_interval == -1:
self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
return data_dict
Expand Down Expand Up @@ -246,10 +255,10 @@ def log_dataset_artifact(self, data_file, single_cls, project, overwrite_config=
return path

def map_val_table_path(self):
self.val_table_map = {}
self.val_table_path_map = {}
print("Mapping dataset")
for i, data in enumerate(tqdm(self.val_table.data)):
self.val_table_map[data[3]] = data[0]
self.val_table_path_map[data[3]] = data[0]

def create_dataset_table(self, dataset, class_to_id, name='dataset'):
# TODO: Explore multiprocessing to slpit this loop parallely| This is essential for speeding up the the logging
Expand Down Expand Up @@ -283,7 +292,6 @@ def create_dataset_table(self, dataset, class_to_id, name='dataset'):
return artifact

def log_training_progress(self, predn, path, names):
if self.val_table and self.result_table:
class_set = wandb.Classes([{'id': id, 'name': name} for id, name in names.items()])
box_data = []
total_conf = 0
Expand All @@ -297,14 +305,30 @@ def log_training_progress(self, predn, path, names):
"domain": "pixel"})
total_conf = total_conf + conf
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
id = self.val_table_map[Path(path).name]
id = self.val_table_path_map[Path(path).name]
self.result_table.add_data(self.current_epoch,
id,
self.val_table.data[id][1],
wandb.Image(self.val_table.data[id][1], boxes=boxes, classes=class_set),
total_conf / max(1, len(box_data))
)

def val_one_image(self, pred, predn, path, names, im):
if self.val_table and self.result_table: # Log Table if Val dataset is uploaded as artifact
self.log_training_progress(predn, path, names)
else: # Default to bbox media panelif Val artifact not found
log_imgs = min(self.log_imgs, 100)
if len(self.bbox_media_panel_images) < log_imgs and self.current_epoch > 0:
if self.current_epoch % self.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
self.bbox_media_panel_images.append(wandb.Image(im, boxes=boxes, caption=path.name))


def log(self, log_dict):
if self.wandb_run:
for key, value in log_dict.items():
Expand All @@ -313,8 +337,11 @@ def log(self, log_dict):
def end_epoch(self, best_result=False):
if self.wandb_run:
with all_logging_disabled():
if self.bbox_media_panel_images:
self.log_dict["Bounding Box Debugger/Images"] = self.bbox_media_panel_images
wandb.log(self.log_dict)
self.log_dict = {}
self.bbox_media_panel_images = []
if self.result_artifact:
self.result_artifact.add(self.result_table, 'result')
wandb.log_artifact(self.result_artifact, aliases=['latest', 'last', 'epoch ' + str(self.current_epoch),
Expand All @@ -332,21 +359,6 @@ def finish_run(self):
wandb.run.finish()


def wandb_val_one_image(wandb_logger, wandb_images, pred, predn, path, names, im):
# Log 1 validation image, called in val.py
log_imgs = min(wandb_logger.log_imgs, 100)
if len(wandb_images) < log_imgs and wandb_logger.current_epoch > 0: # W&B logging - media panel plots
if wandb_logger.current_epoch % wandb_logger.bbox_interval == 0:
box_data = [{"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
"class_id": int(cls),
"box_caption": "%s %.3f" % (names[cls], conf),
"scores": {"class_score": conf},
"domain": "pixel"} for *xyxy, conf, cls in pred.tolist()]
boxes = {"predictions": {"box_data": box_data, "class_labels": names}} # inference-space
wandb_images.append(wandb_logger.wandb.Image(im, boxes=boxes, caption=path.name))
wandb_logger.log_training_progress(predn, path, names) if wandb_logger.wandb_run else None


@contextmanager
def all_logging_disabled(highest_level=logging.CRITICAL):
""" source - https://gist.github.com/simon-weber/7853144
Expand Down
7 changes: 2 additions & 5 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target, plot_study_txt
from utils.torch_utils import select_device, time_sync
from utils.wandb_logging.wandb_utils import wandb_val_one_image


def save_one_txt(predn, save_conf, shape, file):
Expand Down Expand Up @@ -154,7 +153,7 @@ def run(data,
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
jdict, stats, ap, ap_class = [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
t_ = time_sync()
img = img.to(device, non_blocking=True)
Expand Down Expand Up @@ -217,7 +216,7 @@ def run(data,
if save_json:
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
if wandb_logger:
wandb_val_one_image(wandb_logger, wandb_images, pred, predn, path, names, img[si])
wandb_logger.val_one_image(pred, predn, path, names, img[si])

# Plot images
if plots and batch_i < 3:
Expand Down Expand Up @@ -257,8 +256,6 @@ def run(data,
if wandb_logger and wandb_logger.wandb:
val_batches = [wandb_logger.wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('val*.jpg'))]
wandb_logger.log({"Validation": val_batches})
if wandb_images:
wandb_logger.log({"Bounding Box Debugger/Images": wandb_images})

# Save JSON
if save_json and len(jdict):
Expand Down