Skip to content
8 changes: 8 additions & 0 deletions triton_viz/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
if TYPE_CHECKING:
verbose: bool
sanitizer_activated: bool
virtual_memory: bool
disable_sanitizer: bool
report_grid_execution_progress: bool

Expand Down Expand Up @@ -34,6 +35,13 @@ def reset(self) -> None:
os.getenv("REPORT_GRID_EXECUTION_PROGRESS", "0") == "1"
) # verify using setter

# --- Virtual memory flag ---
self._virtual_memory = os.getenv("TRITON_VIZ_VIRTUAL_MEMORY", "0") == "1"

@property
def virtual_memory(self) -> bool:
return self._virtual_memory

# ---------- disable_sanitizer ----------
@property
def disable_sanitizer(self) -> bool:
Expand Down
75 changes: 73 additions & 2 deletions triton_viz/core/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from tqdm import tqdm

from . import config as cfg
from dataclasses import dataclass
from .callbacks import OpCallbacks, ForLoopCallbacks

from .data import (
Op,
RawLoad,
Expand Down Expand Up @@ -50,6 +52,8 @@
)
from triton.runtime.interpreter import _patch_lang as triton_patch_lang
from triton.runtime.interpreter import ASTTransformer as _OrigASTTransformer
from triton.runtime.interpreter import _tuple_create, _unwrap_tensor, _rewrap_tensor
from triton.tools.tensor_descriptor import TensorDescriptor
from triton.runtime import JITFunction

op_list = [
Expand Down Expand Up @@ -430,6 +434,69 @@ def unpatch_lang():
importlib.reload(tl)


@dataclass(frozen=True)
class FakeTensor:
_data_ptr: int
dtype: str
shape: tuple[int, ...] = ()
_stride: tuple[int, ...] = ()
_is_contiguous: bool = True
_element_size: int = 1

def data_ptr(self) -> int:
return self._data_ptr

def stride(self) -> tuple[int, ...]:
return self._stride

def is_contiguous(self) -> bool:
return self._is_contiguous

def numel(self) -> int:
size = 1
for dim in self.shape:
size *= dim
return size

def element_size(self) -> int:
return self._element_size


def _init_args_hst(args_dev, kwargs):
def _to_cpu(arg):
if isinstance(arg, tuple):
return _tuple_create(arg, map(_to_cpu, arg))
elif isinstance(arg, TensorDescriptor):
return TensorDescriptor(
_to_cpu(arg.base),
arg.shape,
arg.strides,
arg.block_shape,
)
elif not hasattr(arg, "data_ptr"):
return arg

unwrapped_arg = _unwrap_tensor(arg)
cpu_arg = FakeTensor(
_data_ptr=unwrapped_arg.data_ptr(),
dtype=unwrapped_arg.dtype,
shape=unwrapped_arg.shape,
_stride=unwrapped_arg.stride(),
_is_contiguous=unwrapped_arg.is_contiguous(),
_element_size=unwrapped_arg.element_size(),
)
cpu_arg = _rewrap_tensor(cpu_arg, original_tensor=arg)
return cpu_arg

args_hst = [_to_cpu(arg) for arg in args_dev]

# Process keyword arguments
kwargs_hst = {}
for key, value in kwargs.items():
kwargs_hst[key] = _to_cpu(value)
return args_hst, kwargs_hst


def _grid_executor_call(self, *args_dev, **kwargs):
if kwargs.pop("warmup", False):
return
Expand Down Expand Up @@ -470,7 +537,10 @@ def run_grid_loops(grid):
k: v for k, v in kwargs.items() if k in argspec.args or k in triton_viz_args
}
client_manager = kwargs.pop("client_manager")
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
if cfg.virtual_memory:
args_hst, kwargs_hst = _init_args_hst(args_dev, kwargs)
else:
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
# Prepare call arguments
args = inspect.getcallargs(self.fn, *args_hst, **kwargs_hst)
call_args = {}
Expand All @@ -491,7 +561,8 @@ def run_grid_loops(grid):
client_manager.grid_callback(grid)
run_grid_loops(grid)
# Copy arguments back to propagate side-effects
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)
if not cfg.virtual_memory:
self._restore_args_dev(args_dev, args_hst, kwargs, kwargs_hst)


def _jit_function_call(self, *args, **kwargs):
Expand Down