Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve softmax issues with pytorch_model.py #283

Merged
merged 8 commits into from
Aug 29, 2023
Merged

Conversation

dilyabareeva
Copy link
Collaborator

@dilyabareeva dilyabareeva commented Jul 7, 2023

Description

  • Extraction of softmax layer from pytorch models sometimes lead to problem.
  • We want to update the logic of softmax adjustment as follows:
    IMG_4F0CB8FDD9CB-1-1

Implemented changes

  • Updated logic for softmax adjustments in the model depending on the model structure
  • The 6 different cases from the photo represented in get_softmax_arg_model method (see comments)
  • A better way to extract Softmax layer from model by replacing it with Identity (instead of rebuilding the model and forward function)

@codecov-commenter
Copy link

codecov-commenter commented Jul 7, 2023

Codecov Report

Merging #283 (2b85111) into main (c288a7d) will increase coverage by 0.06%.
The diff coverage is 100.00%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

@@            Coverage Diff             @@
##             main     #283      +/-   ##
==========================================
+ Coverage   93.29%   93.35%   +0.06%     
==========================================
  Files          62       62              
  Lines        3430     3448      +18     
==========================================
+ Hits         3200     3219      +19     
+ Misses        230      229       -1     
Files Changed Coverage Δ
quantus/helpers/model/pytorch_model.py 87.05% <100.00%> (+2.75%) ⬆️

Copy link
Member

@annahedstroem annahedstroem left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things! Thanks @dilyabareeva !

quantus/helpers/model/pytorch_model.py Outdated Show resolved Hide resolved
quantus/helpers/model/pytorch_model.py Show resolved Hide resolved
quantus/helpers/model/pytorch_model.py Show resolved Hide resolved
@@ -5,15 +5,15 @@
# Quantus is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you double-check the tf implementation so that the function names as well as the logic are replicated there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both use the same ModelInterface - so the public methods are the same. The logic in tf_model might need to be changed similarly: I'm not 100% sure that model.layers[-1] reliably returns the actual output layer in tensorflow. But this PR is about PyTorch, so maybe this is worth opening another issue to further investigate.

quantus/helpers/model/pytorch_model.py Show resolved Hide resolved
tests/conftest.py Show resolved Hide resolved
tests/conftest.py Show resolved Hide resolved
@aaarrti
Copy link
Collaborator

aaarrti commented Jul 16, 2023

@dilyabareeva You are modifying the structure of the model provided by the user, it would be great to make it visible to the user what exactly is being done under the hood, and what assumptions were made.
E.g., for the case when the last softmax layer is replaced by a linear one, we could log a message:

<XAI method name> reuires model with linear top, but found softmax activation on top for <model name>. Removed <layer name> at <layer index>

This must not be a warning, since this is expected behaviour. I'd suggest using logging.info, and set log format to something like "[%(filename)s:%(lineno)s->%(funcName)s()]:%(levelname)s: %(message)s", so the user can see where it is coming from.

Copy link
Member

@annahedstroem annahedstroem left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK.

@annahedstroem annahedstroem merged commit 72a519b into main Aug 29, 2023
7 checks passed
@annahedstroem annahedstroem deleted the softmax_arg_issue branch November 27, 2023 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants