Adds the CPO Alignment Loss Function #382
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
CPO is almost the same as DPO with the major difference being that the Reference Model in CPO is assumed to be a Uniform distribution. This assumption leads to the cancellation of all terms related to the reference model.
This corresponds to equation 3 in the paper. Additionally CPO also assumes a scaling factor alpha for the NLL loss on the preferred response. In TRL this corresponds to the CPOTrainer using a
loss_type="sigmoid"
We also refactor the test cases for chunked loss functions to include a generic
HFAlignmentLoss
base class that takes care some of the plumbing work to correctly generate batches of input, calculate the NLLoss etc. All future test cases can inherit from this class and just implement thealignment_loss
function to compare implementation in the TRL lib versus the custom impl.Testing Done
A100-80G-SXM
Benchmark Results:
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence