You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It is generally useful to allow third-party, upstream software to customize the LoRa train function (in mlx_lm.tuner.trainer). This was the motivation for allowing user-provided loss, iterate_batches, and training_callback functions. Another useful extension mechanism is to enable the user to specify a custom TrainingArgs data class and instance to the train function.
Other configuration-driven training capabilities could be built on the existing architecture with this change.
However, the TrainingArgs instance is not passed on to the training_callback or iterate_batches functions. Doing so would allow upstream software to pass additional information without changing the signatures of those functions.
For an example of a specific situation where making this signature change could be useful, see where the iterate_batches function needed a breaking change to its signature to support user-specified 'response generation tokens' that could be used to identify the boundaries between input tokens and the rest of the sequence for supporting input masking (or completion only) instruction tuning: #484 & #1103
Specifying these tokens on a custom TrainingArgs data instance passed down to the iterate_batches functions would facilitate this extension and similar ones without breaking changes to the signatures of those methods.
The same principle could apply to the training_callback function for other scenarios. To do so, all that would be required is to pass the training argument instance as an additional keyword argument that defaults to None. This would not break the current signatures of the methods but would also reduce the need to break them for subsequent extensions.
The text was updated successfully, but these errors were encountered:
chimezie
changed the title
Improve support for custom LoRa train method
Improve support for customizing the LoRa train method
Jan 26, 2025
It is generally useful to allow third-party, upstream software to customize the LoRa train function (in mlx_lm.tuner.trainer). This was the motivation for allowing user-provided loss, iterate_batches, and training_callback functions. Another useful extension mechanism is to enable the user to specify a custom TrainingArgs data class and instance to the train function.
Other configuration-driven training capabilities could be built on the existing architecture with this change.
However, the TrainingArgs instance is not passed on to the training_callback or iterate_batches functions. Doing so would allow upstream software to pass additional information without changing the signatures of those functions.
For an example of a specific situation where making this signature change could be useful, see where the iterate_batches function needed a breaking change to its signature to support user-specified 'response generation tokens' that could be used to identify the boundaries between input tokens and the rest of the sequence for supporting input masking (or completion only) instruction tuning: #484 & #1103
Specifying these tokens on a custom TrainingArgs data instance passed down to the iterate_batches functions would facilitate this extension and similar ones without breaking changes to the signatures of those methods.
The same principle could apply to the training_callback function for other scenarios. To do so, all that would be required is to pass the training argument instance as an additional keyword argument that defaults to None. This would not break the current signatures of the methods but would also reduce the need to break them for subsequent extensions.
The text was updated successfully, but these errors were encountered: