Source code for quadrants.lang.func

from typing import Any, Callable

from quadrants._lib import core as _ti_core
from quadrants._lib.core.quadrants_python import (
    Function as FunctionCxx,
)
from quadrants._lib.core.quadrants_python import FunctionKey
from quadrants.lang import _kernel_impl_dataclass, impl, ops
from quadrants.lang.any_array import AnyArray
from quadrants.lang.ast import (
    transform_tree,
)
from quadrants.lang.ast.ast_transformer_utils import ReturnStatus
from quadrants.lang.exception import (
    QuadrantsSyntaxError,
    QuadrantsTypeError,
)
from quadrants.lang.expr import Expr
from quadrants.lang.matrix import MatrixType
from quadrants.lang.struct import StructType
from quadrants.types import (
    ndarray_type,
    primitive_types,
    template,
)
from quadrants.types.enums import AutodiffMode

from ._func_base import FuncBase

# Define proxy for fast lookup
_NONE = AutodiffMode.NONE


[docs] class Func(FuncBase):
[docs] function_counter = 1 # MUST start from >= 1, because 0 means "kernel".
def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None: super().__init__( func=_func, is_classfunc=_classfunc, is_kernel=False, is_classkernel=False, is_real_function=is_real_function, func_id=Func.function_counter, ) Func.function_counter += 1
[docs] self.compiled: dict[int, Callable] = {} # only for real funcs
[docs] self.classfunc = _classfunc
[docs] self.pyfunc = _pyfunc
[docs] self.is_real_function = is_real_function
[docs] self.cxx_function_by_id: dict[int, FunctionCxx] = {}
[docs] self.has_print = False
def __call__(self: "Func", *py_args, **kwargs) -> Any: runtime = impl.get_runtime() global_context = runtime._current_global_context current_kernel = global_context.current_kernel if global_context is not None else None py_args = self.fuse_args( is_func=True, is_pyfunc=self.pyfunc, py_args=py_args, kwargs=kwargs, global_context=global_context ) if not impl.inside_kernel(): if not self.pyfunc: raise QuadrantsSyntaxError("Quadrants functions cannot be called from Python-scope.") return self.func(*py_args) assert current_kernel is not None assert global_context is not None if self.is_real_function: if current_kernel.autodiff_mode != _NONE: raise QuadrantsSyntaxError("Real function in gradient kernels unsupported.") instance_id, arg_features = self.mapper.lookup(impl.current_cfg().raise_on_templated_floats, py_args) key = FunctionKey(self.func.__name__, self.func_id, instance_id) if key.instance_id not in self.compiled: self.do_compile(key=key, args=py_args, arg_features=arg_features) return self.func_call_rvalue(key=key, args=py_args) tree, ctx = self.get_tree_and_ctx( is_kernel=False, py_args=py_args, ast_builder=current_kernel.ast_builder(), is_real_function=self.is_real_function, ) struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx) tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals) ret = transform_tree(tree, ctx) if not self.is_real_function: if self.return_type and ctx.returned != ReturnStatus.ReturnedValue: raise QuadrantsSyntaxError("Function has a return type but does not have a return statement") return ret
[docs] def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any: # Skip the template args, e.g., |self| assert self.is_real_function non_template_args = [] dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()) for i, kernel_arg in enumerate(self.arg_metas): anno = kernel_arg.annotation if not isinstance(anno, template): if id(anno) in primitive_types.type_ids: non_template_args.append(ops.cast(args[i], anno)) elif isinstance(anno, primitive_types.RefType): non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info)) elif isinstance(anno, ndarray_type.NdarrayType): if not isinstance(args[i], AnyArray): raise QuadrantsTypeError( f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}" ) non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info) else: non_template_args.append(args[i]) non_template_args = impl.make_expr_group(non_template_args) compiling_callable = impl.get_runtime().compiling_callable assert compiling_callable is not None func_call = compiling_callable.ast_builder().insert_func_call( self.cxx_function_by_id[key.instance_id], non_template_args, dbg_info ) if self.return_type is None: return None func_call = Expr(func_call) ret = [] for i, return_type in enumerate(self.return_type): if id(return_type) in primitive_types.type_ids: ret.append(Expr(_ti_core.make_get_element_expr(func_call.ptr, (i,), dbg_info))) elif isinstance(return_type, (StructType, MatrixType)): ret.append(return_type.from_quadrants_object(func_call, (i,))) else: raise QuadrantsTypeError(f"Unsupported return type for return value {i}: {return_type}") if len(ret) == 1: return ret[0] return tuple(ret)
[docs] def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None: """ only for real func """ tree, ctx = self.get_tree_and_ctx( is_kernel=False, py_args=args, arg_features=arg_features, is_real_function=self.is_real_function, ) fn = impl.get_runtime().prog.create_function(key) def func_body(): old_callable = impl.get_runtime().compiling_callable impl.get_runtime()._compiling_callable = fn ctx.ast_builder = fn.ast_builder() transform_tree(tree, ctx) impl.get_runtime()._compiling_callable = old_callable self.cxx_function_by_id[key.instance_id] = fn self.compiled[key.instance_id] = func_body self.cxx_function_by_id[key.instance_id].set_function_body(func_body)