You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The text was updated successfully, but these errors were encountered:
shoveller86
changed the title
Can provide a End-to-end example (like a Transformer block) to show the sharding propagation + spmdization
Can provide a End-to-end example (like a Transformer block) to show the sharding propagation + spmdization ?
Jan 10, 2025
I'm interested in how to go from a shardy-fied graph after propagation into some sort of SPMD form too (loop over modified ops or outlined to a function or similar).
Hey! So we are integrated in JAX, and what you could do is write a little MLP or transformer block in JAX, and since we are integrated in XLA, you can use the xla_dump_to command to dump the module. There will be a shardy/ directory and inside you should see:
sdy_module_before_xla_import.mlir // feel free to ignore, this has to do with how Shardy is integrated in XLA
sdy_module_before_sdy_import.mlir // initial MLIR module that Shardy will see
sdy_module_after_sdy_import.mlir // MLIR module after some preprocessing to make propagation work correctly. E.g. constant splitting, adding data flow edges, etc
sdy_module_after_user_priority_0.mlir // If you use user priorities on the sharding (which isn't in JAX at the moment), one will print for each of your priorities
sdy_module_after_propagation.mlir // the module after propagation finishes.
sdy_module_after_sdy_export.mlir // cleaned up the module from what was done in sdy_import
Longer term we want to possibly have a little markdown tutorial walking through each of these steps starting from a basic MLIR module, but let me know if this is enough and if you have any questions!
No description provided.
The text was updated successfully, but these errors were encountered: