Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow multiple outputs for agg_mode=True in Feature Ablation (pytorch…
…#425) Summary: Pull Request resolved: pytorch#425 ## Description What is aggregation output mode? It can be defined as: When there is no 1:1 correspondence with the `num_examples` (`batch_size`) and the amount of outputs your model produces, i.e. the model output size does not grow in size as the `batch_size` becomes larger. This allows for an arbitrary sized tensor to be output from the `forward_func` for feature ablation. --- ## Implementation Details We assume `aggregation_output_mode` to be the case if: `perturbations_per_eval == 1` and [ `feature_mask is None` __or__ is of length 1 (i.e. associated to all inputs) ]. This is not perfect but for feature ablation the underlying logic is the same if there is a 1:1 correspondence (i.e. the model has `batch_size` outputs) and `agg_output_mode=True` If `agg_output_mode == True`: - Feature ablation will output a tensor of shape `1xOxF` where `O` is the number of output features and `F` is the number of input features under aggregation mode. Thus, if the model outputs a tensor > 2D the user must reshape it (as we treat the output as a 2D tensor in the implementation); thus it is recommended to only output a 2D tensor (i.e. the implementation allows for >2D). If we are not in `agg_output_mode` we must ensure the number of elements is `n` (`batch_size`). If it is not, we output an error to the user. Here we could actually check if the element size is at least `n`, but for simplicity I am not doing this. ## Tests Added tests to check for: `agg_mode=True`: - Incorrect feature mask (i.e. where `fm.shape[0] > 1`) - Output a `Fx1` tensor where `F` is the number of features in the input - The above but for a feature mask with the first two features treated as one feature - Output a `2x3x5` constant tensor (not associated to outputs) - internally this will be interpreted as a `1x30` 2D tensor `agg_mode=False`: - Check there is exactly `n` outputs where `n == batch_size` => if not then check that we throw an exception (assertion error). **This already exists in `test_error_perturbations_per_eval_limit_batch_scalar`** ## Notes I created a new function rather than modifying `_find_output_mode_and_verify`; as otherwise this breaks shapley value sampling. Will have to fix this in a separate PR. Differential Revision: D22416476 fbshipit-source-id: 0d08ca990a1e999339e51f0a7fa50be197d2f3b9
- Loading branch information