Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Master issue] Removing unmaintained functionality #848

Open
3 of 5 tasks
fmassa opened this issue Sep 5, 2023 · 3 comments
Open
3 of 5 tasks

[Master issue] Removing unmaintained functionality #848

fmassa opened this issue Sep 5, 2023 · 3 comments

Comments

@fmassa
Copy link
Contributor

fmassa commented Sep 5, 2023

Context

Over the past couple of years, xFormers has evolved and some of the functionality which was originally implemented is not maintained anymore. This means breakages are possible, and we might not notice it before a while.

In the spirit of being more efficient and to provide a better user experience, we will be removing functionality which isn't maintained anymore.

How

The process will start with a deprecation warning, which will be present for a couple of months before we remove them entirely.

If you are a user of xFormers and would be impacted by some of the features being deprecated / removed, please let us know, as it will help us potentially reassessing the deprecated features.

What

The current planned features to be deprecated are the following:

  • experimental. This is not part of the library and depends on an older version of Triton. Given that this is not part of the xFormers distribution, this will be directly removed and won't follow the deprecation warning period Remove experimental folder #851
  • xformers/factory. Those model builders haven't been updated to support memory-efficient attention, we now favor writing your own nn.Module that calls into the functions you want, instead of having a config-based builder (based on hydra, which is unmaintained as well). This means all traces of config system will be removed as well. Deprecate xformers/factory #850
  • xformers/components. The current implementations are sometimes a bit convoluted, and haven't been optimized for performance as of 2023 (specially in the light of torch.compile). We might keep a few of those operators in a different form (sparse attention, rope, etc), but we will remove complexity on how it's exposed.
  • xformers/triton will be re-organized. Some of those operators might not be as competitive as what they used to be back when they were implemented, and thus we will be selectively keeping only a subset of them. cca3421
  • All related documentation and tutorials

I'm impacted

If any of those changes affects you, let us know in this issue. Please point which part of xFormers you use (with code-pointers if available!)

Thanks!

@chriseviparker
Copy link

@fmassa it would be good to understand what the core usage of the library then becomes, i.e. what folder should examples in theory be using? and what advantage do the examples demonstrate of xFormers?

all the examples heavily use both components and factory. I believe that is to demo what it said on the headline place of being a "block" zoo and composability.

is it ops folder that now becomes the focal point?
from what I can see components == xFormers

@fmassa
Copy link
Contributor Author

fmassa commented Sep 26, 2023

Hi @chriseviparker

That's a good question. We will be removing all old examples from xformers as well (which uses the builders etc), and we will be instead using newer ones like https://github.com/facebookresearch/xformers/tree/main/examples/llama_inference (which is actually a quite fast example of how to use xformers for LLMs).

The ops folder is where most of the core components are, and the components/ folder is indeed unmaintained and will be deprecated / removed.

@YoelShoshan
Copy link

YoelShoshan commented Nov 12, 2023

@fmassa I started using xformers very recently, and actually enjoyed the freedom to both take something like xformers.ops.memory_efficient_attention and use it directly in my model,
but ALSO the ability to quickly construct standard blocks like, for example, by using xFormerEncoderBlock + xFormerEncoderConfig.
It didn't force me to use Hydra configs, as it accepts standard kwargs/dictionaries, and was very convenient to quickly test things out, quickly switching existing attention mechanisms.

To test memory_efficient_attention what I did is basically copy MemoryEfficientScaledDotProduct class code and add a memory-efficient version of it, and registering it with a new name. (code below)

I reached this PR after seeing the deprecation warning on usage of xFormerEncoderBlock

So I wonder - if xformers intends to drop the blocks part, what do you recommend as a library that handles the blocks and their stacks creation? each user should code their own or is there some public library that handles that well, for which it's easy to include different attention mechanisms from xformers?

If there isn't such library, my vote is to keep this functionality in xformers :)

code for the adding and registering a memory_efficient_attention that I could use when building blocks with xFormerEncoderBlock()

Warning: I did NOT fully verify the mask related aspects of this code.

import logging
from dataclasses import dataclass
from typing import Optional, Union

import torch
from torch import nn

from xformers.components.attention import (
    Attention,
    AttentionConfig,
    AttentionMask,
    register_attention,
)
from xformers.components.attention.core import scaled_dot_product_attention
from xformers import ops as xops

logger = logging.getLogger("xformers")


@dataclass
class ScaledDotProductConfig(AttentionConfig):
    causal: Optional[bool]
    seq_len: Optional[int]
    to_seq_len: Optional[int]


@register_attention("memory_efficient_scaled_dot_product", ScaledDotProductConfig)
class MemoryEfficientScaledDotProduct(Attention):
    r"""
    A memory efficient version of the 
        Implementing the Scaled Dot-Product attention proposed in
        `Attention is all you need`_, Vaswani et al.

        .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5
    """

    mask: Optional[AttentionMask]

    def __init__(
        self,
        dropout: float = 0.0,
        causal: bool = False,
        seq_len: Optional[int] = None,
        to_seq_len: Optional[int] = None,
        *args,
        **kwargs,
    ):
        super().__init__()

        #self.attn_drop = nn.Dropout(dropout, inplace=False)
        self.dropout = dropout
        self.causal = causal
        self.seq_len = seq_len

        if causal and seq_len is not None:
            self.mask = AttentionMask.make_causal(seq_len, to_seq_len)
        else:
            self.mask = None

        # Properties specific to this attention mechanism
        self.supports_attention_mask = True
        self.supports_key_padding_mask = False

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None,
        *args,
        **kwargs,
    ) -> torch.Tensor:
        r"""
        att_mask    A 2D or 3D mask which ignores attention at certain positions.

                    - If the mask is boolean, a value of True will keep the value,
                        while a value of False will mask the value.

                        Key padding masks (dimension: batch x sequence length) and attention masks
                        (dimension: sequence length x sequence length OR batch x sequence length x sequence length)
                        can be combined and passed in here. Method maybe_merge_masks provided in the utils can be
                        used for that merging.

                    - If the mask has the float type, then an additive mask is expected (masked values are -inf)

        """

        # Convenience, create an attention mask if a tensor was passed
        if att_mask is not None and isinstance(att_mask, torch.Tensor):
            # By default we don't know of the causality, and a check would be expensive
            att_mask = (
                AttentionMask.from_bool(att_mask)
                if att_mask.dtype == torch.bool
                else AttentionMask(att_mask, is_causal=False)
            )

        # Handle a possibly deferred causal mask handling
        mask = self.mask
        if self.causal and self.mask is None:
            mask = AttentionMask.make_causal(
                seq_len=q.shape[-2],
                to_seq_len=q.shape[-2],
                device=q.device,
                dtype=q.dtype,
            )

        # Merge the optional causal mask and the user-provided mask
        if mask is not None:
            mask = mask.to(dtype=q.dtype, device=q.device)

            att_mask = att_mask + mask if att_mask is not None else mask

        # Try to handle a case where the sequence is smaller than the mask
        if (
            att_mask is not None
            and q.shape[-2] == k.shape[-2]
            and q.shape[-2] < att_mask.shape[1]
        ):
            if isinstance(att_mask, AttentionMask):
                att_mask = att_mask.make_crop(seq_len=q.shape[-2])
            else:
                logger.error(
                    "Mismatching sparse attention mask and sequence length."
                    + " Please pad the inputs or adjust the attention mask"
                )
                raise NotImplementedError

        if att_mask is not None:
            att_mask = att_mask.values.expand(q.shape[0], -1, -1)

        # Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S)
        y = xops.memory_efficient_attention(
            query=q, key=k, value=v, p=self.dropout, attn_bias=att_mask, 
        )
        return y

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants