Source code for gratopy.operator.base

"""Generic implementation of operators including basic arithmetic."""

from __future__ import annotations

import numpy as np

from enum import Enum
from typing import Any
from numbers import Number
from copy import deepcopy

from gratopy.utilities import Numeric


[docs] class OperatorArithmeticOperation(Enum): """Enum for operations that can be performed on operators.""" ADDITION = "sum" MULTIPLICATION = "prod"
def _compute_sum_shapes( operands: list[Operator], ) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]: input_shape = None output_shape = None for op in operands: if op.input_shape is not None: if input_shape is not None and input_shape != op.input_shape: raise ValueError( f"Input shape mismatch in sum: expected {input_shape}, " f"but {op} has input shape {op.input_shape}" ) input_shape = op.input_shape if op.output_shape is not None: if output_shape is not None and output_shape != op.output_shape: raise ValueError( f"Output shape mismatch in sum: expected {output_shape}, " f"but {op} has output shape {op.output_shape}" ) output_shape = op.output_shape return input_shape, output_shape def _compute_product_shapes( operands: list[Operator], ) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]: for i in range(len(operands) - 1): left, right = operands[i], operands[i + 1] if left.input_shape is not None and right.output_shape is not None: if left.input_shape != right.output_shape: raise ValueError( f"Shape mismatch in composition: {left} expects input " f"{left.input_shape}, but {right} produces {right.output_shape}" ) input_shape = operands[-1].input_shape output_shape = operands[0].output_shape return input_shape, output_shape
[docs] class Operator: """Base class for all operators.""" def __init__( self, name: str | None = None, scalar: Numeric = 1, state: dict[str, Any] | None = None, arithmetic_operation: OperatorArithmeticOperation | None = None, operands: list[Operator] | None = None, input_shape: tuple[int, ...] | None = None, output_shape: tuple[int, ...] | None = None, ): if name is None: name = self.__class__.__name__ self.name = name if state is None: state = {} self.state = state self._arithmetic_operation = arithmetic_operation if operands is None: operands = [] self._operands = operands self._scalar: Numeric = 1 self.scalar = scalar self.input_shape = input_shape self.output_shape = output_shape def _scalar_repr_(self) -> str: scalar_repr = "" if self.scalar != 1: scalar_repr = repr(self.scalar) if self.scalar < 0: scalar_repr = f"({scalar_repr})" return scalar_repr def _composite_repr_(self) -> str: """Representation of a composite operator.""" assert self.is_composite(), "This method is for composite operators only." scalar_repr = self._scalar_repr_() if self._arithmetic_operation == OperatorArithmeticOperation.ADDITION: op_repr = " + ".join(repr(op) for op in self._operands) if scalar_repr: return f"{scalar_repr}*({op_repr})" return op_repr elif self._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION: op_reprs = [] for op in self._operands: if op.is_composite(): op_reprs.append(f"({repr(op)})") else: op_reprs.append(repr(op)) op_repr = "*".join(op_reprs) if scalar_repr: return f"{scalar_repr}*{op_repr}" return op_repr raise ValueError(f"Unknown arithmetic operation: {self._arithmetic_operation}") def _repr_name_(self) -> str: """Return the operator name used in string representations.""" return self.name def __repr__(self) -> str: if not self.is_composite(): scalar_repr = self._scalar_repr_() name = self._repr_name_() if scalar_repr: return f"{scalar_repr}*{name}" return name return self._composite_repr_() def __eq__(self, other: Any) -> bool: """Check equality of two operators.""" if not isinstance(other, Operator): return False return all( [ type(self) is type(other), self.name == other.name, self.scalar == other.scalar, self.state == other.state, self._arithmetic_operation == other._arithmetic_operation, self._operands == other._operands, self.input_shape == other.input_shape, self.output_shape == other.output_shape, ] ) @property def scalar(self) -> Numeric: return self._scalar @scalar.setter def scalar(self, value: Numeric): """Set the scalar value of the operator.""" if self.is_composite(): if self._arithmetic_operation == OperatorArithmeticOperation.ADDITION: for child_operator in self._operands: child_operator.scalar = child_operator.scalar * value elif self._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION: self._operands[0].scalar = self._operands[0].scalar * value else: self._scalar = value
[docs] def apply_to( self, argument: Any, output: Any | None = None, **kwargs: Any, ) -> Any: """Application of this operator to some given argument. For composite products, ``output`` is passed only to the final child operator. Intermediate results are computed normally while the final result can reuse the caller-provided array. """ if self.is_composite(): if self._arithmetic_operation == OperatorArithmeticOperation.ADDITION: if output is None: result = self._operands[0].apply_to(argument, **kwargs) for child_op in self._operands[1:]: result += child_op.apply_to(argument, **kwargs) return result result = self._operands[0].apply_to(argument, output=output, **kwargs) for child_op in self._operands[1:]: result += child_op.apply_to(argument, **kwargs) return result elif self._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION: result = argument application_order = list(reversed(self._operands)) for index, child_op in enumerate(application_order): child_output = output if index == len(application_order) - 1 else None result = child_op.apply_to(result, output=child_output, **kwargs) return result raise NotImplementedError( "apply_to needs to be implemented in specialized subclasses" )
[docs] def is_composite(self) -> bool: """Check if the operator is composite.""" return self._arithmetic_operation is not None
def __add__(self, other: Operator) -> Operator: """Add another operator to this one.""" if not isinstance(other, Operator): raise TypeError(f"Cannot add {type(other)} to {type(self)}") if isinstance(other, _ZeroOperator): return self operands = [] for operator in [deepcopy(self), deepcopy(other)]: if ( operator.is_composite() and operator._arithmetic_operation == OperatorArithmeticOperation.ADDITION ): for child_operator in operator._operands: child_operator.scalar *= operator.scalar operands.append(child_operator) else: operands.append(operator) input_shape, output_shape = _compute_sum_shapes(operands) return Operator( name=None, scalar=1, arithmetic_operation=OperatorArithmeticOperation.ADDITION, operands=operands, input_shape=input_shape, output_shape=output_shape, ) def __neg__(self) -> Operator: """Negate this operator.""" return self.__rmul__(-1) def __sub__(self, other: Operator) -> Operator: """Subtract another operator from this one.""" return self + (-other) def __rmul__(self, other: Numeric) -> Operator: # type: ignore[misc] """Right-multiply this operator by a scalar.""" if other == 0: return ZERO if other == 1: return self operator_copy = deepcopy(self) operator_copy.scalar = operator_copy.scalar * other return operator_copy def __mul__(self, other: Operator | Any) -> Operator | Any: """Multiply this operator by another operator, or apply it to appropriate input.""" if not isinstance(other, Operator): if isinstance(other, Number): return self.__rmul__(other) # type: ignore[arg-type, operator] # attempt to apply the operator to the input return self.apply_to(other) if isinstance(other, _ZeroOperator): return other if isinstance(other, _IdentityOperator): return self operands = [] scalar: Numeric = 1 for operator in [deepcopy(self), deepcopy(other)]: if ( operator.is_composite() and operator._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION ): operands.extend(operator._operands) else: scalar *= operator.scalar operator.scalar = 1 operands.append(operator) input_shape, output_shape = _compute_product_shapes(operands) return Operator( name=None, scalar=scalar, arithmetic_operation=OperatorArithmeticOperation.MULTIPLICATION, operands=operands, input_shape=input_shape, output_shape=output_shape, )
class _IdentityOperator(Operator): """Base class for identity operator.""" def __mul__(self, other: Operator | Any) -> Operator | Any: """Multiplying the identity operator with another operator returns the other operator.""" if isinstance(other, Operator): return other return super().__mul__(other) def apply_to(self, argument: Any, output: Any | None = None, **kwargs: Any) -> Any: """The identity operator does not change the input.""" result = self.scalar * argument if output is not None: output[...] = result return output return result class _ZeroOperator(Operator): """Base class for zero operator.""" def __add__(self, other: Operator) -> Operator: """Adding zero operator to any operator returns the other operator.""" return other @property def scalar(self) -> Numeric: return self._scalar @scalar.setter def scalar(self, value: Numeric): pass def apply_to(self, argument: Any, output: Any | None = None, **kwargs: Any) -> Any: """Applying the zero operator returns a zero-multiplied version of the input.""" if output is not None: output[...] = 0 return output try: return 0 * argument # type: ignore except TypeError: pass try: return np.zeros_like(argument) except (ValueError, TypeError): pass raise TypeError(f"Cannot apply zero operator to {type(argument)}") IDENTITY = _IdentityOperator(name="[Id]") ZERO = _ZeroOperator(name="[0]")