Skip to content

Commit

Permalink
🛠 Fix visualization (#417)
Browse files Browse the repository at this point in the history
* Fix scaling of gt_masks in ImageResult

Scaling of `gt_mask` checked `pred_mask` and thus never triggered

* Convert images from RGB to BGR only for `cv2` ops

Tensorboard and wandb expect images in RGB format, and therefore would
display "faulty" images otherwise

* Revert unintended change of default value in .show
  • Loading branch information
ORippler committed Jul 8, 2022
1 parent fa6808a commit 059fd44
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __post_init__(self):
if self.pred_mask is not None and np.max(self.pred_mask) <= 1.0:
self.pred_mask *= 255
self.segmentations = mark_boundaries(self.image, self.pred_mask, color=(1, 0, 0), mode="thick")
if self.gt_mask is not None and np.max(self.pred_mask) <= 1.0:
if self.gt_mask is not None and np.max(self.gt_mask) <= 1.0:
self.gt_mask *= 255


Expand Down Expand Up @@ -154,13 +154,13 @@ def _visualize_simple(self, image_result):
visualization = mark_boundaries(
image_result.heat_map, image_result.pred_mask, color=(1, 0, 0), mode="thick"
)
return cv2.cvtColor((visualization * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
return (visualization * 255).astype(np.uint8)
if self.task == "classification":
if image_result.pred_label:
image_classified = add_anomalous_label(image_result.heat_map, image_result.pred_score)
else:
image_classified = add_normal_label(image_result.heat_map, 1 - image_result.pred_score)
return cv2.cvtColor(image_classified, cv2.COLOR_RGB2BGR)
return image_classified
raise ValueError(f"Unknown task type: {self.task}")

@staticmethod
Expand All @@ -172,6 +172,7 @@ def show(title: str, image: np.ndarray, delay: int = 0):
image (np.ndarray): Image that will be shown in the window.
delay (int): Delay in milliseconds to wait for keystroke. 0 for infinite.
"""
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imshow(title, image)
cv2.waitKey(delay)
cv2.destroyAllWindows()
Expand All @@ -185,6 +186,7 @@ def save(file_path: Path, image: np.ndarray):
image (np.ndarray): Image that will be saved to the file system.
"""
file_path.parent.mkdir(parents=True, exist_ok=True)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(file_path), image)


Expand Down Expand Up @@ -233,5 +235,4 @@ def generate(self) -> np.ndarray:
# convert canvas to numpy array to prepare for visualization with opencv
img = np.frombuffer(self.figure.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(self.figure.canvas.get_width_height()[::-1] + (3,))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img

0 comments on commit 059fd44

Please sign in to comment.