Skip to content

Commit

Permalink
Reverts 30ba7f3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686833554
  • Loading branch information
WindQAQ authored and Google-ML-Automation committed Oct 17, 2024
1 parent 9027fb3 commit a01f187
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@
tpu_custom_call_p.multiple_results = True


def get_target_shape(hardware_generation: int) -> tuple[int, int]:
"""Returns the target shape for the given hardware generation."""
del hardware_generation
return (8, 128)


class MemorySpace(enum.Enum):
HBM = enum.auto()
VMEM = enum.auto()
Expand Down Expand Up @@ -423,9 +429,9 @@ def _lower_mosaic_module_to_asm(
"tpu_custom_call cannot be lowered on a machine without TPUs "
"when mosaic_use_python_pipeline=True.")
hardware_generation = int(device_kind[len("TPU v")])
# TODO(b/369418606): Infer the target shape from the hardware generation.
target_shape = get_target_shape(hardware_generation)
module = _lower_tpu_kernel(
module, hardware_generation, target_shape=(8, 128)
module, hardware_generation, target_shape=target_shape
)
needs_hlo_passes = False
needs_layout_passes = False
Expand Down

0 comments on commit a01f187

Please sign in to comment.