Source code for quadrants.lang.struct

# type: ignore

import numbers
from types import MethodType

import numpy as np

from quadrants._lib import core as _ti_core
from quadrants.lang import expr, impl, ops
from quadrants.lang.exception import (
    QuadrantsRuntimeTypeError,
    QuadrantsSyntaxError,
    QuadrantsTypeError,
)
from quadrants.lang.expr import Expr
from quadrants.lang.field import Field, ScalarField, SNodeHostAccess
from quadrants.lang.matrix import Matrix, MatrixType
from quadrants.lang.util import (
    cook_dtype,
    in_python_scope,
    python_scope,
    quadrants_scope,
)
from quadrants.types import primitive_types
from quadrants.types.compound_types import CompoundType
from quadrants.types.enums import Layout
from quadrants.types.utils import is_signed


[docs] class Struct: """The Struct type class. A struct is a dictionary-like data structure that stores members as (key, value) pairs. Valid data members of a struct can be scalars, matrices or other dictionary-like structures. Args: entries (Dict[str, Union[Dict, Expr, Matrix, Struct]]): \ keys and values for struct members. Entries can optionally include a dictionary of functions with the key '__struct_methods' which will be attached to the struct for executing on the struct data. Returns: An instance of this struct. Example:: _ >>> vec3 = ti.types.vector(3, ti.f32) >>> a = ti.Struct(v=vec3([0, 0, 0]), t=1.0) >>> print(a.items) dict_items([('v', [0. 0. 0.]), ('t', 1.0)]) >>> >>> B = ti.Struct(v=vec3([0., 0., 0.]), t=1.0, A=a) >>> print(B.items) dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})]) """ _is_quadrants_class = True _instance_count = 0 def __init__(self, *args, **kwargs): # converts lists to matrices and dicts to structs if len(args) == 1 and kwargs == {} and isinstance(args[0], dict): self.__entries = args[0] elif len(args) == 0: self.__entries = kwargs else: raise QuadrantsSyntaxError( "Custom structs need to be initialized using either dictionary or keyword arguments" ) self.__methods = self.__entries.pop("__struct_methods", {}) matrix_ndim = self.__entries.pop("__matrix_ndim", {}) self._register_methods() for k, v in self.__entries.items(): if isinstance(v, (list, tuple)): v = Matrix(v) if isinstance(v, dict): v = Struct(v) self.__entries[k] = v if in_python_scope() else impl.expr_init(v) self._register_members() self.__dtype = None @property
[docs] def keys(self): """Returns the list of member names in string format. Example:: >>> vec3 = ti.types.vector(3, ti.f32) >>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0) >>> a.keys ['center', 'radius'] """ return list(self.__entries.keys())
@property def _members(self): return list(self.__entries.values()) @property
[docs] def entries(self): return self.__entries
@property
[docs] def methods(self): return self.__methods
@property
[docs] def items(self): """Returns the items in this struct. Example:: >>> vec3 = ti.types.vector(3, ti.f32) >>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0) >>> sphere.items dict_items([('center', 2), ('radius', 1.0)]) """ return self.__entries.items()
def _register_members(self): # https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance cls = self.__class__ new_cls_name = cls.__name__ + str(cls._instance_count) cls._instance_count += 1 properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys} self.__class__ = type(new_cls_name, (cls,), properties) def _register_methods(self): for name, method in self.__methods.items(): # use MethodType to pass self (this object) to the method setattr(self, name, MethodType(method, self)) def __getitem__(self, key): ret = self.__entries[key] if isinstance(ret, SNodeHostAccess): ret = ret.accessor.getter(*ret.key) return ret def __setitem__(self, key, value): if isinstance(self.__entries[key], SNodeHostAccess): self.__entries[key].accessor.setter(value, *self.__entries[key].key) else: if in_python_scope(): if isinstance(self.__entries[key], Struct) or isinstance(self.__entries[key], Matrix): self.__entries[key]._set_entries(value) else: if isinstance(value, numbers.Number): self.__entries[key] = value else: raise TypeError("A number is expected when assigning struct members") else: self.__entries[key] = value def _set_entries(self, value): if isinstance(value, dict): value = Struct(value) for k in self.keys: self[k] = value[k] self.__dtype = value.__dtype @staticmethod def _make_getter(key): def getter(self): """Get an entry from custom struct by name.""" return self[key] return getter @staticmethod def _make_setter(key): @python_scope def setter(self, value): self[key] = value return setter @quadrants_scope def _assign(self, other): if not isinstance(other, (dict, Struct)): raise QuadrantsTypeError("Only dict or Struct can be assigned to a Struct") if isinstance(other, dict): other = Struct(other) if self.__entries.keys() != other.__entries.keys(): raise QuadrantsTypeError(f"Member mismatch between structs {self.keys}, {other.keys}") for k, v in self.items: v._assign(other.__entries[k]) self.__dtype = other.__dtype return self def __len__(self): """Get the number of entries in a custom struct""" return len(self.__entries) def __iter__(self): return self.__entries.values() def __str__(self): """Python scope struct array print support.""" if impl.inside_kernel(): item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items]) item_str += f", struct_methods={self.__methods}" return f"<ti.Struct {item_str}>" return str(self.to_dict()) def __repr__(self): return str(self.to_dict())
[docs] def to_dict(self, include_methods=False, include_ndim=False): """Converts the Struct to a dictionary. Args: include_methods (bool): Whether any struct methods should be included in the result dictionary under the key '__struct_methods'. Returns: Dict: The result dictionary. """ res_dict = { k: ( v.to_dict(include_methods=include_methods, include_ndim=include_ndim) if isinstance(v, Struct) else v.to_list() if isinstance(v, Matrix) else v ) for k, v in self.__entries.items() } if include_methods: res_dict["__struct_methods"] = self.__methods if include_ndim: res_dict["__matrix_ndim"] = dict() for k, v in self.__entries.items(): if isinstance(v, Matrix): res_dict["__matrix_ndim"][k] = v.ndim return res_dict
@classmethod @python_scope
[docs] def field( cls, members, methods={}, shape=None, name="<Struct>", offset=None, needs_grad=False, needs_dual=False, layout=Layout.AOS, ): """Creates a :class:`~quadrants.StructField` with each element has this struct as its type. Args: members (dict): a dict, each item is like `name: type`. methods (dict): a dict of methods that should be included with the field. Each struct item of the field will have the methods as instance functions. shape (Tuple[int]): width and height of the field. offset (Tuple[int]): offset of the indices of the created field. For example if `offset=(-10, -10)` the indices of the field will start at `(-10, -10)`, not `(0, 0)`. needs_grad (bool): enabling grad field (reverse mode autodiff) or not. needs_dual (bool): enabling dual field (forward mode autodiff) or not. layout: AOS or SOA. Example: >>> vec3 = ti.types.vector(3, ti.f32) >>> sphere = {"center": vec3, "radius": float} >>> F = ti.Struct.field(sphere, shape=(3, 3)) >>> F {'center': array([[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], dtype=float32), 'radius': array([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=float32)} """ if shape is None and offset is not None: raise QuadrantsSyntaxError("shape cannot be None when offset is being set") field_dict = {} for key, dtype in members.items(): field_name = name + "." + key if isinstance(dtype, CompoundType): if isinstance(dtype, StructType): field_dict[key] = dtype.field( shape=None, name=field_name, offset=offset, needs_grad=needs_grad, needs_dual=needs_dual, ) else: field_dict[key] = dtype.field( shape=None, name=field_name, offset=offset, needs_grad=needs_grad, needs_dual=needs_dual, ndim=getattr(dtype, "ndim", 2), ) else: field_dict[key] = impl.field( dtype, shape=None, name=field_name, offset=offset, needs_grad=needs_grad, needs_dual=needs_dual, ) if shape is not None: if isinstance(shape, numbers.Number): shape = (shape,) if isinstance(offset, numbers.Number): offset = (offset,) if offset is not None and len(shape) != len(offset): raise QuadrantsSyntaxError( f"The dimensionality of shape and offset must be the same ({len(shape)} != {len(offset)})" ) dim = len(shape) if layout == Layout.SOA: for e in field_dict.values(): impl.root.dense(impl.index_nd(dim), shape).place(e, offset=offset) if needs_grad: for e in field_dict.values(): impl.root.dense(impl.index_nd(dim), shape).place(e.grad, offset=offset) if needs_dual: for e in field_dict.values(): impl.root.dense(impl.index_nd(dim), shape).place(e.dual, offset=offset) else: impl.root.dense(impl.index_nd(dim), shape).place(*tuple(field_dict.values()), offset=offset) if needs_grad: grads = tuple(e.grad for e in field_dict.values()) impl.root.dense(impl.index_nd(dim), shape).place(*grads, offset=offset) if needs_dual: duals = tuple(e.dual for e in field_dict.values()) impl.root.dense(impl.index_nd(dim), shape).place(*duals, offset=offset) return StructField(field_dict, methods, name=name)
class _IntermediateStruct(Struct): """Intermediate struct class for compiler internal use only. Args: entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members. Any methods included under the key '__struct_methods' will be applied to each struct instance. """ def __init__(self, entries): assert isinstance(entries, dict) self._Struct__methods = entries.pop("__struct_methods", {}) self._register_methods() self._Struct__entries = entries self._register_members()
[docs] class StructField(Field): """Quadrants struct field with SNode implementation. Instead of directly constraining Expr entries, the StructField object directly hosts members as `Field` instances to support nested structs. Args: field_dict (Dict[str, Field]): Struct field members. struct_methods (Dict[str, callable]): Dictionary of functions to apply to each struct instance in the field. name (string, optional): The custom name of the field. """ def __init__(self, field_dict, struct_methods, name=None, is_primal=True): # will not call Field initializer # TODO: Why?
[docs] self.field_dict = field_dict
[docs] self.struct_methods = struct_methods
[docs] self.name = name
[docs] self.grad = None
[docs] self.dual = None
self._shape: tuple[int, ...] | None = None self._dtype: DataTypeCxx | None = None if is_primal: grad_field_dict = {} for k, v in self.field_dict.items(): grad_field_dict[k] = v.grad self.grad = StructField(grad_field_dict, struct_methods, name + ".grad", is_primal=False) dual_field_dict = {} for k, v in self.field_dict.items(): dual_field_dict[k] = v.dual self.dual = StructField(dual_field_dict, struct_methods, name + ".dual", is_primal=False) self._register_fields() @property
[docs] def keys(self): """Returns the list of names of the field members. Example:: >>> f1 = ti.Vector.field(3, ti.f32, shape=(3, 3)) >>> f2 = ti.field(ti.f32, shape=(3, 3)) >>> F = ti.StructField({"center": f1, "radius": f2}) >>> F.keys ['center', 'radius'] """ return list(self.field_dict.keys())
@property def _members(self): return list(self.field_dict.values()) @property def _items(self): return self.field_dict.items() @staticmethod def _make_getter(key): def getter(self): """Get an entry from custom struct by name.""" return self.field_dict[key] return getter @staticmethod def _make_setter(key): @python_scope def setter(self, value): self.field_dict[key] = value return setter def _register_fields(self): for k in self.keys: setattr(self, k, self.field_dict[k]) def _get_field_members(self): """Gets A flattened list of all struct elements. Returns: A list of struct elements. """ field_members = [] for m in self._members: assert isinstance(m, Field) field_members += m._get_field_members() return field_members @property def _snode(self): """Gets representative SNode for info purposes. Returns: SNode: Representative SNode (SNode of first field member). """ return self._members[0]._snode def _loop_range(self): """Gets SNode of representative field member for loop range info. Returns: quadrants_python.SNode: SNode of representative (first) field member. """ return self._members[0]._loop_range() @python_scope
[docs] def copy_from(self, other): """Copies all elements from another field. The shape of the other field needs to be the same as `self`. Args: other (Field): The source field. """ assert isinstance(other, Field) assert set(self.keys) == set(other.keys) for k in self.keys: self.field_dict[k].copy_from(other.get_member_field(k))
@python_scope
[docs] def fill(self, val): """Fills this struct field with a specified value. Args: val (Union[int, float]): Value to fill. """ for v in self._members: v.fill(val)
def _initialize_host_accessors(self): for v in self._members: v._initialize_host_accessors()
[docs] def get_member_field(self, key): """Creates a ScalarField using a specific field member. Args: key (str): Specified key of the field member. Returns: ScalarField: The result ScalarField. """ return self.field_dict[key]
@python_scope
[docs] def from_numpy(self, array_dict): """Copies the data from a set of `numpy.array` into this field. The argument `array_dict` must be a dictionay-like object, it contains all the keys in this field and the copying process between corresponding items can be performed. """ for k, v in self._items: v.from_numpy(array_dict[k])
@python_scope
[docs] def from_torch(self, array_dict): """Copies the data from a set of `torch.tensor` into this field. The argument `array_dict` must be a dictionay-like object, it contains all the keys in this field and the copying process between corresponding items can be performed. """ for k, v in self._items: v.from_torch(array_dict[k])
@python_scope
[docs] def to_numpy(self): """Converts the Struct field instance to a dictionary of NumPy arrays. The dictionary may be nested when converting nested structs. Returns: Dict[str, Union[numpy.ndarray, Dict]]: The result NumPy array. """ return {k: v.to_numpy() for k, v in self._items}
@python_scope
[docs] def to_torch(self, device=None): """Converts the Struct field instance to a dictionary of PyTorch tensors. The dictionary may be nested when converting nested structs. Args: device (torch.device, optional): The desired device of returned tensor. Returns: Dict[str, Union[torch.Tensor, Dict]]: The result PyTorch tensor. """ return {k: v.to_torch(device=device) for k, v in self._items}
@python_scope def __setitem__(self, indices, element): self._initialize_host_accessors() self[indices]._set_entries(element) @python_scope def __getitem__(self, indices): self._initialize_host_accessors() # scalar fields does not instantiate SNodeHostAccess by default entries = { k: v._host_access(self._pad_key(indices))[0] if isinstance(v, ScalarField) else v[indices] for k, v in self._items } entries["__struct_methods"] = self.struct_methods return Struct(entries)
class StructType(CompoundType): def __init__(self, **kwargs): self.members = {} self.methods = {} elements = [] for k, dtype in kwargs.items(): if k == "__struct_methods": self.methods = dtype elif isinstance(dtype, StructType): self.members[k] = dtype elements.append([dtype.dtype, k]) elif isinstance(dtype, MatrixType): self.members[k] = dtype elements.append([dtype.tensor_type, k]) else: dtype = cook_dtype(dtype) self.members[k] = dtype elements.append([dtype, k]) self.dtype = _ti_core.get_type_factory_instance().get_struct_type(elements) def __call__(self, *args, **kwargs): """Create an instance of this struct type.""" d = {} items = self.members.items() # iterate over the members of this struct for index, pair in enumerate(items): name, dtype = pair # (member name, member type) if index < len(args): # set from args data = args[index] else: # set from kwargs data = kwargs.get(name, 0) # If dtype is CompoundType and data is a scalar, it cannot be # casted in the self.cast call later. We need an initialization here. if isinstance(dtype, CompoundType) and not isinstance(data, (dict, Struct)): data = dtype(data) d[name] = data entries = Struct(d) entries._Struct__dtype = self.dtype struct = self.cast(entries) struct._Struct__dtype = self.dtype return struct def __instancecheck__(self, instance): if not isinstance(instance, Struct): return False if list(self.members.keys()) != list(instance._Struct__entries.keys()): return False if ( hasattr(instance, "_Struct__dtype") and instance._Struct__dtype is not None and instance._Struct__dtype != self.dtype ): return False for index, (name, dtype) in enumerate(self.members.items()): val = instance._members[index] if isinstance(dtype, StructType): if not isinstance(val, dtype): return False elif isinstance(dtype, MatrixType): if isinstance(val, Expr): if not val.is_tensor(): return False if val.get_shape() != dtype.get_shape(): return False elif dtype in primitive_types.integer_types: if isinstance(val, Expr): if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.integer_types: return False elif not isinstance(val, (int, np.integer)): return False elif dtype in primitive_types.real_types: if isinstance(val, Expr): if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.real_types: return False elif not isinstance(val, (float, np.floating)): return False return True def from_quadrants_object(self, func_ret, ret_index=()): d = {} items = self.members.items() for index, pair in enumerate(items): name, dtype = pair if isinstance(dtype, CompoundType): d[name] = dtype.from_quadrants_object(func_ret, ret_index + (index,)) else: d[name] = expr.Expr( _ti_core.make_get_element_expr( func_ret.ptr, ret_index + (index,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()), ) ) d["__struct_methods"] = self.methods struct = Struct(d) struct._Struct__dtype = self.dtype return struct def from_kernel_struct_ret(self, launch_ctx, ret_index=()): d = {} items = self.members.items() for index, pair in enumerate(items): name, dtype = pair if isinstance(dtype, CompoundType): d[name] = dtype.from_kernel_struct_ret(launch_ctx, ret_index + (index,)) else: if dtype in primitive_types.integer_types: if is_signed(cook_dtype(dtype)): d[name] = launch_ctx.get_struct_ret_int(ret_index + (index,)) else: d[name] = launch_ctx.get_struct_ret_uint(ret_index + (index,)) elif dtype in primitive_types.real_types: d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,)) else: raise QuadrantsRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}") d["__struct_methods"] = self.methods struct = Struct(d) struct._Struct__dtype = self.dtype return struct def set_kernel_struct_args(self, struct, launch_ctx, ret_index=()): # TODO: move this to class Struct after we add dtype to Struct items = self.members.items() for index, pair in enumerate(items): name, dtype = pair if isinstance(dtype, CompoundType): dtype.set_kernel_struct_args(struct[name], launch_ctx, ret_index + (index,)) else: if dtype in primitive_types.integer_types: if is_signed(cook_dtype(dtype)): launch_ctx.set_struct_arg_int(ret_index + (index,), struct[name]) else: launch_ctx.set_struct_arg_uint(ret_index + (index,), struct[name]) elif dtype in primitive_types.real_types: launch_ctx.set_struct_arg_float(ret_index + (index,), struct[name]) else: raise QuadrantsRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}") def cast(self, struct): # sanity check members if self.members.keys() != struct._Struct__entries.keys(): raise QuadrantsSyntaxError("Incompatible arguments for custom struct members!") entries = {} for k, dtype in self.members.items(): if isinstance(dtype, MatrixType): entries[k] = dtype(struct._Struct__entries[k]) elif isinstance(dtype, CompoundType): entries[k] = dtype.cast(struct._Struct__entries[k]) else: if in_python_scope(): v = struct._Struct__entries[k] entries[k] = int(v) if dtype in primitive_types.integer_types else float(v) else: entries[k] = ops.cast(struct._Struct__entries[k], dtype) entries["__struct_methods"] = self.methods struct = Struct(entries) struct._Struct__dtype = self.dtype return struct def filled_with_scalar(self, value): entries = {} for k, dtype in self.members.items(): if isinstance(dtype, MatrixType): entries[k] = dtype(value) elif isinstance(dtype, CompoundType): entries[k] = dtype.filled_with_scalar(value) else: entries[k] = value entries["__struct_methods"] = self.methods struct = Struct(entries) struct._Struct__dtype = self.dtype return struct def field(self, **kwargs): return Struct.field(self.members, self.methods, **kwargs) def __str__(self): """Python scope struct type print support.""" item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.members.items()]) item_str += f", struct_methods={self.methods}" return f"<ti.StructType {item_str}>"
[docs] def dataclass(cls): """Converts a class with field annotations and methods into a quadrants struct type. This will return a normal custom struct type, with the functions added to it. Struct fields can be generated in the normal way from the struct type. Functions in the class can be run on the struct instance. This class decorator inspects the class for annotations and methods and 1. Sets the annotations as fields for the struct 2. Attaches the methods to the struct type Example:: >>> @ti.dataclass >>> class Sphere: >>> center: vec3 >>> radius: ti.f32 >>> >>> @ti.func >>> def area(self): >>> return 4 * 3.14 * self.radius * self.radius >>> >>> my_spheres = Sphere.field(shape=(n, )) >>> my_sphere[2].area() Args: cls (Class): the class with annotations and methods to convert to a struct Returns: A quadrants struct with the annotations as fields and methods from the class attached. """ # save the annotation fields for the struct fields = getattr(cls, "__annotations__", {}) # raise error if there are default values for k in fields.keys(): if hasattr(cls, k): raise QuadrantsSyntaxError("Default value in @dataclass is not supported.") # get the class methods to be attached to the struct types fields["__struct_methods"] = { attribute: getattr(cls, attribute) for attribute in dir(cls) if callable(getattr(cls, attribute)) and not attribute.startswith("__") } return StructType(**fields)
__all__ = ["Struct", "StructField", "dataclass"]