Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371

Open
2 of 11 tasks
shivam15s opened this issue Nov 8, 2024 · 8 comments
Open
2 of 11 tasks

[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371

shivam15s opened this issue Nov 8, 2024 · 8 comments
Assignees

Comments

@shivam15s
Copy link
Collaborator

shivam15s commented Nov 8, 2024

🚀 The feature, motivation and pitch

We want to support various alignment and distillation loss functions.
Refer this PR on ORPO: #362

Progress

Alignment

Distillation

  • KL divergence
  • cosine_similarity
  • earth_mover_distance
  • JSD
  • KVD

Design

Approach Overview:

The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:

  1. Modular Optimization Process:
    • Every alignment algorithm’s optimization can be broken into three key steps:
      • Linear layer computation
      • Loss computation
      • Gradient calculation
  2. Fused Linear and Loss Computation:
    • Similar to FLCE, we aim to fuse the linear layer with the loss computation for efficiency.
  3. Chunking & Forward Optimization:
    • Since this is the final step in the model’s forward pass, we can also compute gradients directly during the forward pass instead of waiting for a separate backward pass.
    • We also chunk the input within the forward pass of the model, allowing significant reduction in peak gpu memory required.
  4. Torch Compile for Kernel Optimization:
    • Instead of manually handling kernel-level optimizations, we let torch.compile automatically optimize kernel execution. This reduces the need for low-level optimizations while still achieving performance gains.

By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.

Key Findings

By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger.
References:

  1. Torch compiled FLCE is 2x faster than the current FLCE #227
  2. https://gist.github.com/Chillee/22cd93e11b887db1f596ab754d60a899#file-lce_benchmark-py

Interface

Have a base class FlexChunkLoss that handles chunking, accumulation and compiling strategies.
A custom loss class wraps the FlexChunkLoss and implements the loss fn that operates on a given chunk.

class Mycustomloss(FlexChunkLoss):
  def loss_fn(...):
    ..do something here

Alternatives

No response

Additional context

No response

@shivam15s shivam15s changed the title Liger FlexChunkLoss: Supporting various alignment (DPO, ORPO, IRPO, CPO, etc) and distillation (KL divergence, cosine_similarity, earth_mover_distance, etc) loss functions [RFC] Liger FlexChunkLoss: Supporting various alignment (DPO, ORPO, IRPO, CPO, etc) and distillation (KL divergence, cosine_similarity, earth_mover_distance, etc) loss functions Nov 8, 2024
@austin362667
Copy link
Contributor

take DPO

@hongpeng-guo
Copy link
Collaborator

I can take fused linear kl div. BTW, really nice illustration on the chunk linear op fusion from the paper. Very clear to new contributors 😄

@pramodith
Copy link
Collaborator

pramodith commented Nov 13, 2024

@shivam15s @ByronHsu I think we should also consider including some of the loss functions commonly used for training embedding models, especially the popular ones supported in Sentence transformers.

It's quite common for embedding models to require large batch sizes to be trained well. Coupled with the fact that their batch/input structure is kind of similar to RLHF where we have positive and negative pairs, I believe that this can prove to be useful. I'd recommend supporting CoSENTLoss, MatryokshaLoss and TripleLoss for starters https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cosentloss. Perhaps this can be its own roadmap separate to this one although the idea of chunking and fusing remains the same.

@ByronHsu
Copy link
Collaborator

@pramodith that is a good idea! do you know if the models in embedding also has large vocab and suffer from memory bottleneck?

@pramodith
Copy link
Collaborator

@ByronHsu most embedding models have a final Linear layer of shape (hidden_dim, hidden_dim), so vocab size doesn't really come into the picture for them so you're right to point it out, but it is common to have an effective batch size of 65k

@ByronHsu
Copy link
Collaborator

Then i think chunk loss is still helpful given the large batch size

@pramodith
Copy link
Collaborator

Then i think chunk loss is still helpful given the large batch size

Yes, I think so too. I can give this a try after we wrap up all the important RLHF and distillation losses. I'll also get Tom Aarsen's perspective since he's the lead of Sentence Transformers.

@ByronHsu ByronHsu changed the title [RFC] Liger FlexChunkLoss: Supporting various alignment (DPO, ORPO, IRPO, CPO, etc) and distillation (KL divergence, cosine_similarity, earth_mover_distance, etc) loss functions [RFC] Liger FlexChunkLoss: Supporting various alignment and distillation loss functions Nov 15, 2024
@ByronHsu ByronHsu pinned this issue Nov 15, 2024
@ByronHsu ByronHsu changed the title [RFC] Liger FlexChunkLoss: Supporting various alignment and distillation loss functions [RFC] Liger FlexChunkLoss: Alignment and Distillation loss Nov 15, 2024
ByronHsu pushed a commit that referenced this issue Nov 15, 2024
## Summary

Add support for a fused, torch-compiled, and chunked DPO ([Direct
Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss
kernel, as requested in
#371.
This implementation is largely based on the excellent work done on ORPO
(#362) by @shivam15s.

### DPO Loss Formulation

In a reference setting (not reference free):

$$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) -
\log(\pi_\theta(y_r|x))$$

$$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) -
\log(\pi_{\theta_{\text{ref}}}(y_c|x)) +
\log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$

Corresponds to:
```python
# Policy model log probabilities
policy_chosen_logps = log_probs(policy_chosen_logits)
policy_rejected_logps = log_probs(policy_rejected_logits)

# Reference model log probabilities
ref_chosen_logps = log_probs(ref_chosen_logits)
ref_rejected_logps = log_probs(ref_rejected_logits)

# Compute advantages
chosen_advantages = policy_chosen_logps - ref_chosen_logps
rejected_advantages = policy_rejected_logps - ref_rejected_logps

# DPO loss
logits_diff = (chosen_advantages - rejected_advantages) / beta
losses = -F.logsigmoid(logits_diff)
```

In this PR:

1. The above mathematical equation shows that to maximize the reward
difference, we get formula:
    $$r_θ(x_c) - r_θ(x_r)$$
2. This can be further optimized using just:
    $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$
3. So, the code implements:
    ```python
logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) -
π_θ(x_r))/β
losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff))
    ```
4. Sum up DPO and NLL:
    $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$

## Testing Done


![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731)

![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6)


- Hardware Type: **NVIDIA L40S (48G)**
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Austin Liu <[email protected]>
Co-authored-by: shivam15s <[email protected]>
@pramodith
Copy link
Collaborator

#take Simpo and Irpo since they are just extensions of CPO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants