Skip to content
sunnycase edited this page May 29, 2026 · 10 revisions

TLE Architecture Design

1. Introduction

Triton is an operator programming language in the form of a Python DSL. It follows a block-based programming model that abstracts away hardware details such as memory hierarchy, layout, pipelining, and synchronization, while achieving strong operator performance through compiler optimization. These advantages have attracted a large developer community and ecosystem.

In recent years, however, Triton has faced growth challenges:

  • Adaptation to DSA platforms and new GPU architectures has progressed slowly.
  • Compared with emerging languages like TileLang, Triton lacks abstractions for fine-grained control of memory hierarchy and parallel granularity, which can lead to weaker performance in some cases.

To address these issues, we propose TLE (Triton Language Extentions), which extends Triton across three levels to meet urgent needs from users with different skill profiles.

2. Observations and Proposed Solutions

We analyzed mainstream DSLs in the industry (Triton, TileLang, and cuTile) and summarized a target language design.

2.1 Pythonic

All three are Python-syntax-based DSLs, indicating that developers prefer Python-like syntax for kernel development, even if only a subset of Python is available.

2.2 Tile Programming

All three support block-level programming. In essence, current block programming mainly performs tiling on global memory. cuTile goes further by supporting multi-level tiling, making it possible to design a unified language across multiple memory hierarchy architectures.

Triton, however, does not explicitly model tile/slice concepts, so users can only tile at the global memory level, limiting further language evolution.

TileLang is similar to Triton in that it does not provide explicit tiling primitives. In addition, except for copy and GEMM, it lacks higher-level tensor ops, which makes GPU programming less convenient. Without automatic vectorization, utilizing SIMD hardware well often requires adding many SIMD-specific ops.

2.3 Memory Hierarchy Abstraction

To address the memory wall, modern hardware uses multi-level memory hierarchies.

  • Triton/cuTile expose only two levels: global memory and local tensor.
  • TileLang directly exposes native hardware memory hierarchy without abstraction.

Problems:

  • Exposing too few levels pushes tiling and buffer promotion work to the compiler.
  • Directly exposing native hierarchy significantly hurts portability.

Preferred direction:

  • Developers perform tiling, but do not explicitly select memory levels.
  • Compiler performs buffer promotion.
  • Developers may provide hints; tile sizes are treated as hyperparameters.

This keeps portability while leaving room for further optimization.

2.4 Parallelism Abstraction

  • Triton/cuTile expose only block-level parallelism, and intra-block parallelism is fully compiler-controlled.
  • TileLang lets developers explicitly control intra-block parallelism (Parallel and Vectorize), improving expressiveness but reducing portability and reuse across hardware.

2.5 Distributed Abstraction

None of these languages directly covers cross-block or cross-node communication, which limits compute-communication fusion (ongoing external work includes Triton Distributed and TileScale).

2.6 Ideal Language Design

  • Level 1: Numpy/PyTorch-like algorithm-level programming. Users focus on algorithm logic only; compiler handles hardware mapping and communication.
  • Level 2: cuTile-like tile-level programming plus distributed descriptions. Users explicitly provide tiling and sharding, while compiler handles memory hierarchy, parallelism, and communication, with optional hardware/scenario hints.
  • Level 3: Hardware-specific extensions (memory hierarchy, thread binding, vectorize, etc.). This level is confined to specific regions with explicit interaction contracts with Level 2. Compiler performs only essential optimizations.

Detailed principles:

  • Tile semantics to avoid manual address arithmetic.
  • Do not require tensor shapes to be powers of two.

Open question: what other strong design ideas should be added?

3. Architecture Design

3.1 Architecture Overview

TLE sits in the middle layer of the AI software stack:

  • Upstream: serves AI frameworks through graph compilers and operator libraries.
  • Downstream: integrates with various hardware runtimes.

Content not available outside Feishu document yet.

TLE is split into three layers:

  • TLE-Lite: lightweight extension over Triton. Features are backend-compatible, and only small changes to existing Triton kernels are needed to gain significant speedups. Targets algorithm engineers and fast optimization workflows.
  • TLE-Struct: architecture-clustered abstractions (e.g., GPGPU, DSA) for deeper performance tuning. Requires moderate hardware knowledge.
  • TLE-Raw: direct hardware control, including vendor-native programming languages for maximum performance. Targets expert performance engineers.

Lowering paths:

  • TLE-Lite and TLE-Struct lower to LLVM IR via FLIR.
  • TLE-Raw lowers to LLVM IR via language-specific pipelines (e.g., vendor private compilers).
  • All parts are finally linked into a complete kernel loaded/executed by runtime.

3.2 TLE-Lite

  • Design philosophy: write once, run anywhere.
  • Core idea: use high-level semantic hints (instead of hard constraints) to guide compiler heuristics. Keep backward compatibility and achieve cross-platform speedups with minimal code changes.

3.2.1 Memory Management

3.2.1.1 tle.load

tle.load extends tl.load with async hint support.

  • Signature: tle.load(ptr, mask=None, other=None, is_async=False)
  • Purpose: keep tl.load semantics while adding async scheduling hints.
  • Practical guidance:
    • Use is_async=True for global-memory reads that are later reused in compute-heavy regions.
    • Keep mask and other explicit on boundary tiles to avoid undefined values.

Simple examples:

offs = base + tl.arange(0, BLOCK)
mask = offs < n_elements
x = tle.load(x_ptr + offs, mask=mask, other=0.0, is_async=True)

for k in tl.range(0, K, BK, num_stages=2):
    a = tle.load(a_ptr + k * stride_a, is_async=True)
    b = tle.load(b_ptr + k * stride_b, is_async=True)
    acc = tl.dot(a, b, acc)

3.2.2 Tensor Slicing

3.2.2.1 tle.extract_tile

Split input tensor into a sub-tile grid using a child-tile shape and extract tile at specified coordinates.

  • Signature: x.extract_tile(index, shape)
  • Returns: sub-tile view at index.
  • GPU: supports extraction from registers and shared memory.
  • Typical use: local transforms on sub-regions without manual pointer arithmetic.
3.2.2.2 tle.insert_tile

Split input tensor into a sub-tile grid using child-tile shape and update tile at specified coordinates.

  • Signature: x.insert_tile(tile, index)
  • Returns: full tensor with the selected sub-tile updated.
  • GPU: supports updates in registers and shared memory.
  • Typical use: write processed activation, quant/dequant, or normalization results back into a larger tile.

Simple example:

# x: [4, 4]
sub = x.extract_tile(index=[1, 0], shape=[2, 2])  # rows [2:4], cols [0:2]
sub = tl.maximum(sub, 0.0)
x = x.insert_tile(sub, index=[1, 0])

3.2.3 Scan/Sort Ops

Scan/sort ops provide prefix, rank, and selection primitives for kernels such as histogram-based top-k, stream compaction, and block-local ordering. TLE-Lite keeps these operations semantic rather than hardware-bound: users describe the scan/sort intent, and each backend chooses an appropriate register/shared-memory lowering strategy.

3.2.3.1 tle.cumsum

tle.cumsum(input, axis=0, reverse=False, dtype=None) computes an exclusive cumulative sum and the total sum along axis in one operation.

  • Signature: tle.cumsum(input, axis=0, reverse=False, dtype=None)
  • Purpose: compute the exclusive prefix/suffix sum and the total sum of a block tensor with one semantic scan op.
  • Returns: (exclusive_sum, total_sum).
  • Typical use: build per-block ranks for top-k, histogram prefixing, stream compaction, and block-local partitioning.
  • exclusive has the same shape as input; total is the scalar sum of the scanned block.
  • reverse=True computes a reverse-exclusive sum, which is useful for suffix counts in descending radix/top-k selection.
  • dtype optionally controls the accumulation/result type. By default, narrow integer inputs are widened to 32-bit integers, and bfloat16 is promoted to float32.
  • Add the original input back to exclusive_sum when an inclusive cumulative sum is needed.
  • Keep masked loads explicit and feed zero for inactive lanes so total_sum describes only valid elements.
  • Supported inputs are static rank-1 block tensors with axis=0, which matches the histogram and radix-selection workloads already used by TLE top-k kernels.

Simple example:

exclusive, total = tle.cumsum(x, axis=0)
inclusive = exclusive + x

3.2.4 Pipeline

3.2.4.1 Pipe and Stages

tle.pipe describes an explicit dataflow edge between a producer and one or more consumers. It records both the shared-memory stage that carries a logical chunk and the synchronization needed to make that chunk visible, so CTA-local load/compute overlap and warp-specialized producer/consumer code can use one typed descriptor instead of hand-managed barriers.

  • Signature: tle.pipe(*, capacity, scope="cta", name=None, readers=None, one_shot=False, **fields)
  • Purpose: create a typed pipe for explicit CTA-local producer/consumer dataflow, ring-buffer stage reuse, and synchronization edges.
  • Parameters:
    • capacity: compile-time positive integer, the number of pipe stages. The leading dimension of every payload field must equal capacity.
    • scope: supported value is "cta".
    • name: optional pipe name for IR/diagnostics; if provided, it must be a string.
    • readers: optional reader-name list. Omit it for the default SPSC reader; pass values such as ("left", "right") for SPMC.
    • one_shot: whether the pipe is a single ready/full edge, useful for one-time broadcast data. one_shot=True does not support close.
    • **fields: one or more payload buffers. Each field must be a shared-memory buffered tensor returned by tle.gpu.alloc(..., scope=tle.gpu.smem), with rank >= 2.
  • Name rules:
    • Pipe field names and reader names must be Python identifiers.
    • Names must not start with _.
    • fields and readers are reserved names.
  • The object returned by tle.pipe(...) is the pipe descriptor. It owns the staged payload fields and creates producer/consumer endpoints through writer() and reader(...).
  • capacity stages form a ring buffer. iter maps to stage = iter % capacity, with a phase bit distinguishing reuse rounds.
3.2.4.2 Producer

A producer is the code that owns pipe.writer(). It acquires a writable stage, fills every required field for the logical chunk, and commits the chunk so consumers can observe it.

  • pipe_value.writer() -> pipe_writer: creates the single writer endpoint for the pipe.
  • The writer always sees all payload fields.
  • writer.acquire(iter) -> pipe_slot: acquires the stage for producer writes and returns a slot whose fields have the leading capacity dimension removed.
  • Produce field data between writer.acquire(iter) and writer.commit(iter).
  • writer.commit(iter) -> None: marks the stage ready for all subscribed consumers. All field writes for that logical chunk must be complete before commit.
  • writer.close(iter) -> None: publishes a closed stage for close-aware consumer loops. one_shot=True pipes do not support close.
  • commit is the producer-side visibility boundary.
3.2.4.3 Consumer

A consumer is the code that owns pipe.reader(...). It waits for a published chunk, reads the returned slot, and releases the stage after all reads are complete.

  • pipe_value.reader(name=None, fields=None) -> pipe_reader: creates a consumer endpoint.
  • For an SPSC pipe (readers=None), name must be omitted.
  • For an SPMC pipe (readers=("mma", "epilogue")), name is required and must match a declared reader.
  • fields may be a non-empty compile-time tuple/list of unique payload field names. If omitted, the consumer subscribes to all fields.
  • A field-subset consumer narrows the endpoint view and the returned wait().slot; it does not create a separate pipe.
  • reader.wait(iter) -> pipe_wait_result: waits until the selected stage is ready or closed and returns both the slot and the closed flag.
  • Normal consumers use wait_result.slot; close-aware consumers also inspect wait_result.is_closed.
  • reader.release(iter) -> None: releases the consumed stage for reuse by the producer. Call it after all reads from wait(iter).slot are complete.
  • wait is the consumer-side visibility boundary. release is the consumer-side free signal.
3.2.4.4 Payload Fields
  • **fields defines the data carried by each stage. Each field is exposed on pipe_slot by name, for example slot.q or slot.scale.
  • pipe_slot also exposes fields: dict[str, tle.gpu.buffered_tensor].
  • pipe_wait_result contains slot: pipe_slot and is_closed: tl.tensor.
  • A pipe may carry one field or multiple fields. Split pipes by logical lifecycle and reader protocol, not by the low-level transport used to fill each field.
  • Different fields in the same slot may be produced by different mechanisms, such as TMA copy, cp.async-style copy, or tle.gpu.local_ptr plus tl.store. Users still call one writer.commit(iter) after all fields for that logical chunk are produced.
  • The compiler infers each field's transport from producer-side IR; transport is not a user-facing pipe attribute and should not be encoded in pipe names, field names, or extra attributes.
  • Use pipe.reader(name, fields=(...)) when a reader only consumes a subset of fields; this narrows the reader view without creating another token.
  • Keep pipe-field provenance visible. Opaque shared-memory pointer escapes, untracked shared stores, or unprovable overlapping writes are rejected instead of lowered with a silent fallback.
  • NVIDIA lowering maps CTA-scoped SMEM pipes to NVWS/mbarrier synchronization. Multi-field payloads are accepted when payload window, field ownership, participant count, and source-order safety can be proven at pipe-field root granularity.
3.2.4.5 Lifecycle
  • In an SPSC pipe, one producer publishes to one default consumer.
  • In an SPMC pipe, one producer publishes the same logical chunk to named consumers such as ("mma", "epilogue").
  • iter is the logical chunk id. The same iter should be used for the producer and every consumer that participates in that chunk.
  • The normal cyclic lifecycle is writer.acquire(iter) -> produce fields -> writer.commit(iter) -> reader.wait(iter) -> consume fields -> reader.release(iter).
  • one_shot=True models a single ready/full edge, usually with capacity=1; do not rely on cyclic reuse or close in that mode.
3.2.4.6 Simple Example

Automatic software pipelining can still be triggered with tl.range(..., num_stages=...). Use an explicit pipe when the producer/consumer split should be visible in the program.

stage_buf = tle.gpu.alloc([2, BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
pipe = tle.pipe(capacity=2, scope="cta", name="x_pipe", x=stage_buf)
writer = pipe.writer()
reader = pipe.reader()
offs = tl.arange(0, BLOCK)

slot = writer.acquire(k)
tl.store(tle.gpu.local_ptr(slot.x), tl.load(x_ptr + k * BLOCK + offs))
writer.commit(k)

ready = reader.wait(k)
x = tl.load(tle.gpu.local_ptr(ready.slot.x))
reader.release(k)

3.2.5 Distributed

Triton distributed API has four core parts: device mesh definition, sharding specification, resharding (collective communication), and remote access (point-to-point communication).

Recommended workflow:

  1. Define topology with tle.device_mesh.
  2. Mark tensor layout with tle.sharding.
  3. Transform layout with tle.reshard.
  4. Keep compute kernels operating on logical tensor views.
3.2.5.1 Device Mesh

tle.device_mesh defines physical device topology and serves as the context foundation for distributed operations.

class device_mesh:
    def __init__(self, topology: dict):
        """
        Initialize DeviceMesh.

        Args:
            topology (dict): Hardware hierarchy description.
                             Keys are hierarchy names; values are int (1D)
                             or tuple lists (multi-dimensional).
        """
        self._physical_ids = ... # Internal flattened physical IDs (0..N-1)
        self._shape = ...        # Current logical shape, e.g. (2, 2, 4, 2, 2, 4)
        self._dim_names = ...    # Current dimension names

    @property
    def shape(self):
        """Return logical mesh shape."""
        return self._shape

    @property
    def ndim(self):
        """Return number of dimensions."""
        return len(self._shape)

    def flatten(self):
        """Flatten mesh to 1D, typically for ring communication."""
        return self.reshape(prod(self._shape))

    def __getitem__(self, key):
        """
        Supports slicing and returns a sub-mesh.
        Supports standard slice and integer indexing.
        """
        return sub_mesh

    def __repr__(self):
        return f"DeviceMesh(shape={self._shape}, names={self._dim_names})"


# Define complex hardware hierarchy
topology = {
    # Cross-node hierarchy (2x2 = 4 nodes)
    "node": [("node_x", 2), ("node_y", 2)],
    # In-node GPUs (4 devices)
    "device": 4,
    # In-GPU cluster (2x2)
    "block_cluster": [("cluster_x", 2), ("cluster_y", 2)],
    # In-cluster blocks (4 blocks)
    "block": 4,
}

# mesh.shape -> (2, 2, 4, 2, 2, 4)
# total size = 256
mesh = tle.device_mesh(topology=topology)
3.2.5.2 Sharding Specification

tle.sharding declares tensor distribution state on the device mesh:

  • splits: how each tensor axis is partitioned on mesh axes.
  • partials: whether tensor is partial-sum state.
  • Unspecified mesh axes are treated as broadcast.

Symbols:

  • tle.S(axis): split.
  • tle.B: broadcast/replicate.
  • tle.P(axis): partial; requires reduce on specified axis.
def sharding(tensor, splits, partials):
    """
    Annotation only: marks tensor state, emits no direct code,
    but guides compiler checks and optimizations.
    """
    return tensor

# Split axis0 on cluster, axis1 on device, and partial on block axis
x_shard = tle.sharding(
    mesh,
    split=[["cluster_x", "cluster_y"], "device"],
    partial=["block"],
)

# Define a sharded tensor
x = tle.make_sharded_tensor(x_ptr, sharding=x_shard, shape=[4, 4])
3.2.5.3 tle.shard_id
  • Signature: tle.shard_id(mesh, axis)
  • Returns current program's coordinate on a mesh axis.
  • axis can be a mesh-axis name (e.g. "node", "device", "cluster_x") or an axis index.
  • Typical use: build peer shard IDs for ring exchange, staged all-reduce, and cluster-cooperative kernels.

Simple example:

mesh = tle.device_mesh({"node": 2, "device": 4})
node_rank = tle.shard_id(mesh, "node")      # 0..1
device_rank = tle.shard_id(mesh, "device")  # 0..3
3.2.5.4 Synchronization

In complex distributed kernels (e.g., ring all-reduce or row/column-independent pipelines), only “same-row” or “same-column” blocks often need synchronization rather than the whole cluster. Global synchronization introduces unnecessary waiting.

def distributed_barrier(mesh):
    """
    If sub_mesh is passed, synchronize only devices in this sub-mesh.
    Devices outside this sub-mesh should treat it as No-Op
    (or compiler guarantees control flow does not enter).
    """
    pass
3.2.5.5 Remote Access

tle.remote obtains a handle for tensor data located on other devices. This maps to point-to-point communication or direct memory access (RDMA/NVLink load).

  • tle.remote reads/writes explicit remote shards.
  • tle.distributed_barrier synchronizes only the mesh/sub-mesh you pass in.
def remote(tensor, shard_id, scope):
    """
    Get a RemoteTensor handle to a shard on a target device.

    :param tensor: logically distributed tensor (already marked by tle.sharding)
    :param shard_id: tuple coordinate in device mesh
    :return: RemoteTensor, supporting load/store and related ops
    """

Simple example:

next_device = (tle.shard_id(mesh, "device") + 1) % mesh.shape[1]
remote_x = tle.remote(x, shard_id=(tle.shard_id(mesh, "node"), next_device), scope=mesh)
tle.distributed_barrier(mesh)
neighbor_vals = tl.load(remote_x)
3.2.5.6 Resharding

tle.reshard is the entrypoint for collectives. Compiler compares source and target specs and inserts communication primitives automatically.

def reshard(tensor, spec):
    """
    Action: transform tensor to a new distribution state.

    Typical transitions:
    1. [ ] -> [S]: Scatter
    2. [S] -> [ ]: Gather
    3. [P] -> [ ]: Reduce
    4. [B] -> [S]: Local slice (no communication)
    5. [S] -> [B]: All-gather
    6. [P] -> [B]: All-reduce
    7. [B] -> [P]: Error
    """

Simple example:

x_spec = tle.sharding(mesh, split=["device"], partial=[])
x = tle.make_sharded_tensor(x_ptr, sharding=x_spec, shape=[M, K])
x_full = tle.reshard(x, spec=tle.sharding(mesh, split=[], partial=[]))
3.2.5.7 Distributed GEMM

NVIDIA Hopper (H100) and newer architectures introduce Thread Block Cluster, allowing groups of CTAs to cooperate via DSMEM for high-bandwidth, low-latency exchange.

tle.distributed_dot is designed to use this feature so developers can write cross-block matrix multiplication without manually handling DSMEM barriers and data movement.

def distributed_dot(a, b, c=None):
    """
    Execute distributed matrix multiplication within current
    Thread Block Cluster scope.

    Behavior depends on sharding specs of input tensors `a` and `b`
    over the cluster mesh.

    Args:
        a (Tensor): left operand with cluster-level sharding annotation.
        b (Tensor): right operand with cluster-level sharding annotation.
        c (Tensor, optional): accumulator.

    Returns:
        Tensor: result tensor with distribution inferred from inputs.
    """

Open question: what additional distributed primitives are needed?

3.2.6 Integrated Examples

3.2.6.1 Async Filter and Compaction

This combines tle.load and tle.cumsum: masked async loads bring a tile into registers, a predicate builds per-lane active flags, and tle.cumsum assigns compact output offsets plus the per-block active count.

pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n_elements

x = tle.load(x_ptr + offs, mask=mask, other=0.0, is_async=True)
keep = (x > threshold) & mask
active = keep.to(tl.int32)

write_offsets, active_total = tle.cumsum(active, axis=0)
block_base = tl.load(block_offsets_ptr + pid)

tl.store(out_ptr + block_base + write_offsets, x, mask=keep)
tl.store(block_counts_ptr + pid, active_total)
3.2.6.2 Pipelined Tile Post-processing

This combines tle.pipe, shared-memory pointer views, and tile slicing. The producer stages global tiles into shared memory, while the consumer reads the ready stage, applies a local transform to one sub-tile, and writes the updated tile to the output.

tile_buf = tle.gpu.alloc([2, BM, BN], dtype=tl.float32, scope=tle.gpu.smem)
pipe = tle.pipe(capacity=2, scope="cta", name="post_pipe", tile=tile_buf)
writer = pipe.writer()
reader = pipe.reader()

rows = tl.arange(0, BM)[:, None]
cols = tl.arange(0, BN)[None, :]

for t in tl.range(0, n_tiles):
    slot = writer.acquire(t)
    vals = tle.load(x_ptr + tile_offset(t, rows, cols), mask=tile_mask(t, rows, cols), other=0.0, is_async=True)
    tl.store(tle.gpu.local_ptr(slot.tile, (rows, cols)), vals)
    writer.commit(t)

for t in tl.range(0, n_tiles):
    ready = reader.wait(t)
    tile = tl.load(tle.gpu.local_ptr(ready.slot.tile, (rows, cols)))
    sub = tile.extract_tile(index=[0, 0], shape=[BM // 2, BN // 2])
    sub = tl.maximum(sub, 0.0)
    tile = tile.insert_tile(sub, index=[0, 0])
    tl.store(out_ptr + tile_offset(t, rows, cols), tile, mask=tile_mask(t, rows, cols))
    reader.release(t)
3.2.6.3 Multi-field Pipe with Selective Readers

This combines one multi-field pipe payload with reader field subsets. One pipe slot carries both a TMA-produced tile and a locally produced scale vector; the MMA reader subscribes only to the tile, while the epilogue reader consumes both fields.

q = tle.gpu.alloc([PIPE_CAPACITY, BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
scale = tle.gpu.alloc(
    [PIPE_CAPACITY, BM],
    dtype=tl.float32,
    scope=tle.gpu.smem,
    nv_mma_shared_layout=False,
)

pipe = tle.pipe(
    capacity=PIPE_CAPACITY,
    scope="cta",
    name="multi_field_inputs",
    readers=("mma", "epilogue"),
    q=q,
    scale=scale,
)

writer = pipe.writer()
mma_reader = pipe.reader("mma", fields=("q",))
epilogue_reader = pipe.reader("epilogue", fields=("q", "scale"))

slot = writer.acquire(k)
tle.gpu.copy(q_desc, slot.q, [BM, BK], [q_block, k_block])
tl.store(tle.gpu.local_ptr(slot.scale, (tl.arange(0, BM),)), scale_values)
writer.commit(k)

q_ready = mma_reader.wait(k).slot
q_tile = tl.load(tle.gpu.local_ptr(q_ready.q))
mma_reader.release(k)

epilogue_ready = epilogue_reader.wait(k).slot
q_for_epilogue = tl.load(tle.gpu.local_ptr(epilogue_ready.q))
scale_tile = tl.load(tle.gpu.local_ptr(epilogue_ready.scale, (tl.arange(0, BM),)))
epilogue_reader.release(k)
3.2.6.4 Distributed Exchange and Reshard

This combines tle.device_mesh, tle.sharding, tle.shard_id, tle.remote, tle.distributed_barrier, and tle.reshard. A shard reads a neighbor's tile for ring-style exchange, then materializes a replicated view before local compute.

mesh = tle.device_mesh({"node": 2, "device": 4})
x_spec = tle.sharding(mesh, split=["device"], partial=[])
x = tle.make_sharded_tensor(x_ptr, sharding=x_spec, shape=[M, K])

node_rank = tle.shard_id(mesh, "node")
device_rank = tle.shard_id(mesh, "device")
next_device = (device_rank + 1) % mesh.shape[1]

neighbor_x = tle.remote(x, shard_id=(node_rank, next_device), scope=mesh)
tle.distributed_barrier(mesh)
neighbor_vals = tl.load(neighbor_x)

x_full = tle.reshard(x, spec=tle.sharding(mesh, split=[], partial=[]))
acc = local_compute(x_full, neighbor_vals)

3.3 TLE-Struct

  • Design philosophy: architecture-aware, fine-grained tuning.
  • Core idea: classify backends by hardware-topology families (e.g., GPGPU, DSA), expose common hierarchical parallel/storage structures, and let developers explicitly define structured compute/data mappings (e.g., warp-group control, pipeline scheduling). This decouples algorithm logic from hardware physical implementation at the abstraction level.

3.3.1 GPU

3.3.1.1 Memory Management
3.3.1.1.1 tle.gpu.memory_space

Specify tensor memory_space:

x = ...
x = tle.gpu.memory_space(x, "shared_memory")
3.3.1.1.2 tle.gpu.alloc

Allocate memory:

a_smem = tle.gpu.alloc(
    [XBLOCK, YBLOCK],
    dtype=tl.float32,
    layout=None,
    scope=tle.gpu.storage_kind.smem,
)
3.3.1.1.3 tle.gpu.local_ptr

Get memory pointers:

# pointers for a_smem[0, :]: [(0, 0), (0, 1), ..., (0, YBLOCK-1)]
a_smem_ptrs = tle.gpu.local_ptr(
    a_smem,
    indices=(tl.broadcast(0, [YBLOCK]), tl.arange(0, YBLOCK)),
)
  • Signature: tle.gpu.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr
  • Purpose: Build arbitrary-shaped pointer views over shared memory buffers for tl.load/tl.store/tl.atomic*.
  • Parameters:
    • buffer: buffered tensor returned by tle.gpu.alloc (SMEM/TMEM).
    • indices: optional tuple of integer tensors. Tuple length must equal rank(buffer), and all tensors must have identical shapes. If omitted/None, backend treats it as full indices.
  • Semantics:
    • If indices is provided: output pointer tensor shape equals common shape of index tensors.
    • For each logical output index (i0, i1, ...), pointer value corresponds to buffer[indices0(i0,...), indices1(i0,...), ...].
    • If indices=None: build full-view pointers over buffer shape (rank>0 returns pointer tensor with shape(buffer), rank=0 returns scalar pointer).
    • Returned pointers live in shared-memory address space (LLVM addrspace=3). Indices must be integers (i32/i64, etc.; lowered to i32).
    • Linearization is row-major (last dimension fastest); shared-memory layout/encoding follows buffer memdesc.

Example 1: 1D slice

smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
# Slice [offset, offset + SLICE)
idx = offset + tl.arange(0, SLICE)
slice_ptr = tle.gpu.local_ptr(smem, (idx,))
vals = tl.load(slice_ptr)

Example 2: K-dimension tiling (matrix slice)

smem_a = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
# Slice (BM, KW), where KW is K-dimension slice
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, KW))
cols = tl.broadcast_to(tl.arange(0, KW)[None, :] + k_start, (BM, KW))
a_slice = tle.gpu.local_ptr(smem_a, (rows, cols))
a_vals = tl.load(a_slice)

Example 3: arbitrary gather view

smem = tle.gpu.alloc([H, W], dtype=tl.float32, scope=tle.gpu.smem)
# Take an offset column per row
rows = tl.broadcast_to(tl.arange(0, H)[:, None], (H, SLICE))
cols = tl.broadcast_to(1 + tl.arange(0, SLICE)[None, :], (H, SLICE))
gather_ptr = tle.gpu.local_ptr(smem, (rows, cols))
out = tl.load(gather_ptr)

Supported downstream ops:

  • tl.load
  • tl.store
  • tl.atomic_add/and/cas/max/min/or/xchg/xor

Practical notes:

  • Atomic ops require element dtype/backend support; use integer/float types supported by target hardware.
  • For local-pointer load-after-store hazards, TLE backend pass TleInsertLocalPointerBarriers inserts barriers automatically; add manual barriers only for custom synchronization patterns outside pass coverage.

Example 4: load/store/atomic on the same local_ptr

smem_i32 = tle.gpu.alloc([BLOCK], dtype=tl.int32, scope=tle.gpu.smem)
ptr = tle.gpu.local_ptr(smem_i32, (tl.arange(0, BLOCK),))

tl.store(ptr, tl.zeros([BLOCK], dtype=tl.int32))
tl.atomic_add(ptr, 1)
vals = tl.load(ptr)
3.3.1.1.4 tle.gpu.local_ptr (for remote)
  • Signature: tle.gpu.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr
  • Purpose: materialize pointer views for remote shared/local buffers returned by tle.remote(...).
  • Inputs:
    • remote_buffer: result of tle.remote(buffer, shard_id, scope), where buffer is typically allocated by tle.gpu.alloc.
    • indices: same rules as local mode (None for full view, or tuple of integer tensors with identical shapes).
  • Semantics:
    • Pointer shape/linearization rules are identical to local tle.gpu.local_ptr.
    • Address resolution targets the remote shard selected by shard_id.
    • Use tle.distributed_barrier(...) when cross-shard producer/consumer ordering is required.

Example: read remote SMEM tile from neighbor shard

smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem)
remote_smem = tle.remote(smem, shard_id=(node_rank, next_device), scope=mesh)

rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
remote_ptr = tle.gpu.local_ptr(remote_smem, (rows, cols))

vals = tl.load(remote_ptr)
3.3.1.1.5 tle.gpu.copy

Memory copy:

tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK])
3.3.1.2 Execution Orchestration
3.3.1.2.1 tle.gpu.warp_specialize

tle.gpu.warp_specialize creates an explicit warp-specialized region inside one CTA, placing different JIT functions into separate warp partitions. Typical uses include splitting TMA/cp.async producers, WGMMA consumers, epilogues, and reductions, with shared-memory data passed through tle.pipe or other explicit synchronization primitives.

  • Signature: tle.gpu.warp_specialize(functions_and_args, worker_num_warps, worker_num_regs)
  • Parameters:
    • functions_and_args: [(fn0, args0), (fn1, args1), ...]. Entry 0 is emitted into the default partition; later entries are emitted into worker partitions.
    • worker_num_warps: warp counts for worker partitions. Length must equal len(functions_and_args) - 1.
    • worker_num_regs: requested register counts for worker partitions. Length must equal len(functions_and_args) - 1.
  • Semantics:
    • Each args value must be a tuple. Plain Python int/float/bool/tl.dtype values are passed as constexpr arguments.
    • The default partition may return values; the return value of tle.gpu.warp_specialize(...) comes from the default partition. Worker partitions perform side effects and end with warp return.
    • Worker callees receive the corresponding "ttg.num-warps" attribute, and the region records requestedRegisters.
    • Captured worker arguments are deduplicated in IR, so multiple workers can share the same pipe endpoint or buffer handle.
    • warp_specialize does not itself provide data-visibility ordering. Producer/consumer ordering should be expressed with tle.pipe commit/wait/release, barriers, or other synchronization primitives.

Example: one producer partition loads shared memory, one consumer worker computes.

@triton.jit
def producer(writer, x_ptr, n_tiles: tl.constexpr, BLOCK: tl.constexpr):
    offs = tl.arange(0, BLOCK)
    for i in tl.range(0, n_tiles):
        slot = writer.acquire(i)
        vals = tl.load(x_ptr + i * BLOCK + offs)
        tl.store(tle.gpu.local_ptr(slot.tile), vals)
        writer.commit(i)


@triton.jit
def consumer(reader, out_ptr, n_tiles: tl.constexpr, BLOCK: tl.constexpr):
    offs = tl.arange(0, BLOCK)
    acc = tl.zeros([BLOCK], dtype=tl.float32)
    for i in tl.range(0, n_tiles):
        ready = reader.wait(i)
        tile = tl.load(tle.gpu.local_ptr(ready.slot.tile))
        acc += tile
        reader.release(i)
    tl.store(out_ptr + offs, acc)


@triton.jit
def kernel(x_ptr, out_ptr, n_tiles: tl.constexpr, BLOCK: tl.constexpr):
    smem = tle.gpu.alloc([2, BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
    pipe = tle.pipe(capacity=2, scope="cta", name="x_pipe", tile=smem)

    tle.gpu.warp_specialize(
        [
            (producer, (pipe.writer(), x_ptr, n_tiles, BLOCK)),
            (consumer, (pipe.reader(), out_ptr, n_tiles, BLOCK)),
        ],
        [4],      # consumer worker uses 4 warps
        [168],    # consumer worker requested registers
    )

Example: multiple workers with an SPMC pipe.

tile = tle.gpu.alloc([2, BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
pipe = tle.pipe(
    capacity=2,
    scope="cta",
    name="spmc_tile",
    readers=("qk", "value"),
    tile=tile,
)

tle.gpu.warp_specialize(
    [
        (load_tile_producer, (pipe.writer(), a_desc, b_desc)),
        (qk_consumer, (pipe.reader("qk"), acc_qk)),
        (value_consumer, (pipe.reader("value", fields=("tile",)), acc_v)),
    ],
    [4, 4],
    [240, 168],
)

3.3.2 DSA

This section is rewritten from triton_v3.2.x (python/triton/experimental/tle/language/dsa and its README). DSA APIs are split into:

  • Generic DSA APIs under tle.dsa.*
  • Backend-specific address spaces under tle.dsa.ascend.*
3.3.2.1 Memory and Data Movement
3.3.2.1.1 tle.dsa.alloc
  • Signature: tle.dsa.alloc(shape, dtype, mem_addr_space)
  • Purpose: allocate DSA local buffers in a target memory space.

Ascend memory spaces exposed in source:

  • tle.dsa.ascend.UB
  • tle.dsa.ascend.L1
  • tle.dsa.ascend.L0A
  • tle.dsa.ascend.L0B
  • tle.dsa.ascend.L0C
a_ub = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_l1 = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.L1)
3.3.2.1.2 tle.dsa.copy
  • Signature: tle.dsa.copy(src, dst, shape, inter_no_alias=False)
  • Purpose: explicit movement between GMEM pointers and DSA local buffers (both directions).
tle.dsa.copy(x_ptrs, a_ub, [tail_m, tail_n])          # GMEM -> local buffer
tle.dsa.copy(a_ub, out_ptrs, [tail_m, tail_n])        # local buffer -> GMEM
3.3.2.1.3 tle.dsa.local_ptr
  • Signature: tle.dsa.local_ptr(buffer, indices=None) -> tl.tensor | tl.ptr
  • Purpose: build pointer views over DSA local buffers (for example UB/L1) for explicit local-memory access patterns.
  • Parameters:
    • buffer: DSA buffered tensor, typically from tle.dsa.alloc.
    • indices: optional tuple of integer tensors. If omitted/None, backend treats it as full indices.
  • Semantics:
    • Shape and indexing behavior follow tle.gpu.local_ptr (same pointer-view model).
    • Intended for DSA-local data access paths that require explicit pointer materialization.

Example:

a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
a_ptr = tle.dsa.local_ptr(a_ub, (rows, cols))
a_val = tl.load(a_ptr)
3.3.2.1.4 tle.dsa.local_ptr (for remote)
  • Signature: tle.dsa.local_ptr(remote_buffer, indices=None) -> tl.tensor | tl.ptr
  • Purpose: materialize pointer views over remote DSA local buffers obtained from tle.remote(...).
  • Inputs:
    • remote_buffer: result of tle.remote(dsa_buffer, shard_id, scope).
    • indices: same rules as local DSA mode.
  • Semantics:
    • Same pointer-view semantics as local DSA mode.
    • Pointer dereference is routed to the remote shard selected by shard_id.
    • Pair with tle.distributed_barrier when cross-shard ordering is required.

Example:

a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
remote_a_ub = tle.remote(a_ub, shard_id=peer_rank, scope=mesh)

rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
remote_ptr = tle.dsa.local_ptr(remote_a_ub, (rows, cols))
remote_val = tl.load(remote_ptr)
3.3.2.1.5 tle.dsa.to_tensor / tle.dsa.to_buffer
  • tle.dsa.to_tensor(buffer, writable=True): convert a DSA buffer to a tensor view for tensor expressions.
  • tle.dsa.to_buffer(tensor, space): convert a tensor value back to a buffer in a target DSA address space.
c_val = tle.dsa.to_tensor(c_ub, writable=True)
result = c_val * 0.5
d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB)
tle.dsa.copy(d_ub, out_ptrs, [tail_m, tail_n])
3.3.2.2 Elementwise Compute Ops (buffer-based)

Builtins provided by source:

  • tle.dsa.add

  • tle.dsa.sub

  • tle.dsa.mul

  • tle.dsa.div

  • tle.dsa.max

  • tle.dsa.min

  • Common signature: tle.dsa.<op>(lhs, rhs, out)

  • Compute model: elementwise binary op over DSA local buffers.

  • Shape rules:

    • lhs, rhs, out must have the same rank and shape.
    • No implicit broadcast is assumed in this API layer.
  • Dtype rules:

    • Three operands should use the same dtype in practice.
    • Integer dtypes are typical for index/count paths; float dtypes are typical for activation/math paths.
  • Memory-space rules:

    • Buffers should be allocated in compatible DSA local spaces (for example UB/L1 combinations allowed by backend).
    • Keep hot operands/results in local space to avoid extra GMEM traffic.

Per-op semantics:

  • tle.dsa.add(lhs, rhs, out): out = lhs + rhs
  • tle.dsa.sub(lhs, rhs, out): out = lhs - rhs
  • tle.dsa.mul(lhs, rhs, out): out = lhs * rhs
  • tle.dsa.div(lhs, rhs, out): out = lhs / rhs (backend-dependent precision/rounding)
  • tle.dsa.max(lhs, rhs, out): out = max(lhs, rhs)
  • tle.dsa.min(lhs, rhs, out): out = min(lhs, rhs)

In-place usage:

  • You can reuse the same output buffer across steps, for example tle.dsa.mul(tmp, b, tmp).
  • Avoid aliasing inputs/outputs unless backend semantics explicitly allow it.

Example 1: arithmetic chain ((a - b) * b) / scale

a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
scale_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
out_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)

tle.dsa.copy(a_ptrs, a_ub, [BM, BK])
tle.dsa.copy(b_ptrs, b_ub, [BM, BK])
tle.dsa.copy(scale_ptrs, scale_ub, [BM, BK])

tle.dsa.sub(a_ub, b_ub, tmp_ub)      # tmp = a - b
tle.dsa.mul(tmp_ub, b_ub, tmp_ub)    # tmp = tmp * b
tle.dsa.div(tmp_ub, scale_ub, out_ub)  # out = tmp / scale

tle.dsa.copy(out_ub, out_ptrs, [BM, BK])

Example 2: clamp by max + min

x_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
floor_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
ceil_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
tmp_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
y_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)

tle.dsa.copy(x_ptrs, x_ub, [BM, BK])
tle.dsa.copy(floor_ptrs, floor_ub, [BM, BK])
tle.dsa.copy(ceil_ptrs, ceil_ub, [BM, BK])

tle.dsa.max(x_ub, floor_ub, tmp_ub)  # tmp = max(x, floor)
tle.dsa.min(tmp_ub, ceil_ub, y_ub)   # y = min(tmp, ceil)

tle.dsa.copy(y_ub, y_ptrs, [BM, BK])
3.3.2.3 Loop and Hint APIs

Source includes:

  • tle.dsa.pipeline(...)
  • tle.dsa.parallel(...)
  • tle.dsa.hint(...) (used as with tle.dsa.hint(...) compile-time hints)
with tle.dsa.hint(inter_no_alias=True):
    tle.dsa.copy(x_ptr + offs, a_ub, [tail_size], inter_no_alias=True)
3.3.2.4 Slice/View Utilities

Source includes:

  • tle.dsa.extract_slice
  • tle.dsa.insert_slice
  • tle.dsa.extract_element
  • tle.dsa.subview
sub = tle.dsa.extract_slice(full, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1))
full = tle.dsa.insert_slice(full, sub, offsets=(0, k0), sizes=(BM, BK), strides=(1, 1))
elem = tle.dsa.extract_element(sub, indice=(i, j))

3.3.3 Struct API Cookbook

3.3.3.1 Shared-memory staging (alloc + copy + local_ptr)

Use this pattern when data is reused across multiple math operations.

# 1) Allocate SMEM tile
a_smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem)

# 2) Copy GMEM -> SMEM
tle.gpu.copy(a_ptrs, a_smem, [BM, BK])

# 3) Build local pointer view and load
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
a_ptr_local = tle.gpu.local_ptr(a_smem, (rows, cols))
a_tile = tl.load(a_ptr_local)
3.3.3.2 Shared-memory atomics with local_ptr

Useful for histogram, bucketization, and radix-select style counting.

bins = 256
counts = tle.gpu.alloc([bins], dtype=tl.int32, scope=tle.gpu.storage_kind.smem)
idx = tl.arange(0, BLOCK) % bins
count_ptr = tle.gpu.local_ptr(counts, (idx,))
tl.atomic_add(count_ptr, 1)
3.3.3.3 DSA local-buffer flow (dsa.alloc + dsa.copy + dsa.to_tensor/to_buffer)

Use this for DSA backends that expose dedicated local buffer spaces.

a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)

tle.dsa.copy(a_ptrs, a_ub, [BM, BK])
tle.dsa.copy(b_ptrs, b_ub, [BM, BK])
tle.dsa.add(a_ub, b_ub, c_ub)

c_val = tle.dsa.to_tensor(c_ub, writable=True)
out_ub = tle.dsa.to_buffer(c_val, tle.dsa.ascend.UB)
tle.dsa.copy(out_ub, out_ptrs, [BM, BK])

3.4 TLE-Raw

  • Design philosophy: native passthrough and maximal control.
  • Core idea: break DSL abstraction boundaries and support inlined vendor-native code. Target instructions are generated through vendor-private pipelines, bypassing general compiler middle layers and giving experts strong control over instruction scheduling, register allocation, and low-level synchronization primitives.

Content not available outside Feishu document yet.

Open question: should Raw integration be limited to Python DSL only?

3.4.1 Language Extensions

3.4.1.1 MLIR
from typing import Annotated
from mlir import ir
from mlir.dialects import arith, nvvm, tensor
import triton.language as tl
from triton.experimental.flagtree.edsl import dialect
import triton.experimental.flagtree.language as fl

# 1. Dialect declaration
@tle.raw.language(name="mlir")
# 2. Hardware constraints
@tle.hardware_constraint(threads_dim=1, sync_scope="block")
# 3. Function implementation
def vector_add_tile(
    x: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
    y: Annotated[ir.RankedTensorType, "tensor<1024xf32>"],
    output: Annotated[ir.RankedTensorType, "tensor<1024xf32>"]
):
    tidx = nvvm.ThreadIdXOp(ir.IntegerType.get_signless(32)).res
    bidx = nvvm.BlockIdXOp(ir.IntegerType.get_signless(32)).res
    bdimx = nvvm.BlockDimXOp(ir.IntegerType.get_signless(32)).res
    idx = arith.addi(arith.muli(bidx, bdimx), tidx)
    idx = arith.index_cast(ir.IndexType.get(), idx)
    xval = tensor.extract(x, [idx])
    yval = tensor.extract(y, [idx])
    result = arith.addf(xval, yval)
    tensor.insert(result, output, [idx])

@tle.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = tl.zeros_like(x)

    # 4. Function call
    tle.call(
        vector_add_tile,
        args=[x, y, output],
        hardware={
            "threads": (BLOCK_SIZE,),
        },
        layout={
            x: {"space": "shared", "order": [0]},
            y: {"space": "shared", "order": [0]},
            output: {"space": "shared", "order": [0]},
        }
    )
    tl.store(output_ptr + offsets, output, mask=mask)

4. Examples and Evaluation

4.1 SparseMLA

Optimization and tests have been conducted for SparseMLA in DSA on H800.

4.1.1 DeepSeek V3.2 SparseMLA Prefill

The cases match the FlashMLA V3.2 sparse prefill performance fixture, with attn_sink omitted because the local Triton, TLE, and TileLang kernels do not implement it: B=1, S=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048.

Latency in milliseconds:

SKV Triton TLE TLE-Pipe-Pipelined TLE-FlashMLA-Prefill TileLang TileLang-Pipelined TileLang-Seesaw FlashMLA
8192 9.896 6.927 4.832 4.273 62.409 5.432 5.160 3.850
32768 11.210 7.624 5.321 4.834 75.160 6.428 5.577 4.117
65536 11.655 8.378 5.731 5.305 84.432 6.865 5.786 4.348
98304 11.835 8.658 5.972 5.599 86.561 7.139 5.873 4.447
131072 11.923 8.863 6.122 5.887 87.448 7.143 5.916 4.534

Speedup summary:

SKV TLE-Pipe over Triton TLE-FlashMLA over Triton TLE-FlashMLA over TLE-Pipe TLE-FlashMLA over FlashMLA TLE-FlashMLA over TileLang-Seesaw
8192 2.05x 2.32x 1.13x 0.90x 1.21x
32768 2.11x 2.32x 1.10x 0.85x 1.15x
65536 2.03x 2.20x 1.08x 0.82x 1.09x
98304 1.98x 2.11x 1.07x 0.79x 1.05x
131072 1.95x 2.03x 1.04x 0.77x 1.00x

4.2 MoeAlignBlockSize

With shared-memory extensions in tle-struct, it is possible to implement vllm/sglang-style moe_align_block_size and improve performance.

4.2.1 RTX 5060 Ti

num_tokens triton triton_atomic tle_atomic_fused [ours] tle_cluster_fused [ours] sglang_cuda Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused))
256 0.0348 0.0302 0.0323 0.0097 0.0138 1.42x
512 0.0369 0.0301 0.0240 0.0117 0.0138 1.18x
1024 0.0369 0.0313 0.0179 0.0117 0.0139 1.19x
2048 0.0368 0.0313 0.0158 0.0131 0.0138 1.05x
4096 0.0369 0.0301 0.0138 0.0143 0.0148 1.07x
8192 0.0369 0.0313 0.0138 0.0164 0.0179 1.30x
16384 0.0369 0.0301 0.0158 0.0205 0.0240 1.52x
32768 0.0389 0.0322 0.0179 0.0301 0.0312 1.74x
65536 0.0430 0.0374 0.0225 0.0486 0.0507 2.25x
163840 0.0609 0.0512 0.0384 0.1036 0.1001 2.61x

4.2.2 H800

num_tokens triton triton_atomic tle_atomic_fused [ours] tle_cluster_fused [ours] sglang_cuda Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused))
256 0.0260 0.0408 0.0445 0.0133 0.0160 1.20x
512 0.0262 0.0399 0.0315 0.0140 0.0162 1.16x
1024 0.0274 0.0401 0.0239 0.0158 0.0163 1.03x
2048 0.0509 0.0422 0.0226 0.0169 0.0173 1.02x
4096 0.0265 0.0412 0.0200 0.0177 0.0187 1.06x
8192 0.0476 0.0416 0.0192 0.0211 0.0230 1.20x
16384 0.0548 0.0441 0.0219 0.0256 0.0286 1.31x
32768 0.0443 0.0441 0.0221 0.0358 0.0401 1.81x
65536 0.0361 0.0481 0.0273 0.0561 0.0645 2.36x
163840 0.0509 0.0626 0.0451 0.1177 0.1323 2.93x

4.2.3 H800 Real Data (build/gems/moe_topk_ids.pt)

  • Runtime config: num_tokens=163840, num_experts=512, block_size=16, source=real.
num_tokens num_experts block_size triton triton_atomic tle_atomic_fused [ours] tle_cluster_fused [ours] sglang_cuda Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused))
163840 512 16 0.0471 0.0535 0.0387 0.0750 0.1467 3.79x

4.2.4 RTX 5060 Ti Real Data (build/gems/moe_topk_ids.pt, Local Measurement)

  • Runtime config: num_tokens=163840, num_experts=512, block_size=16, source=real.
  • Runtime command: conda run -n flagtree python python/tutorials/tle/02-moe_align_block_size.py --skip_correctness --real_data build/gems/moe_topk_ids.pt --num_experts 512 --block_size 16
num_tokens num_experts block_size triton triton_atomic tle_atomic_fused [ours] tle_cluster_fused [ours] sglang_cuda Speedup (sglang_cuda / min(tle_atomic_fused, tle_cluster_fused))
163840 512 16 0.0507 0.0395 0.0261 0.0532 0.1060 4.06x

4.3 TopK

With shared-memory extensions in tle-struct, radix-select-based TopK can improve performance in MoE scenarios with large N and small K.

4.3.1 RTX 5060 Ti (tle-topk-radix-vs-torch)

M N K Triton-RadixSelect Torch-TopK Speedup (Torch / Triton-RadixSelect)
64 128 8 0.008192 0.010240 1.25x
64 1024 32 0.008192 0.020480 2.50x
64 8192 128 0.026624 0.059392 2.23x
128 32768 256 0.124928 0.192512 1.54x

4.3.2 H800 (tle-topk-radix-vs-torch)

M N K Triton-RadixSelect Torch-TopK Speedup (Torch / Triton-RadixSelect)
64 128 8 0.008384 0.017536 2.09x
64 1024 32 0.010688 0.024304 2.27x
64 8192 128 0.029952 0.057184 1.91x
128 32768 256 0.092256 0.117856 1.28x

4.4 TopK Selector

TopK selector performance uses python/tutorials/tle/deepseek_v32/01-topk_selector.py; the table below is taken from build/topk_selector/zhihu.md. Environment: single NVIDIA H800. Latency is reported in milliseconds.

Providers: Triton, TRT-LLM prefill, TRT-LLM prefill-1024T, FlashInfer, TileLang, and TLE (ours). The batch=1 table also includes TLE cluster (ours).

4.4.1 Batch=1

seq_len topk Triton (ms) TRT-LLM prefill (ms) TRT-LLM prefill-1024T (ms) FlashInfer (ms) TLE (ours) (ms) TLE cluster (ours) (ms) TileLang (ms)
8192 256 0.044672 0.011456 0.010400 0.013312 0.010848 0.018048 0.016416
32768 1024 0.141888 0.025184 0.018592 0.022400 0.021952 0.022656 0.034880
131072 2048 0.565440 0.075456 0.048880 0.044544 0.052368 0.029984 0.126576
262144 2048 1.116624 0.129600 0.079360 0.048160 0.090480 0.038448 N/A
524288 2048 2.172256 0.237504 0.139712 0.048064 0.164864 0.054832 N/A

4.4.2 Batch=64

seq_len topk Triton (ms) TRT-LLM prefill (ms) TRT-LLM prefill-1024T (ms) FlashInfer (ms) TLE (ours) (ms) TileLang (ms)
4096 128 0.031040 0.010336 0.009840 0.012832 0.010144 0.015040
8192 256 0.046656 0.012640 0.011488 0.014512 0.012304 0.018496
32768 1024 0.144480 0.026912 0.020416 0.025376 0.024000 0.037792
131072 2048 0.601968 0.092256 0.061152 0.067392 0.063040 0.152448
262144 2048 1.251968 0.173760 0.106656 0.126032 0.112032 N/A
524288 2048 2.412192 0.311104 0.183168 0.195776 0.198592 N/A

TileLang is skipped by this benchmark script for seq_len >= 262144; the table records those entries as N/A.

Clone this wiki locally