-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor LigerFusedLinearPreferenceBase
#381
Refactor LigerFusedLinearPreferenceBase
#381
Conversation
This was quick. Great refactor! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
_input, | ||
weight, | ||
target, | ||
bias, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I wrongly assumed that forward just takes positional arguments.
Could you also make input, weight target, bias keyword args?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Summary
This PR refactors the
LigerFusedLinearPreferenceBase
class to contain an abstractmethod corresponding to the calculation of the loss that needs to be implemented by all sub-classes.It also adds a new function to the class called
_compute_loss
which is mostly the same as the_compute_orpo_loss
function introduced in #362 but makes it generic to calculate the NLL/Cross Entropy Loss plus accepts a custom loss function that implements a new alignment loss function.Most RLHF/RLAIF/Alignment algorithms state their final loss as
NLL + Beta * (Alignment_Loss)
so adding the NLL logic inside the base class reduces repeated code.The _compute_loss function accepts
Testing Done
On A100-80G-SXM
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence