-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
*~ | ||
__pycache__ | ||
hlo_out | ||
nsys | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import argparse | ||
import numpy | ||
import jax | ||
import jax.numpy as jnp | ||
from jax.sharding import Mesh, PartitionSpec | ||
from jax.experimental.shard_map import shard_map | ||
from jax.experimental.pjit import pjit | ||
from flax.linen import partitioning as nn_partitioning | ||
|
||
import os | ||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_async_collective_permute=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_highest_priority_async_stream=true' | ||
#os.environ['XLA_FLAGS'] += ' --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=/results/hlo/ag_matmul_dp8tp1' | ||
|
||
with_sharding_constraint = nn_partitioning.with_sharding_constraint | ||
|
||
def ag_matmul(lhs, rhs): | ||
axis_size = jax.lax.psum(1, axis_name='tp') | ||
axis_idx = jax.lax.axis_index('tp') | ||
def collective_matmul(i, carrys): | ||
out, lhs = carrys | ||
update = lhs @ rhs | ||
# in parallel, we shift the lhs around the next one | ||
lhs = jax.lax.ppermute( | ||
lhs, | ||
'tp', | ||
perm=[(j, (j + 1) % axis_size) for j in range(axis_size)]) | ||
update_idx = (axis_idx + i) % axis_size | ||
out = jax.lax.dynamic_update_slice(out, update, (0, update_idx*lhs.shape[1], 0)) | ||
return out, lhs | ||
|
||
out = jnp.empty((lhs.shape[0], lhs.shape[1]*axis_size, rhs.shape[1]), dtype=lhs.dtype) | ||
out, lhs = jax.lax.fori_loop( | ||
0, axis_size - 1, collective_matmul, (out, lhs)) | ||
|
||
update = lhs @ rhs | ||
update_idx = (axis_idx + axis_size - 1) % axis_size | ||
out = jax.lax.dynamic_update_slice(out, update, (0, update_idx*lhs.shape[1], 0)) | ||
return out | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Matmul overlap with all-gather communication') | ||
parser.add_argument("--dp", dest="dp", type=int, default=8) | ||
parser.add_argument("--tp", dest="tp", type=int, default=1) | ||
parser.add_argument("--batch_size", dest="batch_size", type=int, default=2) | ||
parser.add_argument("--seq_len", dest="seq_len", type=int, default=2048) | ||
parser.add_argument("--hidden_size", dest="hidden_size", type=int, default=12288) | ||
parser.add_argument("--profile", action="store_true") | ||
args = parser.parse_args() | ||
|
||
assert(args.dp * args.tp == len(list(jax.devices()))) | ||
assert(args.seq_len % args.tp == 0) | ||
assert(args.hidden_size % args.tp == 0) | ||
args.batch_size = args.batch_size * args.dp | ||
|
||
dtype = jnp.bfloat16 | ||
key = jax.random.PRNGKey(0) | ||
key1, key2 = jax.random.split(key, 2) | ||
input = jax.random.uniform(key2, (args.batch_size, args.seq_len, args.hidden_size), dtype=dtype) | ||
weight = jax.random.uniform(key2, (args.hidden_size, 4*args.hidden_size), dtype=dtype) | ||
|
||
mesh_shape = {'dp': args.dp, 'tp': args.tp} | ||
in_specs = (PartitionSpec('dp', 'tp', None), PartitionSpec(None, 'tp')) | ||
out_specs = PartitionSpec('dp', None, 'tp') | ||
|
||
mesh = Mesh(numpy.array(jax.devices()).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) | ||
logical_axis_rules = (('batch', 'dp'), ('seq_rs', 'tp'), ('seq_ag', None), ('emb', None), ('mlp', 'tp')) | ||
pjitted_ag_matmul = pjit(shard_map(ag_matmul, mesh, in_specs=in_specs, out_specs=out_specs)) | ||
|
||
if args.profile: | ||
import ctypes | ||
libcudart = ctypes.cdll.LoadLibrary('libcudart.so') | ||
with mesh, nn_partitioning.axis_rules(logical_axis_rules): | ||
input = with_sharding_constraint(input, ('batch', 'seq_rs', 'emb')) | ||
weight = with_sharding_constraint(weight, ('emb', 'mlp')) | ||
for i in range(100): | ||
if i == 9: | ||
libcudart.cudaProfilerStart() | ||
out = pjitted_ag_matmul(input, weight) | ||
libcudart.cudaProfilerStop() | ||
else: | ||
with mesh, nn_partitioning.axis_rules(logical_axis_rules): | ||
input = with_sharding_constraint(input, ('batch', 'seq_rs', 'emb')) | ||
weight = with_sharding_constraint(weight, ('emb', 'mlp')) | ||
for i in range(100): | ||
out = pjitted_ag_matmul(input, weight) | ||
|
||
return out | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import argparse | ||
from functools import partial | ||
import numpy as np | ||
from flax.linen import partitioning as nn_partitioning | ||
import jax | ||
import jax.numpy as jnp | ||
from jax.sharding import Mesh, PartitionSpec | ||
from jax.experimental.pjit import pjit | ||
from jax.experimental.shard_map import shard_map | ||
|
||
from ag_matmul.ag_matmul import ag_matmul | ||
from rs_matmul.rs_matmul import rs_matmul | ||
|
||
import os | ||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_async_collective_permute=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_highest_priority_async_stream=true' | ||
#os.environ['XLA_FLAGS'] += ' --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=/results/hlo/collective_matmul_dp8tp1/' | ||
|
||
with_sharding_constraint = nn_partitioning.with_sharding_constraint | ||
shard_mapped_ag_matmul = None | ||
shard_mapped_rs_matmul = None | ||
|
||
def test_fn(tp_overlap, input, weight1, weight2): | ||
if tp_overlap: | ||
out = shard_mapped_ag_matmul(input, weight1) | ||
else: | ||
out = input @ weight1 | ||
out = with_sharding_constraint(out, ('batch', 'seq_ag', 'mlp')) | ||
if tp_overlap: | ||
out = shard_mapped_rs_matmul(out, weight2) | ||
else: | ||
out = out @ weight2 | ||
return out | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Collective Matmul Unit Test') | ||
parser.add_argument("--dp", dest="dp", type=int, default=8) | ||
parser.add_argument("--tp", dest="tp", type=int, default=1) | ||
parser.add_argument("--batch_size", dest="batch_size", type=int, default=2) | ||
parser.add_argument("--seq_len", dest="seq_len", type=int, default=2048) | ||
parser.add_argument("--hidden_size", dest="hidden_size", type=int, default=12288) | ||
parser.add_argument("--no_tp_overlap", dest="tp_overlap", action="store_false") | ||
parser.add_argument("--profile", action="store_true") | ||
args = parser.parse_args() | ||
|
||
assert(args.dp * args.tp == len(list(jax.devices()))) | ||
assert(args.seq_len % args.tp == 0) | ||
assert(args.hidden_size % args.tp == 0) | ||
args.batch_size = args.batch_size * args.dp | ||
|
||
dtype = jnp.bfloat16 | ||
key = jax.random.PRNGKey(0) | ||
key1, key2 = jax.random.split(key, 2) | ||
input = jax.random.uniform(key2, (args.batch_size, args.seq_len, args.hidden_size), dtype=dtype) | ||
weight1 = jax.random.uniform(key2, (args.hidden_size, 4*args.hidden_size), dtype=dtype) | ||
weight2 = jax.random.uniform(key2, (4*args.hidden_size, args.hidden_size), dtype=dtype) | ||
|
||
mesh_shape = {'dp': args.dp, 'tp': args.tp} | ||
mesh = Mesh(np.array(jax.devices()).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) | ||
logical_axis_rules = (('batch', 'dp'), ('seq_rs', 'tp'), ('seq_ag', None), ('emb', None), ('mlp', 'tp')) | ||
|
||
if args.tp_overlap: | ||
global shard_mapped_ag_matmul, shard_mapped_rs_matmul | ||
ag_matmul_in_specs = (PartitionSpec('dp', 'tp', None), PartitionSpec(None, 'tp')) | ||
ag_matmul_out_specs = PartitionSpec('dp', None, 'tp') | ||
shard_mapped_ag_matmul = shard_map(ag_matmul, mesh, in_specs=ag_matmul_in_specs, out_specs=ag_matmul_out_specs) | ||
rs_matmul_in_specs = (PartitionSpec('dp', None, 'tp'), PartitionSpec('tp', None)) | ||
rs_matmul_out_specs = PartitionSpec('dp', 'tp', None) | ||
shard_mapped_rs_matmul = shard_map(rs_matmul, mesh, in_specs=rs_matmul_in_specs, out_specs=rs_matmul_out_specs) | ||
|
||
pjitted_test_fn = pjit(partial(test_fn, args.tp_overlap), out_shardings=PartitionSpec('dp', 'tp', None)) | ||
|
||
if args.profile: | ||
import ctypes | ||
libcudart = ctypes.cdll.LoadLibrary('libcudart.so') | ||
with mesh, nn_partitioning.axis_rules(logical_axis_rules): | ||
input = with_sharding_constraint(input, ('batch', 'seq_rs', 'emb')) | ||
weight1 = with_sharding_constraint(weight1, ('emb', 'mlp')) | ||
weight2 = with_sharding_constraint(weight2, ('mlp', 'emb')) | ||
for i in range(100): | ||
if i == 9: | ||
libcudart.cudaProfilerStart() | ||
out = pjitted_test_fn(input, weight1, weight2) | ||
libcudart.cudaProfilerStop() | ||
else: | ||
with mesh, nn_partitioning.axis_rules(logical_axis_rules): | ||
input = with_sharding_constraint(input, ('batch', 'seq_rs', 'emb')) | ||
weight1 = with_sharding_constraint(weight1, ('emb', 'mlp')) | ||
weight2 = with_sharding_constraint(weight2, ('mlp', 'emb')) | ||
for i in range(100): | ||
out = pjitted_test_fn(input, weight1, weight2) | ||
|
||
return out | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import argparse | ||
import numpy | ||
import jax | ||
import jax.numpy as jnp | ||
from jax.sharding import Mesh, PartitionSpec | ||
from jax.experimental.shard_map import shard_map | ||
from jax.experimental.pjit import pjit | ||
from flax.linen import partitioning as nn_partitioning | ||
|
||
import os | ||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_async_collective_permute=true --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_highest_priority_async_stream=true' | ||
#os.environ['XLA_FLAGS'] += ' --xla_dump_hlo_as_text --xla_dump_hlo_as_html --xla_dump_to=/results/hlo/rs_matmul_dp8tp1' | ||
|
||
with_sharding_constraint = nn_partitioning.with_sharding_constraint | ||
|
||
def rs_matmul(lhs, rhs): | ||
axis_size = jax.lax.psum(1, axis_name='tp') | ||
axis_idx = jax.lax.axis_index('tp') | ||
lhs = lhs.reshape((lhs.shape[0], axis_size, lhs.shape[1]//axis_size, lhs.shape[2])) | ||
def collective_matmul(i, out): | ||
out = jax.lax.ppermute( | ||
out, | ||
'tp', | ||
perm=[(j, (j + 1) % axis_size) for j in range(axis_size)]) | ||
lhs_idx = (axis_idx - i - 1) % axis_size | ||
update = lhs[:, lhs_idx, ...] @ rhs | ||
out = out + update | ||
return out | ||
|
||
lhs_idx = (axis_idx - 1) % axis_size | ||
out = lhs[:, lhs_idx, ...] @ rhs | ||
out = jax.lax.fori_loop( | ||
1, axis_size, collective_matmul, out) | ||
|
||
return out | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Matmul overlap with all-gather communication') | ||
parser.add_argument("--dp", dest="dp", type=int, default=8) | ||
parser.add_argument("--tp", dest="tp", type=int, default=1) | ||
parser.add_argument("--batch_size", dest="batch_size", type=int, default=2) | ||
parser.add_argument("--seq_len", dest="seq_len", type=int, default=2048) | ||
parser.add_argument("--hidden_size", dest="hidden_size", type=int, default=12288) | ||
parser.add_argument("--profile", action="store_true") | ||
args = parser.parse_args() | ||
|
||
assert(args.dp * args.tp == len(list(jax.devices()))) | ||
assert(args.seq_len % args.tp == 0) | ||
assert(args.hidden_size % args.tp == 0) | ||
args.batch_size = args.batch_size * args.dp | ||
|
||
dtype = jnp.bfloat16 | ||
key = jax.random.PRNGKey(0) | ||
key1, key2 = jax.random.split(key, 2) | ||
input = jax.random.uniform(key2, (args.batch_size, args.seq_len, 4*args.hidden_size), dtype=dtype) | ||
weight = jax.random.uniform(key2, (4*args.hidden_size, args.hidden_size), dtype=dtype) | ||
|
||
mesh_shape = {'dp': args.dp, 'tp': args.tp} | ||
in_specs = (PartitionSpec('dp', None, 'tp'), PartitionSpec('tp', None)) | ||
out_specs = PartitionSpec('dp', 'tp', None) | ||
|
||
mesh = Mesh(numpy.array(jax.devices()).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) | ||
logical_axis_rules = (('batch', 'dp'), ('seq_rs', 'tp'), ('seq_ag', None), ('emb', None), ('mlp', 'tp')) | ||
pjitted_rs_matmul = pjit(shard_map(rs_matmul, mesh, in_specs=in_specs, out_specs=out_specs)) | ||
|
||
if args.profile: | ||
import ctypes | ||
libcudart = ctypes.cdll.LoadLibrary('libcudart.so') | ||
with mesh, nn_partitioning.axis_rules(logical_axis_rules): | ||
input = with_sharding_constraint(input, ('batch', 'seq_ag', 'mlp')) | ||
weight = with_sharding_constraint(weight, ('mlp', 'emb')) | ||
for i in range(100): | ||
if i == 9: | ||
libcudart.cudaProfilerStart() | ||
out = pjitted_rs_matmul(input, weight) | ||
libcudart.cudaProfilerStop() | ||
else: | ||
with mesh, nn_partitioning.axis_rules(logical_axis_rules): | ||
input = with_sharding_constraint(input, ('batch', 'seq_ag', 'mlp')) | ||
weight = with_sharding_constraint(weight, ('mlp', 'emb')) | ||
for i in range(100): | ||
out = pjitted_rs_matmul(input, weight) | ||
|
||
return out | ||
|
||
if __name__ == "__main__": | ||
main() |