Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyplot for token attributions (continued from PR #11) #13

Merged
merged 2 commits into from
Sep 14, 2022

Conversation

TomPham97
Copy link
Contributor

My apologies for accidentally unlinking the original fork, thus I have to create a new PR based on this new fork.

The pyplot function appears to use multiple sub-functions (plt.title, plt.xlabel, and so on) instead of arguments like in Pandas'. Therefore, I'm not entirely sure if **kwargs is necessary any more, what do you think?

I also set the default plot type to be horizontal bar since it's more aesthetically pleasing, but it's ultimately up to you to decide which one is the most suitable ☺️

The code snippet below was used to test and succeeded:

# Dependencies
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Convert list of tuples into a dataframe
df = pd.DataFrame(output.token_attributions,
                  columns = ['tokens', 'attributions'],
                  index = None)
# Extract vectors
tokens = df['tokens'].copy()
attributions = df['attributions'].copy()

def plot(tokens: object = tokens,
         attributions: object = attributions,
        #  self,
         plot_type: str = 'barh',
         **plot_kwargs) -> None:
         plt.title('Token Attributions')
         
         plot_kwargs = {'title': 'Token Attributions', **plot_kwargs}

         if plot_type == 'bar':
             # Bar chart
             plt.bar(tokens, attributions)
             plt.xlabel('tokens')
             plt.ylabel('attribution value')

         elif plot_type == 'barh':
             # Horizontal bar chart
             plt.barh(tokens, attributions)
             plt.xlabel('attribution value')
             plt.ylabel('tokens')
             plt.gca().invert_yaxis()

         elif plot_type == 'pie':
             # Pie chart
             plt.pie(attributions,
                     startangle = 90,
                     counterclock = False,
                    #  explode = (attributions <= 3) * 0.5,
                     labels = tokens,
                     autopct = '%1.1f%%',
                     pctdistance = 0.8)

         else:
             raise NotImplementedError(f"`plot_type = {plot_type} is not implemented. Choose one of: ['bar', 'barh', 'pie']")

Co-authored-by: João Lages <joaop.glages@gmail.com>
@JoaoLages
Copy link
Owner

JoaoLages commented Sep 14, 2022

Thanks for another contribution! 💪 🚀

@JoaoLages JoaoLages merged commit 1a2d6ca into JoaoLages:main Sep 14, 2022
@TomPham97 TomPham97 deleted the plots branch September 14, 2022 18:01
@TomPham97 TomPham97 restored the plots branch September 14, 2022 18:42
@TomPham97
Copy link
Contributor Author

What would happen if Pandas' plot function was already loaded? Would it raise a conflict with our newly-defined plot function? Fortunately, I suspect that most users would not use Pandas along with stable diffusion 😅

@JoaoLages
Copy link
Owner

What would happen if Pandas' plot function was already loaded? Would it raise a conflict with our newly-defined plot function? Fortunately, I suspect that most users would not use Pandas along with stable diffusion 😅

The object TokenAttributions is of type list, not pandas.DataFrame, so there are no conflicts there. Idk if I got your question right 🤔

@TomPham97 TomPham97 mentioned this pull request Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants