Skip to content

Commit

Permalink
add optimziation_barrier to overlap P2P in RS matmul
Browse files Browse the repository at this point in the history
Signed-off-by: Xiaowei Ren <[email protected]>
  • Loading branch information
xrennvidia committed Aug 1, 2023
1 parent 2ca689f commit 9af0707
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
47 changes: 47 additions & 0 deletions rs_matmul/barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax
jax.default_device = jax.devices("cpu")

from jax import core
from jax.interpreters.xla import apply_primitive
from jax.tree_util import tree_flatten,tree_unflatten
import jaxlib.mlir.dialects.stablehlo as hlo
from jax._src import util
from jax._src.interpreters import mlir
from functools import partial
#print(jax.devices())

def _optimization_barrier_abstract_eval(*args):
return args

def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))

def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))

optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)

import jax.experimental.shard_map as shard_map
shard_map.register_standard(optimization_barrier_p) # doesn't change replication

#import jax.numpy as jnp
#def f(y, z, a):
# d = jnp.dot(y, z)
# d = _optimization_barrier(d)
# acc = d + a
# return acc
#
#y = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
#z = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
#a = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
#print(jax.jit(f).lower(y, z, a).as_text())
3 changes: 3 additions & 0 deletions rs_matmul/rs_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_async_collective_permute=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_highest_priority_async_stream=true'
#os.environ['XLA_FLAGS'] += ' --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=/results/hlo/rs_matmul_dp8tp1'

from .barrier import _optimization_barrier

with_sharding_constraint = nn_partitioning.with_sharding_constraint

def rs_matmul(lhs, rhs):
Expand All @@ -24,6 +26,7 @@ def collective_matmul(i, out):
perm=[(j, (j + 1) % axis_size) for j in range(axis_size)])
lhs_idx = (axis_idx - i - 1) % axis_size
update = lhs[:, lhs_idx, ...] @ rhs
update = _optimization_barrier(update)
out = out + update
return out

Expand Down

0 comments on commit 9af0707

Please sign in to comment.