Skip to content

Commit

Permalink
Refactor metrics to precompute some parts
Browse files Browse the repository at this point in the history
  • Loading branch information
Jklein64 committed May 11, 2022
1 parent c48397a commit 9ee0953
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 62 deletions.
106 changes: 59 additions & 47 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,65 @@
import numpy as np
import ot

class Metric:
"""Class wrapper for metrics, used to precompute parts for efficiency."""

# make something semantic??
def __init__(self, pixels: np.ndarray) -> None:
self.pixels = pixels
self.value = self.compute()

def compute(pixels: np.ndarray):
pass

@staticmethod
def compare(a: Metric, b: Metric):
pass


class AverageColor(Metric):
def compute(self):
return np.mean(self.pixels, axis=0)

def compare(a, b):
return sum(np.square(a.value - b.value))


# FIXME precompute the average color before comparing to reduce computation,
# since comparison with a different superpixel doesn't change the color
def average_color_distance(pixels_1: np.ndarray, pixels_2: np.ndarray) -> float:
"""Compute the squared distance between the average colors of the given sets of pixels."""
# cast to float to avoid integer truncation
average_1 = np.mean(pixels_1, axis=0)
average_2 = np.mean(pixels_2, axis=0)
# return sum of squared difference
return sum(np.square(average_1 - average_2))


def wasserstein_image_distance(pixels_1: np.ndarray, pixels_2: np.ndarray) -> float:
"""Compute the Wasserstein or Earth Mover's distance between the given sets of integer-valued 8-bit pixels."""
# compute and normalize (by pixel count) color histograms for each channel
red_1, green_1, blue_1 = map(lambda h: h / len(pixels_1), color_histograms(pixels_1))
red_2, green_2, blue_2 = map(lambda h: h / len(pixels_2), color_histograms(pixels_2))
# cast to float to avoid integer truncation
pixels_1 = pixels_1[..., 0:3].astype(np.float64)
pixels_2 = pixels_2[..., 0:3].astype(np.float64)
# create and normalize the distance matrix
distance = ot.dist(np.arange(0.0, 256.0)[..., np.newaxis])
distance /= np.max(distance)
# find optimal flows for each channel
optimal_flow_red = ot.lp.emd(red_1, red_2, distance)
optimal_flow_green = ot.lp.emd(green_1, green_2, distance)
optimal_flow_blue = ot.lp.emd(blue_1, blue_2, distance)
# derive Wasserstein distances for each channel
wasserstein_red = np.sum(optimal_flow_red * distance)
wasserstein_green = np.sum(optimal_flow_green * distance)
wasserstein_blue = np.sum(optimal_flow_blue * distance)
# sum the channel-based distances to get final metric
return wasserstein_red + wasserstein_green + wasserstein_blue


def color_histograms(pixels: np.ndarray) -> list[np.ndarray]:
"""Maps a list of rgb pixels with integer values to three frequency charts with 256 bins each, one for each color channel. The frequencies are NOT normalized. This function is used in the Wasserstein-based metrics."""
r, g, b = np.transpose(pixels)[0:3]
histograms = []
for channel in (r, g, b):
# need explicit integer type since np.float64 is default
# needs to be big enough to store a count! uint8 won't do!
histogram = np.zeros((2 ** 8), dtype=np.uint64)
values, counts = np.unique(channel, return_counts=True)
histogram[values] = counts
histograms.append(histogram)
return histograms
class Wasserstein(Metric):
def compute(self):
n = len(self.pixels)
r, g, b = self.color_histograms()
return (r / n, g / n, b / n)

def compare(a, b):
red_1, green_1, blue_1 = a.value
red_2, green_2, blue_2 = b.value
# create and normalize the distance matrix
distance = ot.dist(np.arange(0.0, 256.0)[..., np.newaxis])
distance /= np.max(distance)
# find optimal flows for each channel
optimal_flow_red = ot.lp.emd(red_1, red_2, distance)
optimal_flow_green = ot.lp.emd(green_1, green_2, distance)
optimal_flow_blue = ot.lp.emd(blue_1, blue_2, distance)
# derive Wasserstein distances for each channel
wasserstein_red = np.sum(optimal_flow_red * distance)
wasserstein_green = np.sum(optimal_flow_green * distance)
wasserstein_blue = np.sum(optimal_flow_blue * distance)
# sum the channel-based distances to get final metric
return wasserstein_red + wasserstein_green + wasserstein_blue

def color_histograms(self) -> list[np.ndarray]:
"""Maps a list of rgb pixels with integer values to three frequency charts with 256 bins each, one for each color channel. The frequencies are NOT normalized. This function is used in the Wasserstein-based metrics."""
r, g, b = np.transpose(self.pixels)[0:3]
histograms = []
for channel in (r, g, b):
# need explicit integer type since np.float64 is default
# needs to be big enough to store a count! uint8 won't do!
histogram = np.zeros((2 ** 8), dtype=np.uint64)
values, counts = np.unique(channel, return_counts=True)
histogram[values] = counts
histograms.append(histogram)
return histograms


# make something semantic??

31 changes: 16 additions & 15 deletions superpixel_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from metrics import wasserstein_image_distance, average_color_distance
from metrics import Metric, AverageColor, Wasserstein
from visualization import show


Expand Down Expand Up @@ -46,8 +46,7 @@ def main():
show(original, regions=labels, constraints=constraints)

# merge neighbors within threshold to reduce the total number of superpixels
# FIXME figure out why this function takes so long, even with average color distance
neighbors = neighbor_matrix(original, labels, metric=wasserstein_image_distance)
neighbors = neighbor_matrix(original, labels, metric=Wasserstein)
# invert neighbors to represent distance instead of similarity; delta is arbitrary
# FIXME justify delta choice
merged_local = connected_within_threshold(labels, 1 - neighbors, delta=0.001)
Expand All @@ -56,7 +55,7 @@ def main():
show(original, regions=merged_local, constraints=constraints)

# create dense distances matrix and merge based on optimized delta
distances = distances_matrix(original, merged_local, metric=average_color_distance)
distances = distances_matrix(original, merged_local, metric=AverageColor)
merged_nonlocal = constrained_division(merged_local, np.zeros_like(labels), distances, (0, 1), constraints)

# show the image after applying the first two constraints
Expand Down Expand Up @@ -140,25 +139,25 @@ def constrained_division(superpixels: np.ndarray, merged_nonlocal: np.ndarray, d
# fill masked values with previous constraint
merged[merged == -1] = merged_nonlocal[merged == -1]
return merged




def neighbor_matrix(original: np.ndarray, superpixels: np.ndarray, metric: Callable[[np.ndarray, np.ndarray], float]) -> np.ndarray:
def neighbor_matrix(original: np.ndarray, superpixels: np.ndarray, metric: Metric) -> np.ndarray:
"""Create a weighed adjacency matrix between every pair of superpixels which neighbors each other. Weighed region adjacency graph! The resulting matrix is normalized so that the weights are within [0, 1], and then rescaled so that higher values correspond to higher similarity."""
# store list of valid superpixel labels
unique_labels = np.ma.compressed(np.ma.unique(superpixels))
# pixels is a list of pixel values for the n'th superpixel
pixels = [original[superpixels == label] for label in unique_labels]
# precompute part of the metric to avoid recomputing certain things
precomputed = [metric(original[superpixels == label]) for label in unique_labels]
# create n-by-n matrix to compare distances between neighbors
neighbors = np.zeros((len(unique_labels), len(unique_labels)))
# iterate through neighbors and calculate distances
for i in unique_labels:
for j in superpixel_neighbors(superpixels, i):
# only compute distances one way
if j < i:
neighbors[i, j] = metric(pixels[i], pixels[j])
# fill out the other half of the matrix
a = precomputed[i]
b = precomputed[j]
neighbors[i, j] = metric.compare(a, b)
# # fill out the other half of the matrix
neighbors = neighbors + np.transpose(neighbors)
# normalize to be within zero and one
neighbors /= np.max(neighbors)
Expand Down Expand Up @@ -196,17 +195,19 @@ def connected_within_threshold(superpixels: np.ndarray, distances: np.ndarray, d
return labels


def distances_matrix(original: np.ndarray, superpixels: np.ndarray, metric: Callable[[np.ndarray, np.ndarray], float]) -> np.ndarray:
def distances_matrix(original: np.ndarray, superpixels: np.ndarray, metric: Metric) -> np.ndarray:
"""Create a matrix with the metric-based distances between every pair of the given superpixels implied by the original and labelled images."""
# store list of valid superpixel labels
unique_labels = np.ma.compressed(np.ma.unique(superpixels))
# pixels is a list of pixel values for the n'th superpixel
pixels = [original[superpixels == label] for label in unique_labels]
# precompute part of the metric to avoid recomputing certain things
precomputed = [metric(original[superpixels == label]) for label in unique_labels]
# create n-by-n matrix to compare distances between n superpixels
distances = np.zeros((len(unique_labels), len(unique_labels)))
# distance is symmetric, so only compare each pair once (below diagonal)
for i, j in np.transpose(np.tril_indices(len(unique_labels), k=-1)):
distances[i, j] = metric(pixels[i], pixels[j])
a = precomputed[i]
b = precomputed[j]
distances[i, j] = metric.compare(a, b)
# fill in the rest of the distances matrix
return distances + np.transpose(distances)

Expand Down

0 comments on commit 9ee0953

Please sign in to comment.