Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Pixtral Support #253

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ loss.backward()
| LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Pixtral | `liger_kernel.transformers.apply_liger_kernel_to_pixtral` | RoPE, RMSNorm, SwiGLU|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_pixtral,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)
Expand Down
103 changes: 103 additions & 0 deletions src/liger_kernel/transformers/model/pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Optional, Tuple, Union

import torch
from transformers.modeling_outputs import BaseModelOutput
from transformers.models.pixtral.modeling_pixtral import (
_CONFIG_FOR_DOC,
PIXTRAL_INPUTS_DOCSTRING,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)


@add_start_docstrings_to_model_forward(PIXTRAL_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def lce_forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Copy paste Pixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy

Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

if not return_dict:
return tuple(
v for v in [hidden_states, encoder_states, all_attentions] if v is not None
)

return BaseModelOutput(
last_hidden_states=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
31 changes: 31 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.mixtral import lce_forward as mixtral_lce_forward
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
from liger_kernel.transformers.model.pixtral import lce_forward as pixtral_lce_forward
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
Expand Down Expand Up @@ -139,6 +140,35 @@ def apply_liger_kernel_to_mixtral(
modeling_mixtral.MixtralBlockSparseTop2MLP = LigerBlockSparseTop2MLP


def apply_liger_kernel_to_pixtral(
rope: bool = True,
rms_norm: bool = True,
fused_linear_cross_entropy: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Mistral models

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
fused_linear_cross_entropy (bool): If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
"""
from transformers.models.pixtral import modeling_pixtral

if rope:
modeling_pixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_pixtral.MistralRMSNorm = LigerRMSNorm
if fused_linear_cross_entropy:
modeling_pixtral.PixtralTransformer.forward = pixtral_lce_forward
if swiglu:
modeling_pixtral.MistralMLP = LigerSwiGLUMLP


def apply_liger_kernel_to_gemma(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -339,6 +369,7 @@ def apply_liger_kernel_to_phi3(
"llama": apply_liger_kernel_to_llama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"pixtral": apply_liger_kernel_to_pixtral,
"qwen2": apply_liger_kernel_to_qwen2,
"qwen2_vl": apply_liger_kernel_to_qwen2_vl,
"phi3": apply_liger_kernel_to_phi3,
Expand Down
6 changes: 3 additions & 3 deletions src/liger_kernel/triton/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ def apply_liger_triton_cache_manager():
Experimental feature to get around transient FileNotFoundError in triton compilation.
For more details please see https://github.com/triton-lang/triton/pull/4295
"""
os.environ[
"TRITON_CACHE_MANAGER"
] = "liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
os.environ["TRITON_CACHE_MANAGER"] = (
"liger_kernel.triton.monkey_patch:LigerTritonFileCacheManager"
)
36 changes: 36 additions & 0 deletions test/convergence/test_mini_models_no_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM
from transformers.models.pixtral import PixtralConfig, PixtralTransformer
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

from liger_kernel.transformers import (
Expand All @@ -26,6 +27,7 @@
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_pixtral,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)
Expand Down Expand Up @@ -174,6 +176,24 @@
attn_implementation="sdpa",
),
),
"mini_pixtral": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_pixtral,
model_class=PixtralTransformer,
mini_model_config=PixtralConfig(
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
image_size=1024,
patch_size=16,
hidden_activation="gelu",
layer_norm_eps=1e-5,
attention_dropout=0.0,
rope_theta=10000.0,
tie_word_embeddings=False,
),
),
"mini_gemma1": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gemma,
model_class=GemmaForCausalLM,
Expand Down Expand Up @@ -498,6 +518,22 @@ def run_mini_model(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
("mini_pixtral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
"mini_pixtral",
32,
1e-4,
torch.bfloat16,
1e-8,
1e-5,
1e-2,
1e-5,
1e-2,
1e-5,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
pytest.param(
Expand Down
1 change: 1 addition & 0 deletions test/transformers/test_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_import_from_root():
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_pixtral,
apply_liger_kernel_to_qwen2,
apply_liger_kernel_to_qwen2_vl,
)
Expand Down