-
Notifications
You must be signed in to change notification settings - Fork 73
fix(torchrun): Omit empty arguments and correct nproc_per_node type #661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The command generation logic is updated to dynamically build the torchrun command, excluding arguments that are empty or None. This prevents them from overriding environment variables, ensuring that torchrun can correctly inherit its configuration. An exception is made for integer arguments where 0 is a valid value. Additionally, the nproc_per_node argument type has been changed from int to str to support special values accepted by PyTorch, such as 'auto', 'gpu', and 'cpu'. Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88 Signed-off-by: Saad Zaher <[email protected]>
src/instructlab/training/config.py
Outdated
""" | ||
|
||
nproc_per_node: int | ||
nproc_per_node: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean to make this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
src/instructlab/training/main_ds.py
Outdated
# build args for this file. Ignore empty or unset values except int values | ||
for key, value in train_args.model_dump(exclude_none=True).items(): | ||
# avoid ignoring int attrs with value = 0 | ||
if not isinstance(value, int) and (not value or value == ""): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How would this handle booleans?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated this one to only check for string types.
src/instructlab/training/main_ds.py
Outdated
# avoid ignoring int attrs with value = 0 | ||
if not isinstance(value, int) and (not value or value == ""): | ||
continue | ||
command.append(f"--{key}={value}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you verified that all of our CLI arguments are perfectly 1:1 with the variable names we're using here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated this one to only process torchrun
args and leave the scripts args as they're not perfectly 1:1 mapped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @szaher , I think the changes here are reasonable and had a few questions about the implementation.
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
# this will tell the model construct to ignore | ||
# extra arguments that aren't part of this model | ||
class Config: | ||
extra = "ignore" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szaher Do you know when this would be the case? If our goal here is to dynamically build the torchrun
command using the defined interface, this seems like it now opens the floor up for users to pass invalid arguments through torchrun. This means that any incorrect interface usage wouldn't be detected until runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact this will actually drop additionally provided arguments and only keep torchrun ones
torchrun_defaults = {
'nnodes': 1, 'node_rank': 0, 'rdzv_id': 0, 'rdzv_endpoint': '',
'nproc_per_node': 2, "fake_arg": "what"
}
y = TorchrunArgs(**torchrun_defaults)
print(y)
TorchrunArgs(nproc_per_node=2, nnodes=1, node_rank=0, rdzv_id=0, rdzv_endpoint='')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, that's fine then.
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
Signed-off-by: Saad Zaher <[email protected]>
The command generation logic is updated to dynamically build the torchrun command, excluding arguments that are empty or None. This prevents them from overriding environment variables, ensuring that torchrun can
correctly inherit its configuration. An exception is made for integer arguments where 0 is a valid value.
Additionally, the nproc_per_node argument type has been changed from int to str to support special values
accepted by PyTorch, such as 'auto', 'gpu', and 'cpu'.
Reference: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py#L77-L88