-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Liger FlexChunkLoss: Alignment and Distillation loss #371
Comments
take DPO |
I can take fused linear kl div. BTW, really nice illustration on the chunk linear op fusion from the paper. Very clear to new contributors 😄 |
@shivam15s @ByronHsu I think we should also consider including some of the loss functions commonly used for training embedding models, especially the popular ones supported in Sentence transformers. It's quite common for embedding models to require large batch sizes to be trained well. Coupled with the fact that their batch/input structure is kind of similar to RLHF where we have positive and negative pairs, I believe that this can prove to be useful. I'd recommend supporting |
@pramodith that is a good idea! do you know if the models in embedding also has large vocab and suffer from memory bottleneck? |
@ByronHsu most embedding models have a final Linear layer of shape (hidden_dim, hidden_dim), so vocab size doesn't really come into the picture for them so you're right to point it out, but it is common to have an effective batch size of 65k |
Then i think chunk loss is still helpful given the large batch size |
Yes, I think so too. I can give this a try after we wrap up all the important RLHF and distillation losses. I'll also get Tom Aarsen's perspective since he's the lead of Sentence Transformers. |
## Summary Add support for a fused, torch-compiled, and chunked DPO ([Direct Preference Optimization](https://arxiv.org/html/2305.18290v3)) loss kernel, as requested in #371. This implementation is largely based on the excellent work done on ORPO (#362) by @shivam15s. ### DPO Loss Formulation In a reference setting (not reference free): $$r_\theta(x,y_c) - r_\theta(x,y_r) = \log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x))$$ $$-\log(\sigma((\log(\pi_\theta(y_c|x)) - \log(\pi_\theta(y_r|x)) - \log(\pi_{\theta_{\text{ref}}}(y_c|x)) + \log(\pi_{\theta_{\text{ref}}}(y_r|x)))/\beta))$$ Corresponds to: ```python # Policy model log probabilities policy_chosen_logps = log_probs(policy_chosen_logits) policy_rejected_logps = log_probs(policy_rejected_logits) # Reference model log probabilities ref_chosen_logps = log_probs(ref_chosen_logits) ref_rejected_logps = log_probs(ref_rejected_logits) # Compute advantages chosen_advantages = policy_chosen_logps - ref_chosen_logps rejected_advantages = policy_rejected_logps - ref_rejected_logps # DPO loss logits_diff = (chosen_advantages - rejected_advantages) / beta losses = -F.logsigmoid(logits_diff) ``` In this PR: 1. The above mathematical equation shows that to maximize the reward difference, we get formula: $$r_θ(x_c) - r_θ(x_r)$$ 2. This can be further optimized using just: $$-log(σ((π_θ(x_c) - π_θ(x_r))/β))$$ 3. So, the code implements: ```python logits_diff = (chosen_logps - rejected_logps) / beta # (π_θ(x_c) - π_θ(x_r))/β losses = -F.logsigmoid(logits_diff) # -log(σ(logits_diff)) ``` 4. Sum up DPO and NLL: $$L_{DPO+NLL} = L_{DPO}+αL_{NLL}$$ ## Testing Done ![dpo_loss_memory](https://github.com/user-attachments/assets/d48965a2-bab7-4a81-9872-a43826106731) ![dpo_loss_speed](https://github.com/user-attachments/assets/10ab33c3-a905-435f-886b-67c911b8fff6) - Hardware Type: **NVIDIA L40S (48G)** - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Signed-off-by: Austin Liu <[email protected]> Co-authored-by: shivam15s <[email protected]>
#take Simpo and Irpo since they are just extensions of CPO. |
🚀 The feature, motivation and pitch
We want to support various alignment and distillation loss functions.
Refer this PR on ORPO: #362
Progress
Alignment
Distillation
Design
Approach Overview:
The core idea is to extend the methods used in chunked Fused Linear Cross Entropy (FLCE) to various alignment algorithms. Here's how the process is structured:
By combining these strategies, we efficiently optimize alignment algorithms while also simplifying development.
Key Findings
By leveraging torch.compile alongside optimization techniques like chunking, online softmax, etc, we observed close to custom triton kernel performance and reduced development time. This is why we want to introduce torch.compile as a key component of Liger.
References:
Interface
Have a base class
FlexChunkLoss
that handles chunking, accumulation and compiling strategies.A custom loss class wraps the
FlexChunkLoss
and implements the loss fn that operates on a given chunk.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: