"""Generic implementation of operators including basic arithmetic."""
from __future__ import annotations
import numpy as np
from enum import Enum
from typing import Any
from numbers import Number
from copy import deepcopy
from gratopy.utilities import Numeric
[docs]
class OperatorArithmeticOperation(Enum):
"""Enum for operations that can be performed on operators."""
ADDITION = "sum"
MULTIPLICATION = "prod"
def _compute_sum_shapes(
operands: list[Operator],
) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]:
input_shape = None
output_shape = None
for op in operands:
if op.input_shape is not None:
if input_shape is not None and input_shape != op.input_shape:
raise ValueError(
f"Input shape mismatch in sum: expected {input_shape}, "
f"but {op} has input shape {op.input_shape}"
)
input_shape = op.input_shape
if op.output_shape is not None:
if output_shape is not None and output_shape != op.output_shape:
raise ValueError(
f"Output shape mismatch in sum: expected {output_shape}, "
f"but {op} has output shape {op.output_shape}"
)
output_shape = op.output_shape
return input_shape, output_shape
def _compute_product_shapes(
operands: list[Operator],
) -> tuple[tuple[int, ...] | None, tuple[int, ...] | None]:
for i in range(len(operands) - 1):
left, right = operands[i], operands[i + 1]
if left.input_shape is not None and right.output_shape is not None:
if left.input_shape != right.output_shape:
raise ValueError(
f"Shape mismatch in composition: {left} expects input "
f"{left.input_shape}, but {right} produces {right.output_shape}"
)
input_shape = operands[-1].input_shape
output_shape = operands[0].output_shape
return input_shape, output_shape
[docs]
class Operator:
"""Base class for all operators."""
def __init__(
self,
name: str | None = None,
scalar: Numeric = 1,
state: dict[str, Any] | None = None,
arithmetic_operation: OperatorArithmeticOperation | None = None,
operands: list[Operator] | None = None,
input_shape: tuple[int, ...] | None = None,
output_shape: tuple[int, ...] | None = None,
):
if name is None:
name = self.__class__.__name__
self.name = name
if state is None:
state = {}
self.state = state
self._arithmetic_operation = arithmetic_operation
if operands is None:
operands = []
self._operands = operands
self._scalar: Numeric = 1
self.scalar = scalar
self.input_shape = input_shape
self.output_shape = output_shape
def _scalar_repr_(self) -> str:
scalar_repr = ""
if self.scalar != 1:
scalar_repr = repr(self.scalar)
if self.scalar < 0:
scalar_repr = f"({scalar_repr})"
return scalar_repr
def _composite_repr_(self) -> str:
"""Representation of a composite operator."""
assert self.is_composite(), "This method is for composite operators only."
scalar_repr = self._scalar_repr_()
if self._arithmetic_operation == OperatorArithmeticOperation.ADDITION:
op_repr = " + ".join(repr(op) for op in self._operands)
if scalar_repr:
return f"{scalar_repr}*({op_repr})"
return op_repr
elif self._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION:
op_reprs = []
for op in self._operands:
if op.is_composite():
op_reprs.append(f"({repr(op)})")
else:
op_reprs.append(repr(op))
op_repr = "*".join(op_reprs)
if scalar_repr:
return f"{scalar_repr}*{op_repr}"
return op_repr
raise ValueError(f"Unknown arithmetic operation: {self._arithmetic_operation}")
def _repr_name_(self) -> str:
"""Return the operator name used in string representations."""
return self.name
def __repr__(self) -> str:
if not self.is_composite():
scalar_repr = self._scalar_repr_()
name = self._repr_name_()
if scalar_repr:
return f"{scalar_repr}*{name}"
return name
return self._composite_repr_()
def __eq__(self, other: Any) -> bool:
"""Check equality of two operators."""
if not isinstance(other, Operator):
return False
return all(
[
type(self) is type(other),
self.name == other.name,
self.scalar == other.scalar,
self.state == other.state,
self._arithmetic_operation == other._arithmetic_operation,
self._operands == other._operands,
self.input_shape == other.input_shape,
self.output_shape == other.output_shape,
]
)
@property
def scalar(self) -> Numeric:
return self._scalar
@scalar.setter
def scalar(self, value: Numeric):
"""Set the scalar value of the operator."""
if self.is_composite():
if self._arithmetic_operation == OperatorArithmeticOperation.ADDITION:
for child_operator in self._operands:
child_operator.scalar = child_operator.scalar * value
elif self._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION:
self._operands[0].scalar = self._operands[0].scalar * value
else:
self._scalar = value
[docs]
def apply_to(
self,
argument: Any,
output: Any | None = None,
**kwargs: Any,
) -> Any:
"""Application of this operator to some given argument.
For composite products, ``output`` is passed only to the final child
operator. Intermediate results are computed normally while the final
result can reuse the caller-provided array.
"""
if self.is_composite():
if self._arithmetic_operation == OperatorArithmeticOperation.ADDITION:
if output is None:
result = self._operands[0].apply_to(argument, **kwargs)
for child_op in self._operands[1:]:
result += child_op.apply_to(argument, **kwargs)
return result
result = self._operands[0].apply_to(argument, output=output, **kwargs)
for child_op in self._operands[1:]:
result += child_op.apply_to(argument, **kwargs)
return result
elif self._arithmetic_operation == OperatorArithmeticOperation.MULTIPLICATION:
result = argument
application_order = list(reversed(self._operands))
for index, child_op in enumerate(application_order):
child_output = output if index == len(application_order) - 1 else None
result = child_op.apply_to(result, output=child_output, **kwargs)
return result
raise NotImplementedError(
"apply_to needs to be implemented in specialized subclasses"
)
[docs]
def is_composite(self) -> bool:
"""Check if the operator is composite."""
return self._arithmetic_operation is not None
def __add__(self, other: Operator) -> Operator:
"""Add another operator to this one."""
if not isinstance(other, Operator):
raise TypeError(f"Cannot add {type(other)} to {type(self)}")
if isinstance(other, _ZeroOperator):
return self
operands = []
for operator in [deepcopy(self), deepcopy(other)]:
if (
operator.is_composite()
and operator._arithmetic_operation == OperatorArithmeticOperation.ADDITION
):
for child_operator in operator._operands:
child_operator.scalar *= operator.scalar
operands.append(child_operator)
else:
operands.append(operator)
input_shape, output_shape = _compute_sum_shapes(operands)
return Operator(
name=None,
scalar=1,
arithmetic_operation=OperatorArithmeticOperation.ADDITION,
operands=operands,
input_shape=input_shape,
output_shape=output_shape,
)
def __neg__(self) -> Operator:
"""Negate this operator."""
return self.__rmul__(-1)
def __sub__(self, other: Operator) -> Operator:
"""Subtract another operator from this one."""
return self + (-other)
def __rmul__(self, other: Numeric) -> Operator: # type: ignore[misc]
"""Right-multiply this operator by a scalar."""
if other == 0:
return ZERO
if other == 1:
return self
operator_copy = deepcopy(self)
operator_copy.scalar = operator_copy.scalar * other
return operator_copy
def __mul__(self, other: Operator | Any) -> Operator | Any:
"""Multiply this operator by another operator, or apply it to appropriate input."""
if not isinstance(other, Operator):
if isinstance(other, Number):
return self.__rmul__(other) # type: ignore[arg-type, operator]
# attempt to apply the operator to the input
return self.apply_to(other)
if isinstance(other, _ZeroOperator):
return other
if isinstance(other, _IdentityOperator):
return self
operands = []
scalar: Numeric = 1
for operator in [deepcopy(self), deepcopy(other)]:
if (
operator.is_composite()
and operator._arithmetic_operation
== OperatorArithmeticOperation.MULTIPLICATION
):
operands.extend(operator._operands)
else:
scalar *= operator.scalar
operator.scalar = 1
operands.append(operator)
input_shape, output_shape = _compute_product_shapes(operands)
return Operator(
name=None,
scalar=scalar,
arithmetic_operation=OperatorArithmeticOperation.MULTIPLICATION,
operands=operands,
input_shape=input_shape,
output_shape=output_shape,
)
class _IdentityOperator(Operator):
"""Base class for identity operator."""
def __mul__(self, other: Operator | Any) -> Operator | Any:
"""Multiplying the identity operator with another operator returns
the other operator."""
if isinstance(other, Operator):
return other
return super().__mul__(other)
def apply_to(self, argument: Any, output: Any | None = None, **kwargs: Any) -> Any:
"""The identity operator does not change the input."""
result = self.scalar * argument
if output is not None:
output[...] = result
return output
return result
class _ZeroOperator(Operator):
"""Base class for zero operator."""
def __add__(self, other: Operator) -> Operator:
"""Adding zero operator to any operator returns the other operator."""
return other
@property
def scalar(self) -> Numeric:
return self._scalar
@scalar.setter
def scalar(self, value: Numeric):
pass
def apply_to(self, argument: Any, output: Any | None = None, **kwargs: Any) -> Any:
"""Applying the zero operator returns a zero-multiplied version of the input."""
if output is not None:
output[...] = 0
return output
try:
return 0 * argument # type: ignore
except TypeError:
pass
try:
return np.zeros_like(argument)
except (ValueError, TypeError):
pass
raise TypeError(f"Cannot apply zero operator to {type(argument)}")
IDENTITY = _IdentityOperator(name="[Id]")
ZERO = _ZeroOperator(name="[0]")