Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the CPO Alignment Loss Function #382

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

pramodith
Copy link
Collaborator

@pramodith pramodith commented Nov 14, 2024

Summary

CPO is almost the same as DPO with the major difference being that the Reference Model in CPO is assumed to be a Uniform distribution. This assumption leads to the cancellation of all terms related to the reference model.

$$CPOLoss = -\log(\sigma(\beta\log(\pi_\theta(y_c|x)) - \beta\log(\pi_\theta(y_r|x))))$$

This corresponds to equation 3 in the paper. Additionally CPO also assumes a scaling factor alpha for the NLL loss on the preferred response. In TRL this corresponds to the CPOTrainer using a loss_type="sigmoid"

We also refactor the test cases for chunked loss functions to include a generic HFAlignmentLoss base class that takes care some of the plumbing work to correctly generate batches of input, calculate the NLLoss etc. All future test cases can inherit from this class and just implement the alignment_loss function to compare implementation in the TRL lib versus the custom impl.

Testing Done

A100-80G-SXM

Benchmark Results:

Screenshot 2024-11-14 at 5 17 42 PM
Screenshot 2024-11-14 at 5 23 00 PM

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant