Source code for quadrants.lang.simt.block
# type: ignore
from quadrants._lib import core as _ti_core
from quadrants.lang import impl
from quadrants.lang.expr import make_expr_group
from quadrants.lang.util import quadrants_scope
[docs]
def arch_uses_spv(arch):
return arch == _ti_core.vulkan or arch == _ti_core.metal
[docs]
def sync():
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda or arch == _ti_core.amdgpu:
return impl.call_internal("block_barrier", with_runtime_context=False)
if arch_uses_spv(arch):
return impl.call_internal("workgroupBarrier", with_runtime_context=False)
raise ValueError(f"ti.block.shared_array is not supported for arch {arch}")
[docs]
def sync_all_nonzero(predicate):
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier_and_i32", predicate, with_runtime_context=False)
raise ValueError(f"ti.block.sync_all_nonzero is not supported for arch {arch}")
[docs]
def sync_any_nonzero(predicate):
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier_or_i32", predicate, with_runtime_context=False)
raise ValueError(f"ti.block.sync_any_nonzero is not supported for arch {arch}")
[docs]
def sync_count_nonzero(predicate):
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier_count_i32", predicate, with_runtime_context=False)
raise ValueError(f"ti.block.sync_count_nonzero is not supported for arch {arch}")
[docs]
def mem_sync():
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda:
return impl.call_internal("block_barrier", with_runtime_context=False)
if arch_uses_spv(arch):
return impl.call_internal("workgroupMemoryBarrier", with_runtime_context=False)
raise ValueError(f"ti.block.mem_sync is not supported for arch {arch}")
[docs]
def thread_idx():
arch = impl.get_runtime().prog.config().arch
if arch_uses_spv(arch):
return impl.call_internal("localInvocationId", with_runtime_context=False)
raise ValueError(f"ti.block.thread_idx is not supported for arch {arch}")
[docs]
def global_thread_idx():
arch = impl.get_runtime().prog.config().arch
if arch == _ti_core.cuda or _ti_core.amdgpu:
return impl.get_runtime().compiling_callable.ast_builder().insert_thread_idx_expr()
if arch_uses_spv(arch):
return impl.call_internal("globalInvocationId", with_runtime_context=False)
raise ValueError(f"ti.block.global_thread_idx is not supported for arch {arch}")
[docs]
class SharedArray:
_is_quadrants_class = True
def __init__(self, shape, dtype):
if isinstance(shape, int):
self.shape = (shape,)
elif (isinstance(shape, tuple) or isinstance(shape, list)) and all(isinstance(s, int) for s in shape):
self.shape = shape
else:
raise ValueError(
f"ti.simt.block.shared_array shape must be an integer or a tuple of integers, but got {shape}"
)
if isinstance(dtype, impl.MatrixType):
dtype = dtype.tensor_type
[docs]
self.shared_array_proxy = impl.expr_init_shared_array(self.shape, dtype)
@quadrants_scope
[docs]
def subscript(self, *indices):
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
return impl.Expr(
ast_builder.expr_subscript(
self.shared_array_proxy,
make_expr_group(*indices),
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
)
)