Skip to content

Commit 1bc9054

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Move sharding table above the code example
PiperOrigin-RevId: 751139503
1 parent 8450b71 commit 1bc9054

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

README.md

+6-6
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,12 @@ explicit in JAX types, inspectable using `jax.typeof`;
189189
where you have a per-device view of data
190190
and computation, and can communicate with explicit collectives.
191191

192+
| Mode | View? | Explicit sharding? | Explicit Collectives? |
193+
|---|---|---|---|
194+
| Auto | Global |||
195+
| Explicit | Global |||
196+
| Manual | Per-device |||
197+
192198
```python
193199
from jax.sharding import set_mesh, AxisType, PartitionSpec as P
194200
mesh = jax.make_mesh((8,), ('data',), axis_types=(AxisType.Explicit,))
@@ -210,12 +216,6 @@ param_grads = gradfun(params, (inputs, targets))
210216
See the [tutorial](https://docs.jax.dev/en/latest/sharded-computation.html) and
211217
[advanced guides](https://docs.jax.dev/en/latest/advanced_guide.html) for more.
212218

213-
| Mode | View? | Explicit sharding? | Explicit Collectives? |
214-
|---|---|---|---|
215-
| Auto | Global |||
216-
| Explicit | Global |||
217-
| Manual | Per-device |||
218-
219219
## Gotchas and sharp bits
220220

221221
See the [Gotchas

0 commit comments

Comments
 (0)