Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for customizing the LoRa train method #1224

Open
chimezie opened this issue Jan 26, 2025 · 1 comment
Open

Improve support for customizing the LoRa train method #1224

chimezie opened this issue Jan 26, 2025 · 1 comment

Comments

@chimezie
Copy link
Contributor

It is generally useful to allow third-party, upstream software to customize the LoRa train function (in mlx_lm.tuner.trainer). This was the motivation for allowing user-provided loss, iterate_batches, and training_callback functions. Another useful extension mechanism is to enable the user to specify a custom TrainingArgs data class and instance to the train function.

Other configuration-driven training capabilities could be built on the existing architecture with this change.

However, the TrainingArgs instance is not passed on to the training_callback or iterate_batches functions. Doing so would allow upstream software to pass additional information without changing the signatures of those functions.

For an example of a specific situation where making this signature change could be useful, see where the iterate_batches function needed a breaking change to its signature to support user-specified 'response generation tokens' that could be used to identify the boundaries between input tokens and the rest of the sequence for supporting input masking (or completion only) instruction tuning: #484 & #1103

Specifying these tokens on a custom TrainingArgs data instance passed down to the iterate_batches functions would facilitate this extension and similar ones without breaking changes to the signatures of those methods.

The same principle could apply to the training_callback function for other scenarios. To do so, all that would be required is to pass the training argument instance as an additional keyword argument that defaults to None. This would not break the current signatures of the methods but would also reduce the need to break them for subsequent extensions.

@chimezie chimezie changed the title Improve support for custom LoRa train method Improve support for customizing the LoRa train method Jan 26, 2025
@awni
Copy link
Member

awni commented Jan 28, 2025

It seems reasonable to me! If you are interested to submit a PR that would be great.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants