Source code for quadrants.lang.simt.block

# type: ignore
# pyright: reportInvalidTypeForm=false, reportOperatorIssue=false, reportArgumentType=false

import warnings

from quadrants._lib import core as _qd_core
from quadrants.lang import impl
from quadrants.lang import ops as _ops
from quadrants.lang.expr import make_expr_group
from quadrants.lang.kernel_impl import func as _func

# Import order matters: ``subgroup`` must come before ``reductions``.  ``reductions.py`` does ``from
# quadrants.lang.simt.subgroup import (ballot, invocation_id, ...)`` at its top, and ``subgroup.py`` does ``from
# quadrants.lang.simt.reductions import *`` at its bottom.  If ``reductions`` is imported here first, it triggers a
# circular load that leaves ``subgroup``'s wildcard re-export running while ``reductions.__all__`` isn't yet defined,
# so ``subgroup.reduce_add_tiled`` etc. silently end up missing.  Importing ``subgroup`` first (which then drives
# ``reductions`` to completion via the wildcard) keeps the fully-loaded layout downstream callers expect.  ``isort:
# skip_file`` would disable sorting for the whole file; the local ``noqa`` comments below scope the override to just
# these two lines.
from quadrants.lang.simt import subgroup as _subgroup  # noqa: I001  isort: skip
from quadrants.lang.simt import reductions as _reductions  # noqa: I001
from quadrants.lang.simt.reductions import _bin_add, _bin_max, _bin_min
from quadrants.lang.util import quadrants_scope
from quadrants.types.annotations import template
from quadrants.types.primitive_types import i32 as _i32
from quadrants.types.primitive_types import i64 as _i64
from quadrants.types.primitive_types import u32 as _u32
from quadrants.types.primitive_types import u64 as _u64


[docs] def arch_uses_spv(arch): return arch == _qd_core.vulkan or arch == _qd_core.metal
[docs] def sync(): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda or arch == _qd_core.amdgpu: return impl.call_internal("block_barrier", with_runtime_context=False) if arch_uses_spv(arch): return impl.call_internal("workgroupBarrier", with_runtime_context=False) raise ValueError(f"qd.block.shared_array is not supported for arch {arch}")
[docs] def sync_all_nonzero(predicate): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda: # Hardware-fused barrier+reduction on NVPTX (`barrier.cta.red.and.aligned.all.sync`). return impl.call_internal("block_barrier_and_i32", predicate, with_runtime_context=False) if arch == _qd_core.amdgpu or arch_uses_spv(arch): # AMDGPU and SPIR-V (Vulkan / Metal) emulate via shared memory + 2 barriers + an atomic; see # `_block_reduce_*_emulated` below for the pattern. return _block_reduce_all_nonzero_emulated(predicate) raise ValueError(f"qd.block.sync_all_nonzero is not supported for arch {arch}")
[docs] def sync_any_nonzero(predicate): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda: return impl.call_internal("block_barrier_or_i32", predicate, with_runtime_context=False) if arch == _qd_core.amdgpu or arch_uses_spv(arch): return _block_reduce_any_nonzero_emulated(predicate) raise ValueError(f"qd.block.sync_any_nonzero is not supported for arch {arch}")
[docs] def sync_count_nonzero(predicate): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda: return impl.call_internal("block_barrier_count_i32", predicate, with_runtime_context=False) if arch == _qd_core.amdgpu or arch_uses_spv(arch): return _block_reduce_count_nonzero_emulated(predicate) raise ValueError(f"qd.block.sync_count_nonzero is not supported for arch {arch}")
[docs] def mem_fence(): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda or arch == _qd_core.amdgpu: return impl.call_internal("block_mem_fence", with_runtime_context=False) if arch_uses_spv(arch): return impl.call_internal("workgroupMemoryBarrier", with_runtime_context=False) raise ValueError(f"qd.block.mem_fence is not supported for arch {arch}")
[docs] def mem_sync(): warnings.warn( "qd.simt.block.mem_sync() is deprecated; use qd.simt.block.mem_fence() instead.", DeprecationWarning, stacklevel=2, ) return mem_fence()
[docs] def thread_idx(): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda or arch == _qd_core.amdgpu: return impl.call_internal("block_thread_idx", with_runtime_context=False) if arch_uses_spv(arch): return impl.call_internal("localInvocationId", with_runtime_context=False) raise ValueError(f"qd.block.thread_idx is not supported for arch {arch}")
[docs] def global_thread_idx(): arch = impl.get_runtime().prog.config().arch if arch == _qd_core.cuda or arch == _qd_core.amdgpu: return impl.get_runtime().compiling_callable.ast_builder().insert_thread_idx_expr() if arch_uses_spv(arch): return impl.call_internal("globalInvocationId", with_runtime_context=False) raise ValueError(f"qd.block.global_thread_idx is not supported for arch {arch}")
[docs] class SharedArray: _is_quadrants_class = True def __init__(self, shape, dtype): if isinstance(shape, int): self.shape = (shape,) elif (isinstance(shape, tuple) or isinstance(shape, list)) and all(isinstance(s, int) for s in shape): self.shape = shape else: raise ValueError( f"qd.simt.block.shared_array shape must be an integer or a tuple of integers, but got {shape}" ) if isinstance(dtype, impl.MatrixType): dtype = dtype.tensor_type
[docs] self.dtype = dtype
[docs] self.shared_array_proxy = impl.expr_init_shared_array(self.shape, dtype)
@quadrants_scope
[docs] def subscript(self, *indices): ast_builder = impl.get_runtime().compiling_callable.ast_builder() return impl.Expr( ast_builder.expr_subscript( self.shared_array_proxy, make_expr_group(*indices), _qd_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) )
# --- Block reductions ------------------------------------------------------------------ # # Two-stage block reduce: each subgroup reduces its lanes via `shuffle_down`, lane 0 of every subgroup publishes the # subgroup aggregate to shared memory, a `block.sync()` retires the publish, and thread 0 sequentially folds the # subgroup aggregates with `op`. Cost: `log2(subgroup_size)` shuffles + 1 shared-mem write/read per subgroup + 1 # `block.sync` + (NUM_SUBGROUPS - 1) ops on thread 0. The subgroup size is read from `subgroup.group_size()` (a # compile-time Python int) at the top of every block op, so callers never plumb it in. # # The per-subgroup step delegates to `reductions._reduce_tiled`, the generic-op private helper (alongside # `reductions._inclusive_scan_tiled` / `_exclusive_scan_tiled`) that mirrors `subgroup.reduce_add_tiled` / `_min_tiled` # / `_max_tiled` but takes a caller-supplied template operator -- so the same block skeleton covers add / min / max / # mul / bitwise / custom monoids. @_func
[docs] def reduce(value, block_dim: template(), op: template(), dtype: template()): """Block-scope reduction under a generic associative ``op``. Result is valid in **thread 0 only**; other threads retain partial values. Use `reduce_all` if you need the result on every thread. Args: value: per-thread input. block_dim: threads per block (template). Must be a positive multiple of ``subgroup.group_size()`` (32 on CUDA / Metal / Vulkan-on-NVIDIA, 64 on AMDGPU). op: ``@qd.func`` taking two values and returning the same type as ``value``; callers can plug in custom associative monoids (bitwise ops, multiplicative, matrix-multiply, etc.) without re-implementing the per-subgroup + shared-mem skeleton. See `reduce_add` for the standard sum specialization. dtype: scalar dtype for the inter-subgroup shared-memory staging slot (must match ``value``'s type). The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from `subgroup.group_size()` at compile time. When the block is exactly one subgroup the shared-memory path is short-circuited at compile time and the call costs only the per-subgroup tree. """ SUBGROUP_SIZE = impl.static(_subgroup.group_size()) log2_subgroup = impl.static(_subgroup.log2_group_size()) impl.static_assert( impl.static(block_dim % SUBGROUP_SIZE == 0 and block_dim >= SUBGROUP_SIZE), "block.reduce: block_dim must be a positive multiple of subgroup size", ) NUM_SUBGROUPS = impl.static(block_dim // SUBGROUP_SIZE) subgroup_agg = _reductions._reduce_tiled(value, op, log2_subgroup) if impl.static(NUM_SUBGROUPS == 1): return subgroup_agg tid = thread_idx() subgroup_id = tid // SUBGROUP_SIZE lane_id = tid & impl.static(SUBGROUP_SIZE - 1) shared = SharedArray(impl.static((NUM_SUBGROUPS,)), dtype) if lane_id == 0: shared[subgroup_id] = subgroup_agg sync() result = subgroup_agg if tid == 0: result = shared[0] for w in impl.static(range(1, NUM_SUBGROUPS)): result = op(result, shared[impl.static(w)]) return result
@_func
[docs] def reduce_all(value, block_dim: template(), op: template(), dtype: template()): """Block-scope reduction under a generic associative ``op``, broadcast to every thread. Costs one extra ``block.sync()`` plus a one-slot shared-memory broadcast vs. `reduce`. See `reduce` for the operator contract. """ result = reduce(value, block_dim, op, dtype) bcast = SharedArray((1,), dtype) if thread_idx() == 0: bcast[0] = result sync() return bcast[0]
@_func
[docs] def reduce_add(value, block_dim: template(), dtype: template()): """Block-scope sum reduction. Result valid in **thread 0 only**. See `reduce` for the argument contract.""" return reduce(value, block_dim, _bin_add, dtype)
@_func
[docs] def reduce_min(value, block_dim: template(), dtype: template()): """Block-scope min reduction. Result valid in **thread 0 only**. See `reduce` for the argument contract.""" return reduce(value, block_dim, _bin_min, dtype)
@_func
[docs] def reduce_max(value, block_dim: template(), dtype: template()): """Block-scope max reduction. Result valid in **thread 0 only**. See `reduce` for the argument contract.""" return reduce(value, block_dim, _bin_max, dtype)
@_func
[docs] def reduce_all_add(value, block_dim: template(), dtype: template()): """Block-scope sum reduction with the result broadcast to every thread. See `reduce_add` for the cheaper thread-0-only variant and `reduce` for the argument contract. """ return reduce_all(value, block_dim, _bin_add, dtype)
@_func
[docs] def reduce_all_min(value, block_dim: template(), dtype: template()): """Block-scope min reduction broadcast to every thread. See `reduce_all_add`.""" return reduce_all(value, block_dim, _bin_min, dtype)
@_func
[docs] def reduce_all_max(value, block_dim: template(), dtype: template()): """Block-scope max reduction broadcast to every thread. See `reduce_all_add`.""" return reduce_all(value, block_dim, _bin_max, dtype)
# --- Block scans ----------------------------------------------------------------------- # # Two-stage block scan. Each subgroup does a Hillis-Steele scan via `reductions.{_inclusive_scan_tiled, # _exclusive_scan_tiled}`, the last lane of every subgroup publishes the subgroup aggregate to shared memory, then # every thread sequentially folds the subgroup prefixes and applies its own subgroup's prefix to its scan value. # All threads receive a valid result; cost: one subgroup scan + 1 shared-mem write/read per subgroup + 1 # `block.sync()` + (NUM_SUBGROUPS - 1) ops on every thread (the cross-subgroup prefix is computed redundantly to # avoid a second barrier). # # Inclusive: subgroup aggregate at the last lane is just the inclusive value, written directly. Exclusive: subgroup # aggregate = `op(exclusive[last_lane], value[last_lane])`, since the exclusive scan does not include the last lane's # input - we recover the inclusive total with one extra `op` on the publish path. @_func
[docs] def inclusive_scan(value, block_dim: template(), op: template(), dtype: template()): """Block-scope inclusive scan under a generic associative ``op``. Every thread receives a valid result. Args: value: per-thread input. block_dim: threads per block (template). Must be a positive multiple of ``subgroup.group_size()`` (32 on CUDA / Metal / Vulkan-on-NVIDIA, 64 on AMDGPU). op: ``@qd.func`` taking two values and returning the same type as ``value``; callers can plug in custom associative monoids without re-implementing the per-subgroup + shared-mem skeleton. See `inclusive_add` for the standard sum specialization. dtype: scalar dtype for the inter-subgroup shared-memory staging slot; must match ``value``'s type. The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from `subgroup.group_size()` at compile time. When the block is exactly one subgroup the cross-subgroup shared-memory path is short-circuited at compile time and the call costs only the per-subgroup Hillis-Steele tree. """ SUBGROUP_SIZE = impl.static(_subgroup.group_size()) log2_subgroup = impl.static(_subgroup.log2_group_size()) impl.static_assert( impl.static(block_dim % SUBGROUP_SIZE == 0 and block_dim >= SUBGROUP_SIZE), "block.inclusive_scan: block_dim must be a positive multiple of subgroup size", ) NUM_SUBGROUPS = impl.static(block_dim // SUBGROUP_SIZE) inclusive = _reductions._inclusive_scan_tiled(value, op, log2_subgroup) if impl.static(NUM_SUBGROUPS == 1): return inclusive tid = thread_idx() subgroup_id = tid // SUBGROUP_SIZE lane_id = tid & impl.static(SUBGROUP_SIZE - 1) shared = SharedArray(impl.static((NUM_SUBGROUPS,)), dtype) if lane_id == impl.static(SUBGROUP_SIZE - 1): shared[subgroup_id] = inclusive sync() # Sequential exclusive prefix scan over subgroup aggregates; each thread captures its own subgroup's prefix. # Subgroup 0's prefix is unused (its inclusive value is already the prefix sum from the start of the block), so # we never read `subgroup_prefix` on subgroup 0; the placeholder there exists only to give the variable a # definite type. block_aggregate = shared[0] subgroup_prefix = block_aggregate for w in impl.static(range(1, NUM_SUBGROUPS)): if subgroup_id == impl.static(w): subgroup_prefix = block_aggregate addend = shared[impl.static(w)] block_aggregate = op(block_aggregate, addend) if subgroup_id != 0: inclusive = op(subgroup_prefix, inclusive) return inclusive
@_func
[docs] def exclusive_scan(value, block_dim: template(), op: template(), identity, dtype: template()): """Block-scope exclusive scan under a generic associative ``op`` with explicit ``identity``. Every thread receives a valid result; thread 0 holds ``identity`` and thread ``i > 0`` holds ``op(v[0], ..., v[i-1])``. See `inclusive_scan` for the per-arg contract; in addition this op takes an explicit ``identity`` because exclusive scan needs a definite value for thread 0 (and for the sentinel paths in `exclusive_min` / `exclusive_max`). See `exclusive_add` for the additive specialization which derives a zero identity automatically. """ SUBGROUP_SIZE = impl.static(_subgroup.group_size()) log2_subgroup = impl.static(_subgroup.log2_group_size()) impl.static_assert( impl.static(block_dim % SUBGROUP_SIZE == 0 and block_dim >= SUBGROUP_SIZE), "block.exclusive_scan: block_dim must be a positive multiple of subgroup size", ) NUM_SUBGROUPS = impl.static(block_dim // SUBGROUP_SIZE) exclusive = _reductions._exclusive_scan_tiled(value, op, identity, log2_subgroup) if impl.static(NUM_SUBGROUPS == 1): return exclusive tid = thread_idx() subgroup_id = tid // SUBGROUP_SIZE lane_id = tid & impl.static(SUBGROUP_SIZE - 1) shared = SharedArray(impl.static((NUM_SUBGROUPS,)), dtype) if lane_id == impl.static(SUBGROUP_SIZE - 1): # Subgroup aggregate = inclusive at last lane = exclusive[last] + value[last] under `op`. shared[subgroup_id] = op(exclusive, value) sync() block_aggregate = shared[0] subgroup_prefix = ( identity # subgroup 0's prefix is the identity; subsequent subgroups overwrite this in their own iteration ) for w in impl.static(range(1, NUM_SUBGROUPS)): if subgroup_id == impl.static(w): subgroup_prefix = block_aggregate addend = shared[impl.static(w)] block_aggregate = op(block_aggregate, addend) if subgroup_id != 0: exclusive = op(subgroup_prefix, exclusive) return exclusive
@_func
[docs] def inclusive_add(value, block_dim: template(), dtype: template()): """Block-scope inclusive prefix sum. After the call, thread ``i`` holds ``v[0] + v[1] + ... + v[i]``. See `inclusive_scan` for the argument contract. """ return inclusive_scan(value, block_dim, _bin_add, dtype)
@_func
[docs] def inclusive_min(value, block_dim: template(), dtype: template()): """Block-scope inclusive prefix min. See `inclusive_scan` for the argument contract.""" return inclusive_scan(value, block_dim, _bin_min, dtype)
@_func
[docs] def inclusive_max(value, block_dim: template(), dtype: template()): """Block-scope inclusive prefix max. See `inclusive_scan` for the argument contract.""" return inclusive_scan(value, block_dim, _bin_max, dtype)
@_func
[docs] def exclusive_add(value, block_dim: template(), dtype: template()): """Block-scope exclusive prefix sum. After the call, thread ``i > 0`` holds ``v[0] + v[1] + ... + v[i-1]`` and thread 0 holds the additive identity (zero, in ``value``'s dtype, derived as ``value - value``). See `exclusive_scan` for the argument contract. """ return exclusive_scan(value, block_dim, _bin_add, value - value, dtype)
# Plain Python wrappers (not ``@func``): the identity for an exclusive min / max scan is uniquely determined by # ``value``'s dtype, so we introspect it at compile time and emit a typed-constant identity Expr rather than asking # callers to provide one. Mirrors the subgroup convention (``subgroup.exclusive_min`` and friends). The identity # helpers (``_typed_min_identity`` / ``_typed_max_identity``) are reused from ``reductions.py`` so the per-dtype # sentinel choices stay consistent across the two scopes.
[docs] def exclusive_min(value, block_dim: template(), dtype: template()): """Block-scope exclusive prefix min. After the call, thread ``i > 0`` holds ``min(v[0], ..., v[i-1])`` and thread 0 holds the dtype-derived identity: ``+inf`` for real dtypes, ``np.iinfo(dtype).max`` for integer dtypes (``UINT_MAX`` for unsigned, ``INT_MAX`` for signed). See `exclusive_scan` for the underlying contract. """ return exclusive_scan(value, block_dim, _bin_min, _reductions._typed_min_identity(value), dtype)
[docs] def exclusive_max(value, block_dim: template(), dtype: template()): """Block-scope exclusive prefix max. After the call, thread ``i > 0`` holds ``max(v[0], ..., v[i-1])`` and thread 0 holds the dtype-derived identity: ``-inf`` for real dtypes, ``np.iinfo(dtype).min`` for signed integer dtypes, ``0`` for unsigned and bool. See `exclusive_scan` for the underlying contract. """ return exclusive_scan(value, block_dim, _bin_max, _reductions._typed_max_identity(value), dtype)
# --- Block radix rank ------------------------------------------------------------------ # # Block-level radix ranking via the atomic-OR match-and-count strategy. Each thread holds a single ``u32`` key; the # function returns the key's stable rank within the block under the digit `(key >> bit_start) & ((1 << num_bits) - 1)`, # and writes the per-digit count and exclusive-prefix arrays to caller-supplied shared-memory outparams. # # The algorithm runs in six steps: # # 1. ComputeHistogramsSubgroup: each subgroup builds a private digit histogram in shared memory via ``atomic_add``. # 2. ComputeOffsetsSubgroupUpsweep: every thread sums per-subgroup histograms column-wise to produce a block-wide # bin count for digit ``= tid``, while rewriting the subgroup histogram entries into per-subgroup running # exclusive prefixes. # 3. ExclusiveSum on the per-thread bin counts - uses the block exclusive scan defined above. # 4. ComputeOffsetsSubgroupDownsweep: add the block-wide exclusive prefix into every subgroup's offset entry. # 5. ComputeRanksItem (atomic-OR match): per-subgroup match via ``atomic_or`` on a per-digit lane-mask, then leader # (highest set lane) does a single ``atomic_add`` on the subgroup offset and broadcasts via ``subgroup.shuffle``; # each thread's rank is ``subgroup_offset + popc(bin_mask & lanemask_le) - 1``. # 6. Write bin count + exclusive prefix to the outparam shared arrays. # # Shared-memory layout (all i32, total ``2 * BLOCK_SUBGROUPS * RADIX_DIGITS`` ints, 4096 ints = 16 KiB at the default # 8-subgroup / 256-digit configuration): # # subgroup_offsets / subgroup_histograms : [0, BLOCK_SUBGROUPS * RADIX_DIGITS) (union backing) # match_masks : [BLOCK_SUBGROUPS * RADIX_DIGITS, 2 * ...) # # Subgroup-scope barriers use ``subgroup.sync()`` (lowers to ``__syncwarp`` on CUDA, # ``OpControlBarrier(ScopeSubgroup, ...)`` on SPIR-V, ``s_barrier`` on AMDGPU). ``LaneMaskLe()`` (the PTX intrinsic # that gives a lane its less-than-or-equal lane mask) is replaced by ``subgroup.lanemask_le(lane)`` from the portable # subgroup primitives. @_func def _subgroup_sync_fence(): """Subgroup-scope barrier + memory fence - CUDA ``__syncwarp`` semantics across every backend. Why both ops: on CUDA, `subgroup.sync()` already lowers to `__syncwarp` which folds in a memory fence, so the extra `subgroup.mem_fence()` is redundant (a `__threadfence_block`). On SPIR-V, however, the codegen emits `subgroupBarrier` as `OpControlBarrier(ScopeSubgroup, ScopeSubgroup, 0)` - i.e. with **no** memory semantics - so a bare `subgroup.sync()` does *not* publish prior shared-memory writes to other lanes. The radix rank algorithm relies on the `__syncwarp` invariant that, after the barrier, every lane sees every other lane's prior `atomic_or` / `atomic_add` to shared memory; pairing the barrier with `subgroup.mem_fence()` (which emits a real `OpMemoryBarrier(ScopeSubgroup, AcquireRelease | UniformMemory | WorkgroupMemory)`) restores that invariant. """ _subgroup.sync() _subgroup.mem_fence() @_func def _radix_rank_match_atomic_or_wave32( key, block_dim: template(), radix_bits: template(), bit_start: template(), num_bits: template(), bins, excl_prefix, ): """Wave32 implementation of `radix_rank_match_atomic_or`. See the public wrapper for the contract. Match-mask region is ``i32``; atomic_or, ballot, clz, popcnt all operate on 32 bits. This path is taken on CUDA, Vulkan-on-NVIDIA, and Metal - none of which require ``i64`` threadgroup atomics. """ SUBGROUP_THREADS = impl.static(_subgroup.group_size()) RADIX_DIGITS = impl.static(1 << radix_bits) BLOCK_SUBGROUPS = impl.static(block_dim // SUBGROUP_THREADS) NUM_BITS_MASK = impl.static((1 << num_bits) - 1) BINS_PER_LANE = impl.static(RADIX_DIGITS // SUBGROUP_THREADS) # ``smem_offsets`` (i32) backs the per-subgroup histograms (step 1), in-place column-sum upsweep (step 2), folded # prefixes (step 4), and the leader's atomic_add slot (step 5). ``smem_match`` (i32) backs the per-digit ballot # mask in step 5. These were previously unioned into a single ``i32`` SharedArray; splitting them keeps the # offsets path independent of the match-mask width so the wave64 sibling can pick ``i64`` for its match region. smem_offsets = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i32) smem_match = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i32) tid = thread_idx() subgroup_idx = tid // SUBGROUP_THREADS lane = _ops.cast(_subgroup.invocation_id(), _i32) # Step 1: zero per-subgroup histograms and match_masks. for b in impl.static(range(BINS_PER_LANE)): bin_idx = lane + impl.static(b * SUBGROUP_THREADS) smem_offsets[subgroup_idx * RADIX_DIGITS + bin_idx] = _i32(0) smem_match[subgroup_idx * RADIX_DIGITS + bin_idx] = _i32(0) _subgroup_sync_fence() # Each thread atomic-adds 1 to its subgroup's bin for ``digit``. digit = _ops.cast(_ops.bit_and(_ops.bit_shr(key, _u32(bit_start)), _u32(NUM_BITS_MASK)), _i32) _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _i32(1)) sync() # Publish per-subgroup histograms before column-sum. # Step 2: per-thread column sum across subgroups for digit == tid. Each thread collects the running exclusive # prefix into ``bin_count`` while overwriting the subgroup histogram entries with their per-subgroup exclusive # prefix. After the loop, ``bin_count`` is the block-wide total for digit == tid. bin_count = _i32(0) for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): subgroup_count = smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = bin_count bin_count = bin_count + subgroup_count # Step 3: block-wide exclusive sum on the per-thread bin counts. exclusive_digit_prefix = exclusive_add(bin_count, block_dim, _i32) # Step 4: ComputeOffsetsSubgroupDownsweep - fold the block-wide exclusive prefix into every subgroup's offset. for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = ( smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] + exclusive_digit_prefix ) sync() # Publish subgroup offsets before the per-key match phase. # Step 5: per-key atomic-OR match. ``items_per_thread == 1``, so this runs once per thread. lane_mask = _i32(1) << lane lane_mask_le_v = _subgroup.lanemask_le(_subgroup.invocation_id()) match_idx = subgroup_idx * RADIX_DIGITS + digit # Every thread ORs its lane_mask into the per-digit match mask of its subgroup. Threads with the same digit collide # on the same shared-memory cell and produce a bitmask of "lanes in this subgroup that share this digit". _ops.atomic_or(smem_match[match_idx], lane_mask) _subgroup_sync_fence() # Read the bin_mask back and find the leader (highest matching lane) + intra-subgroup rank. ``clz`` here MUST # run on the u32 (FindUMsb on SPIR-V): casting to i32 first triggers SPIR-V's FindSMsb, which for negative i32 # (top bit set) returns the most-significant 0-bit instead of MSB-of-1, giving a leader that's one less than # the actual highest matching lane. Concretely, with lane 31 holding the only key for its digit, # bin_mask = 0x80000000; FindSMsb on -2147483648 returns 30 (highest 0-bit), so 31 - 30 = 1 elects lane 1 # instead of lane 31, and lane 31's shuffle reads from lane 1 (= 0) - observed as last-lane ranks off by one on # Vulkan / Metal. Now that the subgroup layer dispatches FindUMsb for unsigned ``clz``, passing the u32 directly # emits the right intrinsic on every backend. bin_mask = _ops.cast(smem_match[match_idx], _u32) leader = _i32(31) - _ops.cast(_ops.clz(bin_mask), _i32) popc = _ops.popcnt(_ops.bit_and(bin_mask, lane_mask_le_v)) # Leader claims `popc` slots from this subgroup's slice of the subgroup_offsets entry. subgroup_offset = _i32(0) if lane == leader: subgroup_offset = _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _ops.cast(popc, _i32)) # Leader broadcasts its claimed offset to every lane in the subgroup. subgroup_offset = _subgroup.shuffle(subgroup_offset, _ops.cast(leader, _u32)) # Leader resets the match mask so subsequent passes (or items_per_thread > 1) start clean. if lane == leader: smem_match[match_idx] = _i32(0) _subgroup_sync_fence() rank = subgroup_offset + _ops.cast(popc, _i32) - _i32(1) # Step 6: publish bins + exclusive_digit_prefix to the caller-supplied outparams. ``block_dim == RADIX_DIGITS`` so # every thread writes exactly one digit. Followed by a ``block.sync()`` so the caller can read these arrays # without having to add their own retiring barrier. bins[tid] = bin_count excl_prefix[tid] = exclusive_digit_prefix sync() return rank @_func def _radix_rank_match_atomic_or_wave64( key, block_dim: template(), radix_bits: template(), bit_start: template(), num_bits: template(), bins, excl_prefix, ): """Wave64 implementation of `radix_rank_match_atomic_or`. See the public wrapper for the contract. Match-mask region is ``i64``; atomic_or on shared ``i64`` is native on AMDGPU LDS. Subgroup ``lanemask_le`` is u32-only by contract (see ``subgroup.py``: "lane_id in [0, 31]"), so the 64-lane form is synthesized inline as ``one_at_lane | (one_at_lane - 1)`` - avoids the UB of shifting by 64 when lane == 63. Structural twin of the wave32 path; duplicated rather than parameterised because Quadrants' AST transformer doesn't carry locals across ``if impl.static`` branches and the smem_match dtype + match-phase widths are the only things that differ. """ SUBGROUP_THREADS = impl.static(_subgroup.group_size()) RADIX_DIGITS = impl.static(1 << radix_bits) BLOCK_SUBGROUPS = impl.static(block_dim // SUBGROUP_THREADS) NUM_BITS_MASK = impl.static((1 << num_bits) - 1) BINS_PER_LANE = impl.static(RADIX_DIGITS // SUBGROUP_THREADS) smem_offsets = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i32) smem_match = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i64) tid = thread_idx() subgroup_idx = tid // SUBGROUP_THREADS lane = _ops.cast(_subgroup.invocation_id(), _i32) # Step 1: zero per-subgroup histograms and match_masks. for b in impl.static(range(BINS_PER_LANE)): bin_idx = lane + impl.static(b * SUBGROUP_THREADS) smem_offsets[subgroup_idx * RADIX_DIGITS + bin_idx] = _i32(0) smem_match[subgroup_idx * RADIX_DIGITS + bin_idx] = _i64(0) _subgroup_sync_fence() digit = _ops.cast(_ops.bit_and(_ops.bit_shr(key, _u32(bit_start)), _u32(NUM_BITS_MASK)), _i32) _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _i32(1)) sync() bin_count = _i32(0) for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): subgroup_count = smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = bin_count bin_count = bin_count + subgroup_count exclusive_digit_prefix = exclusive_add(bin_count, block_dim, _i32) for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = ( smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] + exclusive_digit_prefix ) sync() # Step 5 - wave64 specifics: u64 ballot mask via inline ``one_at_lane | (one_at_lane - 1)`` (avoids UB on # lane=63), atomic_or on the i64 match cell, clz / popcnt on u64. Leader formula is ``63 - clz(u64)``. lane_u64 = _ops.cast(lane, _u64) lane_mask = _u64(1) << lane_u64 lane_mask_le_v = lane_mask | (lane_mask - _u64(1)) match_idx = subgroup_idx * RADIX_DIGITS + digit _ops.atomic_or(smem_match[match_idx], _ops.cast(lane_mask, _i64)) _subgroup_sync_fence() # u64 clz via FindUMsb-equivalent on every backend; the wave32 path's caveat about FindSMsb vs FindUMsb on i64 # would apply on SPIR-V wave64 devices if those existed (today wave64 = AMDGPU only). bin_mask = _ops.cast(smem_match[match_idx], _u64) leader = _i32(63) - _ops.cast(_ops.clz(bin_mask), _i32) popc = _ops.popcnt(_ops.bit_and(bin_mask, lane_mask_le_v)) subgroup_offset = _i32(0) if lane == leader: subgroup_offset = _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _ops.cast(popc, _i32)) subgroup_offset = _subgroup.shuffle(subgroup_offset, _ops.cast(leader, _u32)) if lane == leader: smem_match[match_idx] = _i64(0) _subgroup_sync_fence() rank = subgroup_offset + _ops.cast(popc, _i32) - _i32(1) bins[tid] = bin_count excl_prefix[tid] = exclusive_digit_prefix sync() return rank @_func
[docs] def radix_rank_match_atomic_or( key, block_dim: template(), radix_bits: template(), bit_start: template(), num_bits: template(), bins, excl_prefix, ): """Block-level radix rank via the atomic-OR match-and-count strategy. Returns the calling thread's stable rank within the block under digit ``(key >> bit_start) & ((1 << num_bits) - 1)``. Args: key: ``u32`` key, one per thread. block_dim: threads per block (template). Must equal ``RADIX_DIGITS = 1 << radix_bits``: each digit gets exactly one thread for the per-thread bin/excl_prefix output. radix_bits: number of bits in the digit (template). Typical onesweep value is 8, giving 256 digits. bit_start: starting bit of the digit (template). Used as ``key >> bit_start``. num_bits: actual digit width in bits (template), with ``num_bits <= radix_bits``. Bits ``[bit_start, bit_start + num_bits)`` of ``key`` are extracted. bins: ``block.SharedArray((1 << radix_bits,), qd.i32)`` outparam. After the call, ``bins[d]`` holds the count of keys whose digit equals ``d``. Caller is responsible for allocating this array exactly once per kernel. excl_prefix: ``block.SharedArray((1 << radix_bits,), qd.i32)`` outparam. After the call, ``excl_prefix[d]`` holds the exclusive prefix sum of ``bins`` up to digit ``d``. Caller allocates as for ``bins``. The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from `subgroup.group_size()` at compile time. Supports both wave32 (CUDA, Vulkan-on-NVIDIA, Metal) and wave64 (AMDGPU - Quadrants pins every AMDGPU target to ``+wavefrontsize64``). Dispatches to one of two private implementations at compile time based on subgroup size; the match-mask shared-memory region's dtype is the only semantic difference (``i32`` on wave32, ``i64`` on wave64), but Quadrants' AST transformer doesn't carry locals across ``if impl.static`` branches so the two paths are written as separate ``@func`` bodies. Atomic ``or`` on ``i64`` shared memory is native on AMDGPU's LDS; wave32 backends never see the ``i64`` path so portability does not depend on SPIR-V / Metal supporting 64-bit threadgroup atomics. Pre/post: caller must guarantee uniform control flow on entry; the function inserts the necessary ``block.sync()`` and ``subgroup.sync()`` retires. After the call, ``bins`` and ``excl_prefix`` are visible to every thread without a further ``block.sync()`` (we sync internally before exit). Cost: ``~items_per_thread`` atomic_or + atomic_add per pass on shared memory + 2 ``block.sync()`` + 1 block exclusive scan + ``BLOCK_SUBGROUPS`` ops per thread for the column-sum upsweep. Shared-memory footprint at the default ``radix_bits=8``: 4 KiB ``i32`` for subgroup offsets + 4 KiB ``i32`` (wave32) or 8 KiB ``i64`` (wave64) for the match-mask region - so 8 KiB total on wave32, 12 KiB on wave64. """ SUBGROUP_THREADS = impl.static(_subgroup.group_size()) impl.static_assert( impl.static(SUBGROUP_THREADS == 32 or SUBGROUP_THREADS == 64), "block.radix_rank_match_atomic_or: subgroup size must be 32 or 64", ) RADIX_DIGITS = impl.static(1 << radix_bits) impl.static_assert( impl.static(block_dim == RADIX_DIGITS), "block.radix_rank_match_atomic_or: block_dim must equal RADIX_DIGITS (1 << radix_bits)", ) if impl.static(SUBGROUP_THREADS == 32): return _radix_rank_match_atomic_or_wave32(key, block_dim, radix_bits, bit_start, num_bits, bins, excl_prefix) return _radix_rank_match_atomic_or_wave64(key, block_dim, radix_bits, bit_start, num_bits, bins, excl_prefix)
# Shared-memory emulation of CUDA's hardware-fused barrier-with-reduction ops, used on backends that lack a direct # equivalent (AMDGPU has no NVPTX `barrier.cta.red.*` analog; SPIR-V's `OpGroupNonUniform*` only operate at subgroup # scope reliably across Vulkan + Metal). # # Pattern: lane 0 zeroes a 1-element shared `i32` -> block.sync() -> every thread atomically folds its predicate into # the slot -> block.sync() -> every thread reads the broadcasted result. Costs 2 barriers + 1 atomic (vs. CUDA's # hardware fast path of 1 barrier+reduction). Slower than the CUDA path but functionally equivalent and portable. Each # call-site allocates a fresh `SharedArray` so multiple calls in the same kernel do not alias each other. # # IMPORTANT: every thread must participate in the `atomic_add` call unconditionally (guarding with # `if predicate: atomic_add(...)` is NOT safe). On Metal, `workgroupBarrier` does not propagate atomic writes from # divergent branches to threads that did not enter the branch -- non-participating SIMD groups never see the updated # slot. By having every thread call `atomic_add(counter, select(...))` the control flow stays uniform, the barrier # synchronises correctly, and all threads read the final count. # # We also use `atomic_add` rather than `atomic_or` because Metal / MoltenVK silently no-ops `OpAtomicOr` on threadgroup # memory in some configurations. @_func def _block_reduce_count_nonzero_emulated(predicate: _i32) -> _i32: counter = SharedArray((1,), _i32) if thread_idx() == 0: counter[0] = 0 sync() _ops.atomic_add(counter[0], _ops.select(predicate != 0, 1, 0)) sync() return counter[0] @_func def _block_reduce_any_nonzero_emulated(predicate: _i32) -> _i32: counter = SharedArray((1,), _i32) if thread_idx() == 0: counter[0] = 0 sync() _ops.atomic_add(counter[0], _ops.select(predicate != 0, 1, 0)) sync() return _ops.min(counter[0], 1) @_func def _block_reduce_all_nonzero_emulated(predicate: _i32) -> _i32: counter = SharedArray((1,), _i32) if thread_idx() == 0: counter[0] = 0 sync() _ops.atomic_add(counter[0], _ops.select(predicate == 0, 1, 0)) sync() return 1 - _ops.min(counter[0], 1)