Skip to content

Prefer flex attention when available#3979

Open
danielhanchen wants to merge 3 commits intomainfrom
flex-attn-prefer
Open

Prefer flex attention when available#3979
danielhanchen wants to merge 3 commits intomainfrom
flex-attn-prefer

Conversation

@danielhanchen
Copy link
Contributor

@danielhanchen danielhanchen commented Feb 4, 2026

Summary

  • Default to flex attention when supported and keep a single global gate via UNSLOTH_ENABLE_FLEX_ATTENTION.
  • Use the selected attention implementation for LLM and vision model loading.
  • Keep AutoConfig attn_implementation aligned with the chosen backend.

Testing

  • Not run (not requested)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the model loading process by introducing a mechanism to prefer FlexAttention when it is supported and explicitly enabled via an environment variable. This aims to improve performance by leveraging more efficient attention implementations where possible, ensuring that both LLM and vision models can dynamically adapt their attention backend while maintaining consistent configuration.

Highlights

  • FlexAttention Helper: A new helper function, prefer_flex_attn_if_supported, has been introduced to dynamically determine if FlexAttention should be used based on environment variables, Torch support, and model capabilities. This function sets the _attn_implementation and attn_implementation attributes on the model's configuration if FlexAttention is preferred and supported.
  • LLM Integration: The unsloth/models/llama.py file has been updated to utilize the new prefer_flex_attn_if_supported helper. When loading Llama models, the attention implementation will now default to FlexAttention if available and opted in, otherwise falling back to 'eager' attention.
  • Vision Model Integration: The unsloth/models/vision.py file now also incorporates the FlexAttention preference logic. It dynamically checks for FlexAttention support for vision models, potentially fetching AutoConfig if not already available, and sets the attn_implementation accordingly. This ensures vision models can also benefit from FlexAttention when applicable.
  • Configuration Alignment: The changes ensure that the attn_implementation attribute within the AutoConfig object is consistently aligned with the chosen attention backend (FlexAttention, SDPA, or eager), providing a unified configuration across different model types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/models/_utils.py
    • Exported the new prefer_flex_attn_if_supported function.
    • Added the prefer_flex_attn_if_supported function, which checks for an environment variable, Torch FlexAttention availability, and model support before setting attn_implementation to 'flex_attention' in the model config.
  • unsloth/models/llama.py
    • Integrated the prefer_flex_attn_if_supported helper to determine the preferred_attn_impl.
    • Modified from_pretrained calls to use preferred_attn_impl instead of a hardcoded 'eager' attention implementation.
  • unsloth/models/vision.py
    • Added logic to check the UNSLOTH_PREFER_FLEX_ATTENTION environment variable.
    • Dynamically retrieves AutoConfig and model_class if not present to evaluate FlexAttention support for vision models.
    • Sets default_attn_impl to 'flex_attention' if preferred, otherwise 'sdpa'.
    • Updated the attn_implementation parameter in kwargs and AutoConfig.from_pretrained to use the determined attention implementation.
    • Adjusted the condition for falling back from SDPA to eager attention to account for the new FlexAttention preference.
Activity
  • The author indicated that testing was not run as it was not requested for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a helper function to prefer 'flex attention' when it's available and enabled via an environment variable. The changes are well-structured and improve performance options. My review includes a couple of suggestions to improve code clarity and robustness by refining exception handling and removing a redundant check. Overall, a good addition.

Comment on lines +205 to +202
setattr(config, "attn_implementation", "flex_attention")
return "flex_attention"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to handle, such as ImportError or AttributeError. This will make the code more robust.

Suggested change
setattr(config, "attn_implementation", "flex_attention")
return "flex_attention"
except (ImportError, AttributeError):
return None

Comment on lines 520 to 532
flex_attn_impl = None
if os.environ.get("UNSLOTH_PREFER_FLEX_ATTENTION", "0") == "1":
if auto_config is None:
auto_config = AutoConfig.from_pretrained(
model_name,
token = token,
trust_remote_code = trust_remote_code,
)
try:
model_class = auto_model._model_mapping[auto_config.__class__]
except Exception:
model_class = None
flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This code block can be simplified.

  1. The check for the UNSLOTH_PREFER_FLEX_ATTENTION environment variable is redundant, as prefer_flex_attn_if_supported already handles this internally. Removing this outer if will simplify the code and avoid duplication.
  2. The except Exception: is too broad. It's better to catch a more specific KeyError which might be raised when accessing auto_model._model_mapping.

Refactoring this will make the code cleaner and more maintainable.

        if auto_config is None:
            auto_config = AutoConfig.from_pretrained(
                model_name,
                token = token,
                trust_remote_code = trust_remote_code,
            )
        try:
            model_class = auto_model._model_mapping[auto_config.__class__]
        except KeyError:
            model_class = None
        flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)

@danielhanchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for flex_attention as a preferred backend, which is a good performance enhancement. The changes are applied consistently across both language and vision models. My feedback focuses on improving code robustness by narrowing down exception handling and increasing efficiency by avoiding redundant configuration loading. Overall, the changes look good and align with the goal of optimizing attention mechanisms.

Comment on lines +205 to +204
except Exception:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions. In this case, since you are mainly guarding against a missing transformers utility, except ImportError: would be more appropriate. If other exceptions are possible, they should be caught explicitly or logged.

Suggested change
except Exception:
return None
except ImportError:
return None

Comment on lines +521 to +531
if auto_config is None:
auto_config = AutoConfig.from_pretrained(
model_name,
token = token,
trust_remote_code = trust_remote_code,
)
try:
model_class = auto_model._model_mapping[auto_config.__class__]
except Exception:
model_class = None
flex_attn_impl = prefer_flex_attn_if_supported(model_class, auto_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current logic loads the model configuration twice using AutoConfig.from_pretrained (here and later around line 671). This is inefficient as it can involve unnecessary network and disk I/O.

Consider refactoring to load the configuration only once. You can load auto_config, modify it with the chosen attention implementation, and then reuse it as model_config later, avoiding the second from_pretrained call.

For example:

  1. Load auto_config once at the beginning.
  2. Determine flex_attn_impl and update kwargs['attn_implementation'].
  3. Instead of reloading, directly set the attn_implementation on the auto_config object if it wasn't already set by prefer_flex_attn_if_supported.
  4. Use the modified auto_config as model_config.

This would make the model loading process more efficient.

Comment on lines +529 to +530
except Exception:
model_class = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: is not ideal as it can mask bugs. Here, you are accessing a dictionary _model_mapping. A KeyError is a more specific exception to catch if the key is not found. If auto_model might not have _model_mapping, an AttributeError could also be caught.

Suggested change
except Exception:
model_class = None
except (KeyError, AttributeError):
model_class = None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant