Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Feb 19, 2025
1 parent f3c8545 commit 96a19f1
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion sharktank/sharktank/kernels/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def select(self, ksel: KernelSelection):
v_desc.specialize_dims(0, 1, -1)

# Result 0: Shape batch..., m, n
ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float16).specialize_dims(
ksel.return_new_tensor((*q_bs, q_l, v_e), dtype=torch.float32).specialize_dims(
0, 1, -1
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,16 @@ def main(argv):
type=str,
default="7b",
help="Base model to use for split sizes to decompose the qkv tensor. Default is 7b, 70b is also supported.",
choices=["7b", "70b"],
choices=["7b", "70b", "405b"],
)
args = cli.parse(parser, args=argv)

config_json_path: Path = args.config_json
params_path: Path = args.params
# TODO: find a way to get this programatically so we don't have to flag for it
split_sizes = [4096, 4096, 4096] if args.model_base == "7b" else [8192, 1024, 1024]
num_layers = 32 if args.model_base == "7b" else 80

layers_per_base = {"7b": 32, "70b": 40, "405b": 125}
num_layers = layers_per_base[args.model_base]
# Construct the pre-transform dataset.
dataset_props = _get_dataset_props(_load_json(config_json_path))
with safetensors.safe_open(params_path, framework="pt", device="cpu") as st:
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def _scaled_dot_product_attention(
):
tensors = (q, k, v, a)
for override in d.find_overrides(tensors):
result = override(q, k, v, a, is_causal=is_causal, scale=scale)
result = override(q, k, v, a, scale=scale)
if result is not NotImplemented:
return override, result
else:
Expand Down
6 changes: 3 additions & 3 deletions sharktank/tests/kernels/attention_template_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_compare_torch_spda(self, dtype, atol, rtol, use_mask):
mask = None
scale = torch.tensor(1.0, dtype=dtype)
if use_mask:
mask = torch.rand([N, H, L, S], dtype=dtype)
mask = torch.rand([L, S], dtype=dtype)

res2 = kernels.masked_flash_attention(q, k, v, mask, scale=scale)

Expand Down Expand Up @@ -88,7 +88,7 @@ def test_export_dynamic(self, dtype, static, use_mask):
v = torch.rand([N, H, S, Ev], dtype=dtype)
if use_mask:
# mask is same type as inputs, therefore its added to score
mask = torch.rand([N, H, L, S], dtype=dtype)
mask = torch.rand([L, S], dtype=dtype)
if cast:
q = q.to(torch.float8_e4m3fnuz)
k = q.to(torch.float8_e4m3fnuz)
Expand All @@ -108,7 +108,7 @@ def test_export_dynamic(self, dtype, static, use_mask):
"scale": {},
}
if use_mask:
dynamic_shapes["mask"] = {2: L_dim, 3: S_dim}
dynamic_shapes["mask"] = {0: L_dim, 1: S_dim}

class MyModule(torch.nn.Module):
def forward(self, q, k, v, mask, scale):
Expand Down

0 comments on commit 96a19f1

Please sign in to comment.