-
Notifications
You must be signed in to change notification settings - Fork 86
TLE
Triton is an operator programming language in the form of a Python DSL. It follows a block-based programming model that abstracts away hardware details such as memory hierarchy, layout, pipelining, and synchronization, while achieving strong operator performance through compiler optimization. These advantages have attracted a large developer community and ecosystem.
In recent years, however, Triton has faced growth challenges:
- Adaptation to DSA platforms and new GPU architectures has progressed slowly.
- Compared with emerging languages like TileLang, Triton lacks abstractions for fine-grained control of memory hierarchy and parallel granularity, which can lead to weaker performance in some cases.
To address these issues, we propose TLE (Triton Language Extentions), which extends Triton across three levels to meet urgent needs from users with different skill profiles.
We analyzed mainstream DSLs in the industry (Triton, TileLang, and cuTile) and summarized a target language design.
All three are Python-syntax-based DSLs, indicating that developers prefer Python-like syntax for kernel development, even if only a subset of Python is available.
All three support block-level programming. In essence, current block programming mainly performs tiling on global memory. cuTile goes further by supporting multi-level tiling, making it possible to design a unified language across multiple memory hierarchy architectures.
Triton, however, does not explicitly model tile/slice concepts, so users can only tile at the global memory level, limiting further language evolution.
TileLang is similar to Triton in that it does not provide explicit tiling primitives. In addition, except for copy and GEMM, it lacks higher-level tensor ops, which makes GPU programming less convenient. Without automatic vectorization, utilizing SIMD hardware well often requires adding many SIMD-specific ops.
To address the memory wall, modern hardware uses multi-level memory hierarchies.
- Triton/cuTile expose only two levels: global memory and local tensor.
- TileLang directly exposes native hardware memory hierarchy without abstraction.
Problems:
- Exposing too few levels pushes tiling and buffer promotion work to the compiler.
- Directly exposing native hierarchy significantly hurts portability.
Preferred direction:
- Developers perform tiling, but do not explicitly select memory levels.
- Compiler performs buffer promotion.
- Developers may provide hints; tile sizes are treated as hyperparameters.
This keeps portability while leaving room for further optimization.
- Triton/cuTile expose only block-level parallelism, and intra-block parallelism is fully compiler-controlled.
- TileLang lets developers explicitly control intra-block parallelism (Parallel and Vectorize), improving expressiveness but reducing portability and reuse across hardware.
None of these languages directly covers cross-block or cross-node communication, which limits compute-communication fusion (ongoing external work includes Triton Distributed and TileScale).
- Level 1: Numpy/PyTorch-like algorithm-level programming. Users focus on algorithm logic only; compiler handles hardware mapping and communication.
- Level 2: cuTile-like tile-level programming plus distributed descriptions. Users explicitly provide tiling and sharding, while compiler handles memory hierarchy, parallelism, and communication, with optional hardware/scenario hints.
- Level 3: Hardware-specific extensions (memory hierarchy, thread binding, vectorize, etc.). This level is confined to specific regions with explicit interaction contracts with Level 2. Compiler performs only essential optimizations.
Detailed principles:
- Tile semantics to avoid manual address arithmetic.
- Do not require tensor shapes to be powers of two.
Open question: what other strong design ideas should be added?
TLE sits in the middle layer of the AI software stack:
- Upstream: serves AI frameworks through graph compilers and operator libraries.
- Downstream: integrates with various hardware runtimes.
Content not available outside Feishu document yet.
TLE is split into three layers:
- TLE-Lite: lightweight extension over Triton. Features are backend-compatible, and only small changes to existing Triton kernels are needed to gain significant speedups. Targets algorithm engineers and fast optimization workflows.
- TLE-Struct: architecture-clustered abstractions (e.g., GPGPU, DSA) for deeper performance tuning. Requires moderate hardware knowledge.
- TLE-Raw: direct hardware control, including vendor-native programming languages for maximum performance. Targets expert performance engineers.
Lowering paths:
- TLE-Lite and TLE-Struct lower to LLVM IR via FLIR.
- TLE-Raw lowers to LLVM IR via language-specific pipelines (e.g., vendor private compilers).
- All parts are finally linked into a complete kernel loaded/executed by runtime.
- Design philosophy: write once, run anywhere.
- Core idea: use high-level semantic hints (instead of hard constraints) to guide compiler heuristics. Keep backward compatibility and achieve cross-platform speedups with minimal code changes.
tle.load extends tl.load with async hint support.
- Signature:
tle.load(ptr, mask=None, other=None, is_async=False) - Purpose: keep
tl.loadsemantics while adding async scheduling hints. - Practical guidance:
- Use
is_async=Truefor global-memory reads that are later reused in compute-heavy regions. - Keep
maskandotherexplicit on boundary tiles to avoid undefined values.
- Use
Simple examples:
offs = base + tl.arange(0, BLOCK)
mask = offs < n_elements
x = tle.load(x_ptr + offs, mask=mask, other=0.0, is_async=True)
for k in tl.range(0, K, BK, num_stages=2):
a = tle.load(a_ptr + k * stride_a, is_async=True)
b = tle.load(b_ptr + k * stride_b, is_async=True)
acc = tl.dot(a, b, acc)Split input tensor into a sub-tile grid using a child-tile shape and extract tile at specified coordinates.
- Signature:
x.extract_tile(index, shape) - Returns: sub-tile view at
index. - GPU: supports extraction from registers and shared memory.
- Typical use: local transforms on sub-regions without manual pointer arithmetic.
Split input tensor into a sub-tile grid using child-tile shape and update tile at specified coordinates.
- Signature:
x.insert_tile(tile, index) - Returns: full tensor with the selected sub-tile updated.
- GPU: supports updates in registers and shared memory.
- Typical use: write processed activation, quant/dequant, or normalization results back into a larger tile.
Simple example:
# x: [4, 4]
sub = x.extract_tile(index=[1, 0], shape=[2, 2]) # rows [2:4], cols [0:2]
sub = tl.maximum(sub, 0.0)
x = x.insert_tile(sub, index=[1, 0])Scan/sort ops provide prefix, rank, and selection primitives for kernels such as histogram-based top-k, stream compaction, and block-local ordering. TLE-Lite keeps these operations semantic rather than hardware-bound: users describe the scan/sort intent, and each backend chooses an appropriate register/shared-memory lowering strategy.
tle.cumsum(input, axis=0, reverse=False, dtype=None) computes an exclusive cumulative sum and the total sum along axis in one operation.
- Signature:
tle.cumsum(input, axis=0, reverse=False, dtype=None) - Purpose: compute the exclusive prefix/suffix sum and the total sum of a block tensor with one semantic scan op.
- Returns:
(exclusive_sum, total_sum). - Typical use: build per-block ranks for top-k, histogram prefixing, stream compaction, and block-local partitioning.
-
exclusivehas the same shape asinput;totalis the scalar sum of the scanned block. -
reverse=Truecomputes a reverse-exclusive sum, which is useful for suffix counts in descending radix/top-k selection. -
dtypeoptionally controls the accumulation/result type. By default, narrow integer inputs are widened to 32-bit integers, andbfloat16is promoted tofloat32. - Add the original input back to
exclusive_sumwhen an inclusive cumulative sum is needed. - Keep masked loads explicit and feed zero for inactive lanes so
total_sumdescribes only valid elements. - Supported inputs are static rank-1 block tensors with
axis=0, which matches the histogram and radix-selection workloads already used by TLE top-k kernels.
Simple example:
exclusive, total = tle.cumsum(x, axis=0)
inclusive = exclusive + xtle.pipe describes an explicit dataflow edge between a producer and one or more consumers. It records both the shared-memory stage that carries a logical chunk and the synchronization needed to make that chunk visible, so CTA-local load/compute overlap and warp-specialized producer/consumer code can use one typed descriptor instead of hand-managed barriers.
- Signature:
tle.pipe(*, capacity, scope="cta", name=None, readers=None, one_shot=False, **fields) - Purpose: create a typed pipe for explicit CTA-local producer/consumer dataflow, ring-buffer stage reuse, and synchronization edges.
- Parameters:
-
capacity: compile-time positive integer, the number of pipe stages. The leading dimension of every payload field must equalcapacity. -
scope: supported value is"cta". -
name: optional pipe name for IR/diagnostics; if provided, it must be a string. -
readers: optional reader-name list. Omit it for the default SPSC reader; pass values such as("left", "right")for SPMC. -
one_shot: whether the pipe is a single ready/full edge, useful for one-time broadcast data.one_shot=Truedoes not supportclose. -
**fields: one or more payload buffers. Each field must be a shared-memory buffered tensor returned bytle.gpu.alloc(..., scope=tle.gpu.smem), with rank >= 2.
-
- Name rules:
- Pipe field names and reader names must be Python identifiers.
- Names must not start with
_. -
fieldsandreadersare reserved names.
- The object returned by
tle.pipe(...)is the pipe descriptor. It owns the staged payload fields and creates producer/consumer endpoints throughwriter()andreader(...). -
capacitystages form a ring buffer.itermaps tostage = iter % capacity, with a phase bit distinguishing reuse rounds.
A producer is the code that owns pipe.writer(). It acquires a writable stage, fills every required field for the logical chunk, and commits the chunk so consumers can observe it.
-
pipe_value.writer() -> pipe_writer: creates the single writer endpoint for the pipe. - The writer always sees all payload fields.
-
writer.acquire(iter) -> pipe_slot: acquires the stage for producer writes and returns a slot whose fields have the leadingcapacitydimension removed. - Produce field data between
writer.acquire(iter)andwriter.commit(iter). -
writer.commit(iter) -> None: marks the stage ready for all subscribed consumers. All field writes for that logical chunk must be complete beforecommit. -
writer.close(iter) -> None: publishes a closed stage for close-aware consumer loops.one_shot=Truepipes do not supportclose. -
commitis the producer-side visibility boundary.
A consumer is the code that owns pipe.reader(...). It waits for a published chunk, reads the returned slot, and releases the stage after all reads are complete.
-
pipe_value.reader(name=None, fields=None) -> pipe_reader: creates a consumer endpoint. - For an SPSC pipe (
readers=None),namemust be omitted. - For an SPMC pipe (
readers=("mma", "epilogue")),nameis required and must match a declared reader. -
fieldsmay be a non-empty compile-time tuple/list of unique payload field names. If omitted, the consumer subscribes to all fields. - A field-subset consumer narrows the endpoint view and the returned
wait().slot; it does not create a separate pipe. -
reader.wait(iter) -> pipe_wait_result: waits until the selected stage is ready or closed and returns both the slot and the closed flag. - Normal consumers use
wait_result.slot; close-aware consumers also inspectwait_result.is_closed. -
reader.release(iter) -> None: releases the consumed stage for reuse by the producer. Call it after all reads fromwait(iter).slotare complete. -
waitis the consumer-side visibility boundary.releaseis the consumer-side free signal.
-
**fieldsdefines the data carried by each stage. Each field is exposed onpipe_slotby name, for exampleslot.qorslot.scale. -
pipe_slotalso exposesfields: dict[str, tle.gpu.buffered_tensor]. -
pipe_wait_resultcontainsslot: pipe_slotandis_closed: tl.tensor. - A pipe may carry one field or multiple fields. Split pipes by logical lifecycle and reader protocol, not by the low-level transport used to fill each field.
- Different fields in the same slot may be produced by different mechanisms, such as TMA copy, cp.async-style copy, or
tle.gpu.local_ptrplustl.store. Users still call onewriter.commit(iter)after all fields for that logical chunk are produced. - The compiler infers each field's transport from producer-side IR; transport is not a user-facing pipe attribute and should not be encoded in pipe names, field names, or extra attributes.
- Use
pipe.reader(name, fields=(...))when a reader only consumes a subset of fields; this narrows the reader view without creating another token. - Keep pipe-field provenance visible. Opaque shared-memory pointer escapes, untracked shared stores, or unprovable overlapping writes are rejected instead of lowered with a silent fallback.
- NVIDIA lowering maps CTA-scoped SMEM pipes to NVWS/mbarrier synchronization. Multi-field payloads are accepted when payload window, field ownership, participant count, and source-order safety can be proven at pipe-field root granularity.
- In an SPSC pipe, one producer publishes to one default consumer.
- In an SPMC pipe, one producer publishes the same logical chunk to named consumers such as
("mma", "epilogue"). -
iteris the logical chunk id. The sameitershould be used for the producer and every consumer that participates in that chunk. - The normal cyclic lifecycle is
writer.acquire(iter)-> produce fields ->writer.commit(iter)->reader.wait(iter)-> consume fields ->reader.release(iter). -
one_shot=Truemodels a single ready/full edge, usually withcapacity=1; do not rely on cyclic reuse orclosein that mode.
Automatic software pipelining can still be triggered with tl.range(..., num_stages=...). Use an explicit pipe when the producer/consumer split should be visible in the program.
stage_buf = tle.gpu.alloc([2, BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
pipe = tle.pipe(capacity=2, scope="cta", name="x_pipe", x=stage_buf)
writer = pipe.writer()
reader = pipe.reader()
offs = tl.arange(0, BLOCK)
slot = writer.acquire(k)
tl.store(tle.gpu.local_ptr(slot.x), tl.load(x_ptr + k * BLOCK + offs))
writer.commit(k)
ready = reader.wait(k)
x = tl.load(tle.gpu.local_ptr(ready.slot.x))
reader.release(k)Triton distributed API has four core parts: device mesh definition, sharding specification, resharding (collective communication), and remote access (point-to-point communication).
Recommended workflow:
- Define topology with
tle.device_mesh. - Mark tensor layout with
tle.sharding. - Transform layout with
tle.reshard. - Keep compute kernels operating on logical tensor views.
tle.device_mesh defines physical device topology and serves as the context foundation for distributed operations.
class device_mesh:
def __init__(self, topology: dict):
"""
Initialize DeviceMesh.
Args:
topology (dict): Hardware hierarchy description.
Keys are hierarchy names; values are int (1D)
or tuple lists (multi-dimensional).
"""
self._physical_ids = ... # Internal flattened physical IDs (0..N-1)
self._shape = ... # Current logical shape, e.g. (2, 2, 4, 2, 2, 4)
self._dim_names = ... # Current dimension names
@property
def shape(self):
"""Return logical mesh shape."""
return self._shape
@property
def ndim(self):
"""Return number of dimensions."""
return len(self._shape)
def flatten(self):
"""Flatten mesh to 1D, typically for ring communication."""
return self.reshape(prod(self._shape))
def __getitem__(self, key):
"""
Supports slicing and returns a sub-mesh.
Supports standard slice and integer indexing.
"""
return sub_mesh
def __repr__(self):
return f"DeviceMesh(shape={self._shape}, names={self._dim_names})"
# Define complex hardware hierarchy
topology = {
# Cross-node hierarchy (2x2 = 4 nodes)
"node": [("node_x", 2), ("node_y", 2)],
# In-node GPUs (4 devices)
"device": 4,
# In-GPU cluster (2x2)
"block_cluster": [("cluster_x", 2), ("cluster_y", 2)],
# In-cluster blocks (4 blocks)
"block": 4,
}
# mesh.shape -> (2, 2, 4, 2, 2, 4)
# total size = 256
mesh = tle.device_mesh(topology=topology)tle.sharding declares tensor distribution state on the device mesh:
-
splits: how each tensor axis is partitioned on mesh axes. -
partials: whether tensor is partial-sum state. - Unspecified mesh axes are treated as broadcast.
Symbols:
-
tle.S(axis): split. -
tle.B: broadcast/replicate. -
tle.P(axis): partial; requires reduce on specified axis.
def sharding(tensor, splits, partials):
"""
Annotation only: marks tensor state, emits no direct code,
but guides compiler checks and optimizations.
"""
return tensor
# Split axis0 on cluster, axis1 on device, and partial on block axis
x_shard = tle.sharding(
mesh,
split=[["cluster_x", "cluster_y"], "device"],
partial=["block"],
)
# Define a sharded tensor
x = tle.make_sharded_tensor(x_ptr, sharding=x_shard, shape=[4, 4])- Signature:
tle.shard_id(mesh, axis) - Returns current program's coordinate on a mesh axis.
-
axiscan be a mesh-axis name (e.g."node","device","cluster_x") or an axis index. - Typical use: build peer shard IDs for ring exchange, staged all-reduce, and cluster-cooperative kernels.
Simple example:
mesh = tle.device_mesh({"node": 2, "device": 4})
node_rank = tle.shard_id(mesh, "node") # 0..1
device_rank = tle.shard_id(mesh, "device") # 0..3In complex distributed kernels (e.g., ring all-reduce or row/column-independent pipelines), only “same-row” or “same-column” blocks often need synchronization rather than the whole cluster. Global synchronization introduces unnecessary waiting.
def distributed_barrier(mesh):
"""
If sub_mesh is passed, synchronize only devices in this sub-mesh.
Devices outside this sub-mesh should treat it as No-Op
(or compiler guarantees control flow does not enter).
"""
passtle.remote obtains a handle for tensor data located on other devices. This maps to point-to-point communication or direct memory access (RDMA/NVLink load).
-
tle.remotereads/writes explicit remote shards. -
tle.distributed_barriersynchronizes only the mesh/sub-mesh you pass in.
def remote(tensor, shard_id, scope):
"""
Get a RemoteTensor handle to a shard on a target device.
:param tensor: logically distributed tensor (already marked by tle.sharding)
:param shard_id: tuple coordinate in device mesh
:return: RemoteTensor, supporting load/store and related ops
"""Simple example:
next_device = (tle.shard_id(mesh, "device") + 1) % mesh.shape[1]
remote_x = tle.remote(x, shard_id=(tle.shard_id(mesh, "node"), next_device), scope=mesh)
tle.distributed_barrier(mesh)
neighbor_vals = tl.load(remote_x)tle.reshard is the entrypoint for collectives. Compiler compares source and target specs and inserts communication primitives automatically.
def reshard(tensor, spec):
"""
Action: transform tensor to a new distribution state.
Typical transitions:
1. [ ] -> [S]: Scatter
2. [S] -> [ ]: Gather
3. [P] -> [ ]: Reduce
4. [B] -> [S]: Local slice (no communication)
5. [S] -> [B]: All-gather
6. [P] -> [B]: All-reduce
7. [B] -> [P]: Error
"""Simple example:
x_spec = tle.sharding(mesh, split=["device"], partial=[])
x = tle.make_sharded_tensor(x_ptr, sharding=x_spec, shape=[M, K])
x_full = tle.reshard(x, spec=tle.sharding(mesh, split=[], partial=[]))NVIDIA Hopper (H100) and newer architectures introduce Thread Block Cluster, allowing groups of CTAs to cooperate via DSMEM for high-bandwidth, low-latency exchange.
tle.distributed_dot is designed to use this feature so developers can write cross-block matrix multiplication without manually handling DSMEM barriers and data movement.
def distributed_dot(a, b, c=None):
"""
Execute distributed matrix multiplication within current
Thread Block Cluster scope.
Behavior depends on sharding specs of input tensors `a` and `b`
over the cluster mesh.
Args:
a (Tensor): left operand with cluster-level sharding annotation.
b (Tensor): right operand with cluster-level sharding annotation.
c (Tensor, optional): accumulator.
Returns:
Tensor: result tensor with distribution inferred from inputs.
"""Open question: what additional distributed primitives are needed?
This combines tle.load and tle.cumsum: masked async loads bring a tile into registers, a predicate builds per-lane active flags, and tle.cumsum assigns compact output offsets plus the per-block active count.
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n_elements
x = tle.load(x_ptr + offs, mask=mask, other=0.0, is_async=True)
keep = (x > threshold) & mask
active = keep.to(tl.int32)
write_offsets, active_total = tle.cumsum(active, axis=0)
block_base = tl.load(block_offsets_ptr + pid)
tl.store(out_ptr + block_base + write_offsets, x, mask=keep)
tl.store(block_counts_ptr + pid, active_total)This combines tle.pipe, shared-memory pointer views, and tile slicing. The producer stages global tiles into shared memory, while the consumer reads the ready stage, applies a local transform to one sub-tile, and writes the updated tile to the output.
tile_buf = tle.gpu.alloc([2, BM, BN], dtype=tl.float32, scope=tle.gpu.smem)
pipe = tle.pipe(capacity=2, scope="cta", name="post_pipe", tile=tile_buf)
writer = pipe.writer()
reader = pipe.reader()
rows = tl.arange(0, BM)[:, None]
cols = tl.arange(0, BN)[None, :]
for t in tl.range(0, n_tiles):
slot = writer.acquire(t)
vals = tle.load(x_ptr + tile_offset(t, rows, cols), mask=tile_mask(t, rows, cols), other=0.0, is_async=True)
tl.store(tle.gpu.local_ptr(slot.tile, (rows, cols)), vals)
writer.commit(t)
for t in tl.range(0, n_tiles):
ready = reader.wait(t)
tile = tl.load(tle.gpu.local_ptr(ready.slot.tile, (rows, cols)))
sub = tile.extract_tile(index=[0, 0], shape=[BM // 2, BN // 2])
sub = tl.maximum(sub, 0.0)
tile = tile.insert_tile(sub, index=[0, 0])
tl.store(out_ptr + tile_offset(t, rows, cols), tile, mask=tile_mask(t, rows, cols))
reader.release(t)This combines one multi-field pipe payload with reader field subsets. One pipe slot carries both a TMA-produced tile and a locally produced scale vector; the MMA reader subscribes only to the tile, while the epilogue reader consumes both fields.
q = tle.gpu.alloc([PIPE_CAPACITY, BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
scale = tle.gpu.alloc(
[PIPE_CAPACITY, BM],
dtype=tl.float32,
scope=tle.gpu.smem,
nv_mma_shared_layout=False,
)
pipe = tle.pipe(
capacity=PIPE_CAPACITY,
scope="cta",
name="multi_field_inputs",
readers=("mma", "epilogue"),
q=q,
scale=scale,
)
writer = pipe.writer()
mma_reader = pipe.reader("mma", fields=("q",))
epilogue_reader = pipe.reader("epilogue", fields=("q", "scale"))
slot = writer.acquire(k)
tle.gpu.copy(q_desc, slot.q, [BM, BK], [q_block, k_block])
tl.store(tle.gpu.local_ptr(slot.scale, (tl.arange(0, BM),)), scale_values)
writer.commit(k)
q_ready = mma_reader.wait(k).slot
q_tile = tl.load(tle.gpu.local_ptr(q_ready.q))
mma_reader.release(k)
epilogue_ready = epilogue_reader.wait(k).slot
q_for_epilogue = tl.load(tle.gpu.local_ptr(epilogue_ready.q))
scale_tile = tl.load(tle.gpu.local_ptr(epilogue_ready.scale, (tl.arange(0, BM),)))
epilogue_reader.release(k)This combines tle.device_mesh, tle.sharding, tle.shard_id, tle.remote, tle.distributed_barrier, and tle.reshard. A shard reads a neighbor's tile for ring-style exchange, then materializes a replicated view before local compute.
mesh = tle.device_mesh({"node": 2, "device": 4})
x_spec = tle.sharding(mesh, split=["device"], partial=[])
x = tle.make_sharded_tensor(x_ptr, sharding=x_spec, shape=[M, K])
node_rank = tle.shard_id(mesh, "node")
device_rank = tle.shard_id(mesh, "device")
next_device = (device_rank + 1) % mesh.shape[1]
neighbor_x = tle.remote(x, shard_id=(node_rank, next_device), scope=mesh)
tle.distributed_barrier(mesh)
neighbor_vals = tl.load(neighbor_x)
x_full = tle.reshard(x, spec=tle.sharding(mesh, split=[], partial=[]))
acc = local_compute(x_full, neighbor_vals)- Design philosophy: architecture-aware, fine-grained tuning.
- Core idea: classify backends by hardware-topology families (e.g., GPGPU, DSA), expose common hierarchical parallel/storage structures, and let developers explicitly define structured compute/data mappings (e.g., warp-group control, pipeline scheduling). This decouples algorithm logic from hardware physical implementation at the abstraction level.
Specify tensor memory_space:
x = ...
x = tle.gpu.memory_space(x, "shared_memory")Allocate memory:
a_smem = tle.gpu.alloc(
[XBLOCK, YBLOCK],
dtype=tl.float32,
layout=None,
scope=tle.gpu.storage_kind.smem,
)Get memory pointers:
# pointers for a_smem[0, :]: [(0, 0), (0, 1), ..., (0, YBLOCK-1)]
a_smem_ptrs = tle.gpu.local_ptr(
a_smem,
indices=(tl.broadcast(0, [YBLOCK]), tl.arange(0, YBLOCK)),
)- Signature:
tle.gpu.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr - Purpose: Build arbitrary-shaped pointer views over shared memory buffers for
tl.load/tl.store/tl.atomic*. - Parameters:
-
buffer: buffered tensor returned bytle.gpu.alloc(SMEM/TMEM). -
indices: optional tuple of integer tensors. Tuple length must equalrank(buffer), and all tensors must have identical shapes. If omitted/None, backend treats it as full indices.
-
- Semantics:
- If
indicesis provided: output pointer tensor shape equals common shape of index tensors. - For each logical output index
(i0, i1, ...), pointer value corresponds tobuffer[indices0(i0,...), indices1(i0,...), ...]. - If
indices=None: build full-view pointers overbuffershape (rank>0 returns pointer tensor withshape(buffer), rank=0 returns scalar pointer). - Returned pointers live in shared-memory address space (LLVM addrspace=3). Indices must be integers (i32/i64, etc.; lowered to i32).
- Linearization is row-major (last dimension fastest); shared-memory layout/encoding follows buffer memdesc.
- If
Example 1: 1D slice
smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
# Slice [offset, offset + SLICE)
idx = offset + tl.arange(0, SLICE)
slice_ptr = tle.gpu.local_ptr(smem, (idx,))
vals = tl.load(slice_ptr)Example 2: K-dimension tiling (matrix slice)
smem_a = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
# Slice (BM, KW), where KW is K-dimension slice
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, KW))
cols = tl.broadcast_to(tl.arange(0, KW)[None, :] + k_start, (BM, KW))
a_slice = tle.gpu.local_ptr(smem_a, (rows, cols))
a_vals = tl.load(a_slice)Example 3: arbitrary gather view
smem = tle.gpu.alloc([H, W], dtype=tl.float32, scope=tle.gpu.smem)
# Take an offset column per row
rows = tl.broadcast_to(tl.arange(0, H)[:, None], (H, SLICE))
cols = tl.broadcast_to(1 + tl.arange(0, SLICE)[None, :], (H, SLICE))
gather_ptr = tle.gpu.local_ptr(smem, (rows, cols))
out = tl.load(gather_ptr)Supported downstream ops:
tl.loadtl.storetl.atomic_add/and/cas/max/min/or/xchg/xor
Practical notes:
- Atomic ops require element dtype/backend support; use integer/float types supported by target hardware.
- For local-pointer load-after-store hazards, TLE backend pass
TleInsertLocalPointerBarriersinserts barriers automatically; add manual barriers only for custom synchronization patterns outside pass coverage.
Example 4: load/store/atomic on the same local_ptr
smem_i32 = tle.gpu.alloc([BLOCK], dtype=tl.int32, scope=tle.gpu.smem)
ptr = tle.gpu.local_ptr(smem_i32, (tl.arange(0, BLOCK),))
tl.store(ptr, tl.zeros([BLOCK], dtype=tl.int32))
tl.atomic_add(ptr, 1)
vals = tl.load(ptr)- Signature:
tle.gpu.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr - Purpose: materialize pointer views for remote shared/local buffers returned by
tle.remote(...). - Inputs:
-
remote_buffer: result oftle.remote(buffer, shard_id, scope), wherebufferis typically allocated bytle.gpu.alloc. -
indices: same rules as local mode (Nonefor full view, or tuple of integer tensors with identical shapes).
-
- Semantics:
- Pointer shape/linearization rules are identical to local
tle.gpu.local_ptr. - Address resolution targets the remote shard selected by
shard_id. - Use
tle.distributed_barrier(...)when cross-shard producer/consumer ordering is required.
- Pointer shape/linearization rules are identical to local
Example: read remote SMEM tile from neighbor shard
smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem)
remote_smem = tle.remote(smem, shard_id=(node_rank, next_device), scope=mesh)
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
remote_ptr = tle.gpu.local_ptr(remote_smem, (rows, cols))
vals = tl.load(remote_ptr)Memory copy:
tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK])tle.gpu.warp_specialize creates an explicit warp-specialized region inside one CTA, placing different JIT functions into separate warp partitions. Typical uses include splitting TMA/cp.async producers, WGMMA consumers, epilogues, and reductions, with shared-memory data passed through tle.pipe or other explicit synchronization primitives.
- Signature:
tle.gpu.warp_specialize(functions_and_args, worker_num_warps, worker_num_regs) - Parameters:
-
functions_and_args:[(fn0, args0), (fn1, args1), ...]. Entry 0 is emitted into the default partition; later entries are emitted into worker partitions. -
worker_num_warps: warp counts for worker partitions. Length must equallen(functions_and_args) - 1. -
worker_num_regs: requested register counts for worker partitions. Length must equallen(functions_and_args) - 1.
-
- Semantics:
- Each
argsvalue must be a tuple. Plain Pythonint/float/bool/tl.dtypevalues are passed as constexpr arguments. - The default partition may return values; the return value of
tle.gpu.warp_specialize(...)comes from the default partition. Worker partitions perform side effects and end with warp return. - Worker callees receive the corresponding
"ttg.num-warps"attribute, and the region recordsrequestedRegisters. - Captured worker arguments are deduplicated in IR, so multiple workers can share the same pipe endpoint or buffer handle.
-
warp_specializedoes not itself provide data-visibility ordering. Producer/consumer ordering should be expressed withtle.pipecommit/wait/release, barriers, or other synchronization primitives.
- Each
Example: one producer partition loads shared memory, one consumer worker computes.
@triton.jit
def producer(writer, x_ptr, n_tiles: tl.constexpr, BLOCK: tl.constexpr):
offs = tl.arange(0, BLOCK)
for i in tl.range(0, n_tiles):
slot = writer.acquire(i)
vals = tl.load(x_ptr + i * BLOCK + offs)
tl.store(tle.gpu.local_ptr(slot.tile), vals)
writer.commit(i)
@triton.jit
def consumer(reader, out_ptr, n_tiles: tl.constexpr, BLOCK: tl.constexpr):
offs = tl.arange(0, BLOCK)
acc = tl.zeros([BLOCK], dtype=tl.float32)
for i in tl.range(0, n_tiles):
ready = reader.wait(i)
tile = tl.load(tle.gpu.local_ptr(ready.slot.tile))
acc += tile
reader.release(i)
tl.store(out_ptr + offs, acc)
@triton.jit
def kernel(x_ptr, out_ptr, n_tiles: tl.constexpr, BLOCK: tl.constexpr):
smem = tle.gpu.alloc([2, BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
pipe = tle.pipe(capacity=2, scope="cta", name="x_pipe", tile=smem)
tle.gpu.warp_specialize(
[
(producer, (pipe.writer(), x_ptr, n_tiles, BLOCK)),
(consumer, (pipe.reader(), out_ptr, n_tiles, BLOCK)),
],
[4], # consumer worker uses 4 warps
[168], # consumer worker requested registers
)Example: multiple workers with an SPMC pipe.
tile = tle.gpu.alloc([2, BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
pipe = tle.pipe(
capacity=2,
scope="cta",
name="spmc_tile",
readers=("qk", "value"),
tile=tile,
)
tle.gpu.warp_specialize(
[
(load_tile_producer, (pipe.writer(), a_desc, b_desc)),
(qk_consumer, (pipe.reader("qk"), acc_qk)),
(value_consumer, (pipe.reader("value", fields=("tile",)), acc_v)),
],
[4, 4],
[240, 168],
)This section is rewritten from triton_v3.2.x (python/triton/experimental/tle/language/dsa and its README).
DSA APIs are split into:
- Generic DSA APIs under
tle.dsa.* - Backend-specific address spaces under
tle.dsa.ascend.*
- Signature:
tle.dsa.alloc(shape, dtype, mem_addr_space) - Purpose: allocate DSA local buffers in a target memory space.
Ascend memory spaces exposed in source:
tle.dsa.ascend.UBtle.dsa.ascend.L1tle.dsa.ascend.L0Atle.dsa.ascend.L0Btle.dsa.ascend.L0C
a_ub = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_l1 = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.L1)- Signature:
tle.dsa.copy(src, dst, shape, inter_no_alias=False) - Purpose: explicit movement between GMEM pointers and DSA local buffers (both directions).
tle.dsa.copy(x_ptrs, a_ub, [tail_m, tail_n]) # GMEM -> local buffer
tle.dsa.copy(a_ub, out_ptrs, [tail_m, tail_n]) # local buffer -> GMEM- Signature:
tle.dsa.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr - Purpose: build pointer views over DSA local buffers (for example UB/L1) for explicit local-memory access patterns.
- Parameters:
-
buffer: DSA buffered tensor, typically fromtle.dsa.alloc. -
indices: optional tuple of integer tensors. If omitted/None, backend treats it as full indices.
-
- Semantics:
- Shape and indexing behavior follow
tle.gpu.local_ptr(same pointer-view model). - Intended for DSA-local data access paths that require explicit pointer materialization.
- Shape and indexing behavior follow
Example:
a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
a_ptr = tle.dsa.local_ptr(a_ub, (rows, cols))
a_val = tl.load(a_ptr)- Signature:
tle.dsa.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr - Purpose: materialize pointer views over remote DSA local buffers obtained from
tle.remote(...). - Inputs:
-
remote_buffer: result oftle.remote(dsa_buffer, shard_id, scope). -
indices: same rules as local DSA mode.
-
- Semantics:
- Same pointer-view semantics as local DSA mode.
- Pointer dereference is routed to the remote shard selected by
shard_id. - Pair with
tle.distributed_barrierwhen cross-shard ordering is required.
Example:
a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
remote_a_ub = tle.remote(a_ub, shard_id=peer_rank, scope=mesh)
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
remote_ptr = tle.dsa.local_ptr(remote_a_ub, (rows, cols))
remote_val = tl.load(remote_ptr)-
tle.dsa.to_tensor(buffer, writable=True): convert a DSA buffer to a tensor view for tensor expressions. -
tle.dsa.to_buffer(tensor, space): convert a tensor value back to a buffer in a target DSA address space.
c_val = tle.dsa.to_tensor(c_ub, writable=True)
result = c_val * 0.5
d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB)
tle.dsa.copy(d_ub, out_ptrs, [tail_m, tail_n])Builtins provided by source:
-
tle.dsa.add -
tle.dsa.sub -
tle.dsa.mul -
tle.dsa.div -
tle.dsa.max -
tle.dsa.min -
Common signature:
tle.dsa.<op>(lhs, rhs, out) -
Compute model: elementwise binary op over DSA local buffers.
-
Shape rules:
-
lhs,rhs,outmust have the same rank and shape. - No implicit broadcast is assumed in this API layer.
-
-
Dtype rules:
- Three operands should use the same dtype in practice.
- Integer dtypes are typical for index/count paths; float dtypes are typical for activation/math paths.
-
Memory-space rules:
- Buffers should be allocated in compatible DSA local spaces (for example UB/L1 combinations allowed by backend).
- Keep hot operands/results in local space to avoid extra GMEM traffic.
Per-op semantics:
-
tle.dsa.add(lhs, rhs, out):out = lhs + rhs -
tle.dsa.sub(lhs, rhs, out):out = lhs - rhs -
tle.dsa.mul(lhs, rhs, out):out = lhs * rhs -
tle.dsa.div(lhs, rhs, out):out = lhs / rhs(backend-dependent precision/rounding) -
tle.dsa.max(lhs, rhs, out):out = max(lhs, rhs) -
tle.dsa.min(lhs, rhs, out):out = min(lhs, rhs)
In-place usage:
- You can reuse the same output buffer across steps, for example
tle.dsa.mul(tmp, b, tmp). - Avoid aliasing inputs/outputs unless backend semantics explicitly allow it.
Example 1: arithmetic chain ((a - b) * b) / scale
a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
scale_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
out_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tle.dsa.copy(a_ptrs, a_ub, [BM, BK])
tle.dsa.copy(b_ptrs, b_ub, [BM, BK])
tle.dsa.copy(scale_ptrs, scale_ub, [BM, BK])
tle.dsa.sub(a_ub, b_ub, tmp_ub) # tmp = a - b
tle.dsa.mul(tmp_ub, b_ub, tmp_ub) # tmp = tmp * b
tle.dsa.div(tmp_ub, scale_ub, out_ub) # out = tmp / scale
tle.dsa.copy(out_ub, out_ptrs, [BM, BK])Example 2: clamp by max + min
x_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
floor_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
ceil_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
y_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tle.dsa.copy(x_ptrs, x_ub, [BM, BK])
tle.dsa.copy(floor_ptrs, floor_ub, [BM, BK])
tle.dsa.copy(ceil_ptrs, ceil_ub, [BM, BK])
tle.dsa.max(x_ub, floor_ub, tmp_ub) # tmp = max(x, floor)
tle.dsa.min(tmp_ub, ceil_ub, y_ub) # y = min(tmp, ceil)
tle.dsa.copy(y_ub, y_ptrs, [BM, BK])Source includes:
tle.dsa.pipeline(...)tle.dsa.parallel(...)-
tle.dsa.hint(...)(used aswith tle.dsa.hint(...)compile-time hints)
with tle.dsa.hint(inter_no_alias=True):
tle.dsa.copy(x_ptr + offs, a_ub, [tail_size], inter_no_alias=True)Source includes:
tle.dsa.extract_slicetle.dsa.insert_slicetle.dsa.extract_elementtle.dsa.subview
sub = tle.dsa.extract_slice(full, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1))
full = tle.dsa.insert_slice(full, sub, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1))
elem = tle.dsa.extract_element(sub, indice=(i, j))Use this pattern when data is reused across multiple math operations.
# 1) Allocate SMEM tile
a_smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem)
# 2) Copy GMEM -> SMEM
tle.gpu.copy(a_ptrs, a_smem, [BM, BK])
# 3) Build local pointer view and load
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
a_ptr_local = tle.gpu.local_ptr(a_smem, (rows, cols))
a_tile = tl.load(a_ptr_local)Useful for histogram, bucketization, and radix-select style counting.
bins = 256
counts = tle.gpu.alloc([bins], dtype=tl.int32, scope=tle.gpu.storage_kind.smem)
idx = tl.arange(0, BLOCK) % bins
count_ptr = tle.gpu.local_ptr(counts, (idx,))
tl.atomic_add(count_ptr, 1)Use this for DSA backends that expose dedicated local buffer spaces.
a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tle.dsa.copy(a_ptrs, a_ub, [BM, BK])
tle.dsa.copy(b_ptrs, b_ub, [BM, BK])
tle.dsa.add(a_ub, b_ub, c_ub)
c_val = tle.dsa.to_tensor(c_ub, writable=True)
out_ub = tle.dsa.to_buffer(c_val, tle.dsa.ascend.UB)
tle.dsa.copy(out_ub, out_ptrs, [BM, BK])- Design philosophy: native passthrough and maximal control.
- Core idea: break DSL abstraction boundaries and support inlined vendor-native code. Target instructions are generated through vendor-private pipelines, bypassing general compiler middle layers and giving experts strong control over instruction scheduling, register allocation, and low-level synchronization primitives.
Content not available outside Feishu document yet.
Open question: should Raw integration be limited to Python DSL only?
from typing import Annotated
from mlir import ir
from mlir.dialects import arith, nvvm, tensor
import triton.language as tl
from triton.experimental.flagtree.edsl import dialect
import triton.experimental.flagtree.language as fl
# 1. Dialect declaration
@tle.raw.language(name="mlir")
# 2. Hardware constraints
@tle.hardware_constraint(threads_dim=1, sync_scope="block")
# 3. Function implementation
def vector_add_tile(
x: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
y: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
output: Annotated[ir.RankedTensorType, "tensor<1024xf32>"]
):
tidx = nvvm.ThreadIdXOp(ir.IntegerType.get_signless(32)).res
bidx = nvvm.BlockIdXOp(ir.IntegerType.get_signless(32)).res
bdimx = nvvm.BlockDimXOp(ir.IntegerType.get_signless(32)).res
idx = arith.addi(arith.muli(bidx, bdimx), tidx)
idx = arith.index_cast(ir.IndexType.get(), idx)
xval = tensor.extract(x, [idx])
yval = tensor.extract(y, [idx])
result = arith.addf(xval, yval)
tensor.insert(result, output, [idx])
@tle.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = tl.zeros_like(x)
# 4. Function call
tle.call(
vector_add_tile,
args=[x, y, output],
hardware={
"threads": (BLOCK_SIZE,),
},
layout={
x: {"space": "shared", "order": [0]},
y: {"space": "shared", "order": [0]},
output: {"space": "shared", "order": [0]},
}
)
tl.store(output_ptr + offsets, output, mask=mask)Optimization and tests have been conducted for SparseMLA in DSA on H800.
- TileLang version:
v0.1.7 - Example code:
python/tutorials/tle/deepseek_v32/02-sparse-mla.py
The cases match the FlashMLA V3.2 sparse prefill performance fixture, with attn_sink omitted because the local Triton, TLE, and TileLang kernels do not implement it:
B=1, S=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048.
Latency in milliseconds:
| SKV | Triton | TLE | TLE-Pipe-Pipelined | TLE-FlashMLA-Prefill | TileLang | TileLang-Pipelined | TileLang-Seesaw | FlashMLA |
|---|---|---|---|---|---|---|---|---|
| 8192 | 9.896 | 6.927 | 4.832 | 4.273 | 62.409 | 5.432 | 5.160 | 3.850 |
| 32768 | 11.210 | 7.624 | 5.321 | 4.834 | 75.160 | 6.428 | 5.577 | 4.117 |
| 65536 | 11.655 | 8.378 | 5.731 | 5.305 | 84.432 | 6.865 | 5.786 | 4.348 |
| 98304 | 11.835 | 8.658 | 5.972 | 5.599 | 86.561 | 7.139 | 5.873 | 4.447 |
| 131072 | 11.923 | 8.863 | 6.122 | 5.887 | 87.448 | 7.143 | 5.916 | 4.534 |
Speedup summary:
| SKV | TLE-Pipe over Triton | TLE-FlashMLA over Triton | TLE-FlashMLA over TLE-Pipe | TLE-FlashMLA over FlashMLA | TLE-FlashMLA over TileLang-Seesaw |
|---|---|---|---|---|---|
| 8192 | 2.05x | 2.32x | 1.13x | 0.90x | 1.21x |
| 32768 | 2.11x | 2.32x | 1.10x | 0.85x | 1.15x |
| 65536 | 2.03x | 2.20x | 1.08x | 0.82x | 1.09x |
| 98304 | 1.98x | 2.11x | 1.07x | 0.79x | 1.05x |
| 131072 | 1.95x | 2.03x | 1.04x | 0.77x | 1.00x |
With shared-memory extensions in tle-struct, it is possible to implement vllm/sglang-style moe_align_block_size and improve performance.
- Example code:
python/tutorials/tle/02-moe_align_block_size.py
| num_tokens | triton | triton_atomic | tle_atomic_fused [ours] | tle_cluster_fused [ours] | sglang_cuda | Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused)) |
|---|---|---|---|---|---|---|
| 256 | 0.0348 | 0.0302 | 0.0323 | 0.0097 | 0.0138 | 1.42x |
| 512 | 0.0369 | 0.0301 | 0.0240 | 0.0117 | 0.0138 | 1.18x |
| 1024 | 0.0369 | 0.0313 | 0.0179 | 0.0117 | 0.0139 | 1.19x |
| 2048 | 0.0368 | 0.0313 | 0.0158 | 0.0131 | 0.0138 | 1.05x |
| 4096 | 0.0369 | 0.0301 | 0.0138 | 0.0143 | 0.0148 | 1.07x |
| 8192 | 0.0369 | 0.0313 | 0.0138 | 0.0164 | 0.0179 | 1.30x |
| 16384 | 0.0369 | 0.0301 | 0.0158 | 0.0205 | 0.0240 | 1.52x |
| 32768 | 0.0389 | 0.0322 | 0.0179 | 0.0301 | 0.0312 | 1.74x |
| 65536 | 0.0430 | 0.0374 | 0.0225 | 0.0486 | 0.0507 | 2.25x |
| 163840 | 0.0609 | 0.0512 | 0.0384 | 0.1036 | 0.1001 | 2.61x |
| num_tokens | triton | triton_atomic | tle_atomic_fused [ours] | tle_cluster_fused [ours] | sglang_cuda | Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused)) |
|---|---|---|---|---|---|---|
| 256 | 0.0260 | 0.0408 | 0.0445 | 0.0133 | 0.0160 | 1.20x |
| 512 | 0.0262 | 0.0399 | 0.0315 | 0.0140 | 0.0162 | 1.16x |
| 1024 | 0.0274 | 0.0401 | 0.0239 | 0.0158 | 0.0163 | 1.03x |
| 2048 | 0.0509 | 0.0422 | 0.0226 | 0.0169 | 0.0173 | 1.02x |
| 4096 | 0.0265 | 0.0412 | 0.0200 | 0.0177 | 0.0187 | 1.06x |
| 8192 | 0.0476 | 0.0416 | 0.0192 | 0.0211 | 0.0230 | 1.20x |
| 16384 | 0.0548 | 0.0441 | 0.0219 | 0.0256 | 0.0286 | 1.31x |
| 32768 | 0.0443 | 0.0441 | 0.0221 | 0.0358 | 0.0401 | 1.81x |
| 65536 | 0.0361 | 0.0481 | 0.0273 | 0.0561 | 0.0645 | 2.36x |
| 163840 | 0.0509 | 0.0626 | 0.0451 | 0.1177 | 0.1323 | 2.93x |
- Runtime config:
num_tokens=163840,num_experts=512,block_size=16,source=real.
| num_tokens | num_experts | block_size | triton | triton_atomic | tle_atomic_fused [ours] | tle_cluster_fused [ours] | sglang_cuda | Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused)) |
|---|---|---|---|---|---|---|---|---|
| 163840 | 512 | 16 | 0.0471 | 0.0535 | 0.0387 | 0.0750 | 0.1467 | 3.79x |
- Runtime config:
num_tokens=163840,num_experts=512,block_size=16,source=real. - Runtime command:
conda run -n flagtree python python/tutorials/tle/02-moe_align_block_size.py --skip_correctness --real_data build/gems/moe_topk_ids.pt --num_experts 512 --block_size 16
| num_tokens | num_experts | block_size | triton | triton_atomic | tle_atomic_fused [ours] | tle_cluster_fused [ours] | sglang_cuda | Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused)) |
|---|---|---|---|---|---|---|---|---|
| 163840 | 512 | 16 | 0.0507 | 0.0395 | 0.0261 | 0.0532 | 0.1060 | 4.06x |
With shared-memory extensions in tle-struct, radix-select-based TopK can improve performance in MoE scenarios with large N and small K.
- Example code:
python/tutorials/tle/03-topk.py
| M | N | K | Triton-RadixSelect | Torch-TopK | Speedup (Torch / Triton-RadixSelect) |
|---|---|---|---|---|---|
| 64 | 128 | 8 | 0.008192 | 0.010240 | 1.25x |
| 64 | 1024 | 32 | 0.008192 | 0.020480 | 2.50x |
| 64 | 8192 | 128 | 0.026624 | 0.059392 | 2.23x |
| 128 | 32768 | 256 | 0.124928 | 0.192512 | 1.54x |
| M | N | K | Triton-RadixSelect | Torch-TopK | Speedup (Torch / Triton-RadixSelect) |
|---|---|---|---|---|---|
| 64 | 128 | 8 | 0.008384 | 0.017536 | 2.09x |
| 64 | 1024 | 32 | 0.010688 | 0.024304 | 2.27x |
| 64 | 8192 | 128 | 0.029952 | 0.057184 | 1.91x |
| 128 | 32768 | 256 | 0.092256 | 0.117856 | 1.28x |
TopK selector performance uses python/tutorials/tle/deepseek_v32/01-topk_selector.py; the table below is taken from build/topk_selector/zhihu.md. Environment: single NVIDIA H800. Latency is reported in milliseconds.
Providers: Triton, TRT-LLM prefill, TRT-LLM prefill-1024T, FlashInfer, TileLang, and TLE (ours). The batch=1 table also includes TLE cluster (ours).
| seq_len | topk | Triton (ms) | TRT-LLM prefill (ms) | TRT-LLM prefill-1024T (ms) | FlashInfer (ms) | TLE (ours) (ms) | TLE cluster (ours) (ms) | TileLang (ms) |
|---|---|---|---|---|---|---|---|---|
| 8192 | 256 | 0.044672 | 0.011456 | 0.010400 | 0.013312 | 0.010848 | 0.018048 | 0.016416 |
| 32768 | 1024 | 0.141888 | 0.025184 | 0.018592 | 0.022400 | 0.021952 | 0.022656 | 0.034880 |
| 131072 | 2048 | 0.565440 | 0.075456 | 0.048880 | 0.044544 | 0.052368 | 0.029984 | 0.126576 |
| 262144 | 2048 | 1.116624 | 0.129600 | 0.079360 | 0.048160 | 0.090480 | 0.038448 | N/A |
| 524288 | 2048 | 2.172256 | 0.237504 | 0.139712 | 0.048064 | 0.164864 | 0.054832 | N/A |
| seq_len | topk | Triton (ms) | TRT-LLM prefill (ms) | TRT-LLM prefill-1024T (ms) | FlashInfer (ms) | TLE (ours) (ms) | TileLang (ms) |
|---|---|---|---|---|---|---|---|
| 4096 | 128 | 0.031040 | 0.010336 | 0.009840 | 0.012832 | 0.010144 | 0.015040 |
| 8192 | 256 | 0.046656 | 0.012640 | 0.011488 | 0.014512 | 0.012304 | 0.018496 |
| 32768 | 1024 | 0.144480 | 0.026912 | 0.020416 | 0.025376 | 0.024000 | 0.037792 |
| 131072 | 2048 | 0.601968 | 0.092256 | 0.061152 | 0.067392 | 0.063040 | 0.152448 |
| 262144 | 2048 | 1.251968 | 0.173760 | 0.106656 | 0.126032 | 0.112032 | N/A |
| 524288 | 2048 | 2.412192 | 0.311104 | 0.183168 | 0.195776 | 0.198592 | N/A |
TileLang is skipped by this benchmark script for seq_len >= 262144; the table records those entries as N/A.