Skip to content

Commit

Permalink
fidelity: issue 151
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonin POCHE authored and AntoninPoche committed Dec 13, 2023
1 parent 68a36a4 commit 21092bf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/metrics/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_data_types_shapes():
explanation_metrics = {
Deletion: {"steps": 3},
Insertion: {"steps": 3},
MuFidelity: {"nb_samples": 3},
MuFidelity: {"nb_samples": 3, "grid_size": None, "subset_percent": 0.9},
}

for data_type, input_shape in data_types_input_shapes.items():
Expand Down
8 changes: 3 additions & 5 deletions xplique/metrics/fidelity.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ def __init__(self,

# if unspecified use the original equation (pixel-wise modification)
self.grid_size = grid_size or self.inputs.shape[1]
# cardinal of subset (|S| in the equation)
self.subset_size = int(self.grid_size ** 2 * self.subset_percent)

self.base_predictions = self.batch_inference_function(self.model, self.inputs,
self.targets, self.batch_size)
Expand Down Expand Up @@ -198,14 +196,14 @@ def _perturb_samples(self,
if len(inputs.shape) == 2: # tabular data, grid size is ignored
# prepare the random masks
subset_masks = tf.random.uniform((nb_perturbations, inputs.shape[1]), 0, 1, tf.float32)
subset_masks = tf.argsort(subset_masks, axis=-1) > self.subset_size
subset_masks = subset_masks > self.subset_percent
subset_masks = tf.cast(subset_masks, tf.float32)

elif len(inputs.shape) == 3: # time series
# prepare the random masks
subset_masks = tf.random.uniform((nb_perturbations, self.grid_size * inputs.shape[2]),
minval=0, maxval=1, dtype=tf.float32)
subset_masks = tf.argsort(subset_masks, axis=-1) > self.subset_size
subset_masks = subset_masks > self.subset_percent

# and interpolate them if needed
subset_masks = tf.reshape(tf.cast(subset_masks, tf.float32),
Expand All @@ -217,7 +215,7 @@ def _perturb_samples(self,
# prepare the random masks
subset_masks = tf.random.uniform(shape=(nb_perturbations, self.grid_size ** 2),
minval=0, maxval=1, dtype=tf.float32)
subset_masks = tf.argsort(subset_masks, axis=-1) > self.subset_size
subset_masks = subset_masks > self.subset_percent

# and interpolate them if needed
subset_masks = tf.reshape(tf.cast(subset_masks, tf.float32),
Expand Down

0 comments on commit 21092bf

Please sign in to comment.