diff --git a/triton_viz/core/config.py b/triton_viz/core/config.py index 8291c53..63581cc 100644 --- a/triton_viz/core/config.py +++ b/triton_viz/core/config.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: verbose: bool sanitizer_activated: bool + virtual_memory: bool disable_sanitizer: bool report_grid_execution_progress: bool @@ -34,6 +35,13 @@ def reset(self) -> None: os.getenv("REPORT_GRID_EXECUTION_PROGRESS", "0") == "1" ) # verify using setter + # --- Virtual memory flag --- + self._virtual_memory = os.getenv("TRITON_VIZ_VIRTUAL_MEMORY", "0") == "1" + + @property + def virtual_memory(self) -> bool: + return self._virtual_memory + # ---------- disable_sanitizer ---------- @property def disable_sanitizer(self) -> bool: diff --git a/triton_viz/core/patch.py b/triton_viz/core/patch.py index f12da25..3334d9c 100644 --- a/triton_viz/core/patch.py +++ b/triton_viz/core/patch.py @@ -5,7 +5,9 @@ from tqdm import tqdm from . import config as cfg +from dataclasses import dataclass from .callbacks import OpCallbacks, ForLoopCallbacks + from .data import ( Op, RawLoad, @@ -50,6 +52,8 @@ ) from triton.runtime.interpreter import _patch_lang as triton_patch_lang from triton.runtime.interpreter import ASTTransformer as _OrigASTTransformer +from triton.runtime.interpreter import _tuple_create, _unwrap_tensor, _rewrap_tensor +from triton.tools.tensor_descriptor import TensorDescriptor from triton.runtime import JITFunction op_list = [ @@ -430,6 +434,69 @@ def unpatch_lang(): importlib.reload(tl) +@dataclass(frozen=True) +class FakeTensor: + _data_ptr: int + dtype: str + shape: tuple[int, ...] = () + _stride: tuple[int, ...] = () + _is_contiguous: bool = True + _element_size: int = 1 + + def data_ptr(self) -> int: + return self._data_ptr + + def stride(self) -> tuple[int, ...]: + return self._stride + + def is_contiguous(self) -> bool: + return self._is_contiguous + + def numel(self) -> int: + size = 1 + for dim in self.shape: + size *= dim + return size + + def element_size(self) -> int: + return self._element_size + + +def _init_args_hst(args_dev, kwargs): + def _to_cpu(arg): + if isinstance(arg, tuple): + return _tuple_create(arg, map(_to_cpu, arg)) + elif isinstance(arg, TensorDescriptor): + return TensorDescriptor( + _to_cpu(arg.base), + arg.shape, + arg.strides, + arg.block_shape, + ) + elif not hasattr(arg, "data_ptr"): + return arg + + unwrapped_arg = _unwrap_tensor(arg) + cpu_arg = FakeTensor( + _data_ptr=unwrapped_arg.data_ptr(), + dtype=unwrapped_arg.dtype, + shape=unwrapped_arg.shape, + _stride=unwrapped_arg.stride(), + _is_contiguous=unwrapped_arg.is_contiguous(), + _element_size=unwrapped_arg.element_size(), + ) + cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg) + return cpu_arg + + args_hst = [_to_cpu(arg) for arg in args_dev] + + # Process keyword arguments + kwargs_hst = {} + for key, value in kwargs.items(): + kwargs_hst[key] = _to_cpu(value) + return args_hst, kwargs_hst + + def _grid_executor_call(self, *args_dev, **kwargs): if kwargs.pop("warmup", False): return @@ -470,7 +537,10 @@ def run_grid_loops(grid): k: v for k, v in kwargs.items() if k in argspec.args or k in triton_viz_args } client_manager = kwargs.pop("client_manager") - args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) + if cfg.virtual_memory: + args_hst, kwargs_hst = _init_args_hst(args_dev, kwargs) + else: + args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs) # Prepare call arguments args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst) call_args = {} @@ -491,7 +561,8 @@ def run_grid_loops(grid): client_manager.grid_callback(grid) run_grid_loops(grid) # Copy arguments back to propagate side-effects - self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) + if not cfg.virtual_memory: + self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst) def _jit_function_call(self, *args, **kwargs):