Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/mthreads-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,13 @@ jobs:
run: |
set -x
python3 -m pytest -s third_party/mthreads/python/test/unit --device musa

- name: Direct Kernel Test on Mthreads
if: steps.check_backend.outputs.should_skip != 'true'
shell: bash
run: |
set -x
export TRITON_DEFAULT_BACKEND=mthreads
python3 third_party/mthreads/python/test/test_verify_service.py \
--chip moore \
--all
29 changes: 29 additions & 0 deletions python/test/unit/language/test_hint_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from types import SimpleNamespace

from triton.backends.nvidia.nvidia_hint_handler import NvidiaHintHandler


class ParseMustNotRun:

def parse(self):
raise AssertionError("hint lookup must not reparse the JIT function")


def test_nvidia_hint_lookup_uses_codegen_attached_map():
code_generator = SimpleNamespace(
flagtree_line_hints={17: "cache_global"},
jit_fn=ParseMustNotRun(),
)
node = SimpleNamespace(lineno=17)

assert NvidiaHintHandler.get_node_hints(code_generator, node) == "cache_global"


def test_nvidia_hint_source_cache_returns_independent_dicts():
jit_fn = SimpleNamespace(src="def kernel(x):\n y = x # @hint:cache_global\n return y\n")

first = NvidiaHintHandler.maps_line_numbers_to_comment_hints(jit_fn)
first[2] = "mutated"
second = NvidiaHintHandler.maps_line_numbers_to_comment_hints(jit_fn)

assert second == {2: "cache_global"}
9 changes: 7 additions & 2 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def __init__(self, context, prototype, gscope, function_name, jit_fn: JITFunctio

self.lscope = {}
self.jit_fn = jit_fn
self.flagtree_line_hints = {}
# TODO: we currently generate illegal names for non-kernel functions involving constexprs!
if is_kernel:
function_name = function_name[function_name.rfind('.') + 1:]
Expand Down Expand Up @@ -1413,7 +1414,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs, caller_context=None):
module_map=self.builder.module_map, caller_context=caller_context,
is_gluon=self.is_gluon)
try:
generator.visit(fn.parse())
tree = fn.parse()
generator.flagtree_line_hints = getattr(tree.body[0], 'line_flagtree_hints', {}) or {}
generator.visit(tree)
except Exception as e:
# Wrap the error in the callee with the location of the call.
if knobs.compilation.front_end_debugging:
Expand Down Expand Up @@ -1743,7 +1746,9 @@ def apply_constexpr_types(argument, indices, value):
generator = CodeGenerator(context, prototype, gscope=fn.get_capture_scope(), function_name=fn.repr(proxy),
jit_fn=fn, is_kernel=True, file_name=file_name, begin_line=begin_line, options=options,
codegen_fns=codegen_fns, module_map=module_map, module=module, is_gluon=fn.is_gluon())
generator.visit(fn.parse())
tree = fn.parse()
generator.flagtree_line_hints = getattr(tree.body[0], 'line_flagtree_hints', {}) or {}
generator.visit(tree)
module = generator.module
# module takes ownership of the context
module.context = context
Expand Down
Loading
Loading