Source code for odl.core.space.pspace

# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

# Necessary for inplace updates
# pylint: disable=inconsistent-return-statements
# Necessary for operator arithmetic
# pylint: disable=unnecessary-dunder-call
# xl and xr are okay variable name in a function helper
# pylint: disable=invalid-name
# Necessary lazy imports
# pylint: disable=import-outside-toplevel

"""Cartesian products of `LinearSpace` instances."""

from itertools import product
from numbers import Integral, Number
import warnings
from contextlib import contextmanager

import numpy as np

from odl.core.array_API_support.utils import get_array_and_backend
from odl.core.util import indent, is_real_dtype, signature_string
from odl.core.set import LinearSpace
from odl.core.set.space import (LinearSpaceElement,
    SupportedNumOperationParadigms, NumOperationParadigmSupport)
from .weightings.weighting import (
    ArrayWeighting, ConstWeighting, CustomDist, CustomInner, CustomNorm,
    Weighting)


__all__ = ('ProductSpace',)


class ProductSpace(LinearSpace):

    """Cartesian product of `LinearSpace`'s.

    A product space is the Cartesian product ``X_1 x ... x X_n`` of
    linear spaces ``X_i``. It is itself a linear space, where the linear
    combination is defined component-wise. Inner product, norm and
    distance can also be defined in natural ways from the corresponding
    functions in the individual components.
    """

[docs] def __init__(self, *spaces, **kwargs): r"""Initialize a new instance. Parameters ---------- space1,...,spaceN : `LinearSpace` or int The individual spaces ("factors / parts") in the product space. Can also be given as ``space, n`` with ``n`` integer, in which case the power space ``space ** n`` is created. exponent : non-zero float or ``float('inf')``, optional Order of the product distance/norm, i.e. ``dist(x, y) = np.linalg.norm(x-y, ord=exponent)`` ``norm(x) = np.linalg.norm(x, ord=exponent)`` Values ``0 <= exponent < 1`` are currently unsupported due to numerical instability. See ``Notes`` for further information about the interpretation of the values. Default: 2.0 field : `Field`, optional Scalar field of the resulting space. Default: ``spaces[0].field`` weighting : optional Use weighted inner product, norm, and dist. The following types are supported as ``weighting``: ``None`` : no weighting (default) `Weighting` : weighting class, used directly. Such a class instance can be retrieved from the space by the `ProductSpace.weighting` property. `array-like` : weigh each component with one entry from the array. The array must be one-dimensional and have the same length as the number of spaces. float : same weighting factor in each component Other Parameters ---------------- dist : callable, optional The distance function defining a metric on the space. It must accept two `ProductSpaceElement` arguments and fulfill the following mathematical conditions for any three space elements ``x, y, z``: - ``dist(x, y) >= 0`` - ``dist(x, y) = 0`` if and only if ``x = y`` - ``dist(x, y) = dist(y, x)`` - ``dist(x, y) <= dist(x, z) + dist(z, y)`` By default, ``dist(x, y)`` is calculated as ``norm(x - y)``. Cannot be combined with: ``weighting, norm, inner`` norm : callable, optional The norm implementation. It must accept an `ProductSpaceElement` argument, return a float and satisfy the following conditions for all space elements ``x, y`` and scalars ``s``: - ``||x|| >= 0`` - ``||x|| = 0`` if and only if ``x = 0`` - ``||s * x|| = |s| * ||x||`` - ``||x + y|| <= ||x|| + ||y||`` By default, ``norm(x)`` is calculated as ``inner(x, x)``. Cannot be combined with: ``weighting, dist, inner`` inner : callable, optional The inner product implementation. It must accept two `ProductSpaceElement` arguments, return a element from the field of the space (real or complex number) and satisfy the following conditions for all space elements ``x, y, z`` and scalars ``s``: - ``<x, y> = conj(<y, x>)`` - ``<s*x + y, z> = s * <x, z> + <y, z>`` - ``<x, x> = 0`` if and only if ``x = 0`` Cannot be combined with: ``weighting, dist, norm`` Examples -------- Product of two rn spaces >>> r2x3 = ProductSpace(odl.rn(2), odl.rn(3)) Powerspace of rn space >>> r2x2x2 = ProductSpace(odl.rn(2), 3) Notes ----- Inner product, norm and distance are evaluated by collecting the result of the corresponding operation in the individual components and reducing the resulting vector to a single number. The ``exponent`` parameter influences only this last part, not the computations in the individual components. We give the exact definitions in the following: Let :math:`\mathcal{X} = \mathcal{X}_1 \times \dots \times \mathcal{X}_d` be a product space, and :math:`\langle \cdot, \cdot\rangle_i`, :math:`\lVert \cdot \rVert_i`, :math:`d_i(\cdot, \cdot)` be inner products, norms and distances in the respective component spaces. **Inner product:** .. math:: \langle x, y \rangle = \sum_{i=1}^d \langle x_i, y_i \rangle_i **Norm:** - :math:`p < \infty`: .. math:: \lVert x\rVert = \left( \sum_{i=1}^d \lVert x_i \rVert_i^p \right)^{1/p} - :math:`p = \infty`: .. math:: \lVert x\rVert = \max_i \lVert x_i \rVert_i **Distance:** - :math:`p < \infty`: .. math:: d(x, y) = \left( \sum_{i=1}^d d_i(x_i, y_i)^p \right)^{1/p} - :math:`p = \infty`: .. math:: d(x, y) = \max_i d_i(x_i, y_i) To implement own versions of these functions, you can use the following snippet to gather the vector of norms (analogously for inner products and distances):: norms = np.fromiter( (xi.norm() for xi in x), dtype=np.float64, count=len(x)) See Also -------- ProductSpaceArrayWeighting ProductSpaceConstWeighting """ field = kwargs.pop('field', None) dist = kwargs.pop('dist', None) norm = kwargs.pop('norm', None) inner = kwargs.pop('inner', None) weighting = kwargs.pop('weighting', None) exponent = float(kwargs.pop('exponent', 2.0)) if kwargs: raise TypeError( f'got unexpected keyword arguments: {kwargs}' ) # Check validity of option combination (3 or 4 out of 4 must be None) if sum(x is None for x in (dist, norm, inner, weighting)) < 3: raise ValueError('invalid combination of options weighting, ' 'dist, norm and inner') if any(x is not None for x in (dist, norm, inner)) and exponent != 2.0: raise ValueError('`exponent` cannot be used together with ' 'inner, norm or dist') # Make a power space if the second argument is an integer. # For the case that the integer is 0, we already set the field here. if len(spaces) == 2 and isinstance(spaces[1], Integral): field = spaces[0].field spaces = [spaces[0]] * spaces[1] # Validate the space arguments if not all(isinstance(spc, LinearSpace) for spc in spaces): raise TypeError( "all arguments must be `LinearSpace` instances," + " or the first argument must be `LinearSpace`" + f" and the second integer; got {spaces}") if not all(spc.field == spaces[0].field for spc in spaces): raise ValueError('all spaces must have the same field') # Assign spaces and field self.__spaces = tuple(spaces) # Cache for efficiency self.__is_power_space = all(spc == self.spaces[0] for spc in self.spaces[1:]) # Assing or infer field if field is None: if len(self) == 0: raise ValueError('no spaces provided, cannot deduce field') else: field = self.spaces[0].field super().__init__(field) # Assign weighting if weighting is not None: if isinstance(weighting, Weighting): self.__weighting = weighting elif np.isscalar(weighting): self.__weighting = ProductSpaceConstWeighting( weighting, exponent) elif weighting is None: # Need to wait until dist, norm and inner are handled pass else: # last possibility: make a product space element arr = np.asarray(weighting) if arr.dtype == object: raise ValueError(f"invalid weighting argument {weighting}") if arr.ndim == 1: self.__weighting = ProductSpaceArrayWeighting( arr, exponent) else: raise ValueError( f"weighting array has {arr.ndim} dimensions, expected 1") elif dist is not None: self.__weighting = ProductSpaceCustomDist(dist) elif norm is not None: self.__weighting = ProductSpaceCustomNorm(norm) elif inner is not None: self.__weighting = ProductSpaceCustomInner(inner) else: # all None -> no weighing self.__weighting = ProductSpaceConstWeighting(1.0, exponent)
def __len__(self): """Return ``len(self)``. This length is the number of spaces at the top level only, and is equal to ``self.shape[0]``. """ return len(self.spaces) def _elementwise_num_operation(self, operation:str , x1: LinearSpaceElement | Number , x2: None | LinearSpaceElement | Number = None , out=None , namespace=None , **kwargs ): """ Internal helper function to implement the __magic_functions__ (such as __add__). Parameters ---------- x1 : ProductSpaceElement, TensorSpaceElement, int, float, complex Left operand x2 : ProductSpaceElement, TensorSpaceElement, int, float, complex Right operand operation: str Attribute of the array namespace out : ProductSpaceElement, Optional ProductSpaceElement for out-of-place operations Returns ------- ProductSpaceElement The result of the operation `operation` wrapped in a space with the right datatype. """ if self.field is None: raise NotImplementedError("The space has no field.") if out is not None: if not isinstance(out, ProductSpaceElement): raise TypeError(f"Output argument for ProductSpace arithmetic must be a product space. {type(out)=}") assert len(out.parts) == len(self) def _dtype_adaptive_wrapper(new_parts): if all(xln.space == spc for xln, spc in zip(new_parts, self)): return self.element(new_parts) else: # The `xl.space._elementwise_num_operation` may change the dtype, and thus the # part-space. For example, the `isfinite` function has boolean results. # In this case, the resulting product space also has the new dtype, which we # accomplish by creating the new space on the spot. new_space = ProductSpace(*[xln.space for xln in new_parts]) return new_space.element(new_parts) if x2 is None: if out is None: return _dtype_adaptive_wrapper([ xl.space._elementwise_num_operation(operation=operation, x1=xl, namespace=namespace, **kwargs) for xl in x1.parts ]) else: for i, xl in enumerate(x1.parts): xl.space._elementwise_num_operation( operation=operation, x1=xl, out=out.parts[i], namespace=namespace, **kwargs) return out from odl.core.operator import Operator if isinstance(x2, Operator): warnings.warn("The composition of a LinearSpaceElement and an Operator using the * operator is deprecated and will be removed in future ODL versions. Please replace * with @.") return x2.__rmul__(x1) if isinstance(x1, ProductSpaceElement) and isinstance(x2, ProductSpaceElement): assert len(x1.parts) == len(x2.parts) if out is None: return _dtype_adaptive_wrapper([ xl.space._elementwise_num_operation(operation=operation, x1=xl, x2=xr, namespace=namespace, **kwargs) for xl, xr in zip(x1.parts, x2.parts) ]) else: for i, xl in enumerate(x1.parts): xr = x2.parts[i] xl.space._elementwise_num_operation( operation=operation, x1=xl, x2=xr, out=out.parts[i], namespace=namespace, **kwargs) return out elif isinstance(x1, ProductSpaceElement): if out is None: return _dtype_adaptive_wrapper([ x.space._elementwise_num_operation(operation=operation, x1=x, x2=x2, namespace=namespace, **kwargs) for x in x1.parts ]) else: for i, x in enumerate(x1.parts): x.space._elementwise_num_operation( operation=operation, x1=x, x2=x2, out=out.parts[i], namespace=namespace, **kwargs) return out elif isinstance(x2, ProductSpaceElement): if out is None: return _dtype_adaptive_wrapper([ x.space._elementwise_num_operation(operation=operation, x1=x1, x2=x, namespace=namespace, **kwargs) for x in x2.parts ]) else: for i, x in enumerate(x2.parts): x.space._elementwise_num_operation( operation=operation, x1=x1, x2=x, out=out.parts[i], namespace=namespace, **kwargs) return out else: raise TypeError("At least one of the arguments to `ProductSpace._elementwise_num_operation`" + f" should be a `ProductSpaceElement`, but got {type(x1)=}, {type(x2)=}") def _element_reduction(self, operation:str , x: "ProductSpaceElement" , **kwargs ): assert x in self, f"the input {x} does not belong to self {self}" part_results = np.array([ xp.space._element_reduction(operation, xp, **kwargs) for xp in x.parts ]) return getattr(np, operation)(part_results).item() @property def nbytes(self): """Total number of bytes in memory used by an element of this space.""" return sum(spc.nbytes for spc in self.spaces) @property def shape(self): """Total spaces per axis, computed recursively. The recursion ends at the fist level that does not have a shape. Examples -------- >>> r2, r3 = odl.rn(2), odl.rn(3) >>> pspace = odl.ProductSpace(r2, r3) >>> pspace.shape (2,) >>> pspace2 = odl.ProductSpace(pspace, 3) >>> pspace2.shape (3, 2) If the space is a "pure" product space, shape recurses all the way into the components: >>> r2_2 = odl.ProductSpace(r2, 3) >>> r2_2.shape (3, 2) """ if len(self) == 0: return () elif self.is_power_space: try: sub_shape = self[0].shape except AttributeError: sub_shape = () else: sub_shape = () return (len(self),) + sub_shape @property def size(self): """Total number of involved spaces, computed recursively. The recursion ends at the fist level that does not comprise a *power* space, i.e., which is not made of equal spaces. Examples -------- >>> r2, r3 = odl.rn(2), odl.rn(3) >>> pspace = odl.ProductSpace(r2, r3) >>> pspace.size 2 >>> pspace2 = odl.ProductSpace(pspace, 3) >>> pspace2.size 6 """ return (0 if self.shape == () else int(np.prod(self.shape, dtype='int64'))) @property def spaces(self): """A tuple containing all spaces.""" return self.__spaces @property def is_power_space(self): """``True`` if all member spaces are equal.""" return self.__is_power_space @property def exponent(self): """Exponent of the product space norm/dist, ``None`` for custom.""" return self.weighting.exponent @property def weighting(self): """This space's weighting scheme.""" return self.__weighting @property def is_weighted(self): """Return ``True`` if the space is not weighted by constant 1.0.""" return not ( isinstance(self.weighting, ProductSpaceConstWeighting) and self.weighting.const == 1.0) @property def dtype(self): """The data type of this space. This is only well defined if all subspaces have the same dtype. Raises ------ AttributeError If any of the subspaces does not implement `dtype` or if the dtype of the subspaces does not match. """ dtypes = [space.dtype for space in self.spaces] if all(dtype == dtypes[0] for dtype in dtypes): return dtypes[0] else: raise AttributeError("`dtype`'s of subspaces not equal") @property def supported_num_operation_paradigms(self) -> NumOperationParadigmSupport: """Whether in-place operations an out-of-place operations are supported depends on the subspaces. Only operations that are supported on all the subspaces will be supported on the product space. The style that is preferred on most subspaces (if any) will be chosen as preferred on the product space.""" paradigms = [space.supported_num_operation_paradigms for space in self.spaces] ip_supported = True ip_prefers = 0 oop_supported = True oop_prefers = 0 # Check for all of the subspaces whether they support each paradigm, # and count how many of them prefer each. for parad in paradigms: if parad.in_place == NumOperationParadigmSupport.NOT_SUPPORTED: ip_supported = False elif parad.in_place == NumOperationParadigmSupport.PREFERRED: ip_prefers += 1 if parad.out_of_place == NumOperationParadigmSupport.NOT_SUPPORTED: oop_supported = False elif parad.out_of_place == NumOperationParadigmSupport.PREFERRED: oop_prefers += 1 # Support in-place updates if all subspaces support them. # Prefer them if a majority of the subspaces do. if ip_supported: if ip_prefers > oop_prefers: in_place_support = NumOperationParadigmSupport.PREFERRED else: in_place_support = NumOperationParadigmSupport.SUPPORTED else: in_place_support = NumOperationParadigmSupport.NOT_SUPPORTED # Support out-of-place calculations if all subspaces support them. # Prefer them if a majority of the subspaces do. if oop_supported: if oop_prefers > ip_prefers: oo_place_support = NumOperationParadigmSupport.PREFERRED else: oo_place_support = NumOperationParadigmSupport.SUPPORTED else: oo_place_support = NumOperationParadigmSupport.NOT_SUPPORTED return SupportedNumOperationParadigms( in_place=in_place_support, out_of_place=oo_place_support) @property def is_real(self): """True if this is a space of real valued vectors.""" return all(spc.is_real for spc in self.spaces) @property def is_complex(self): """True if this is a space of complex valued vectors.""" return all(spc.is_complex for spc in self.spaces) @property def real_space(self): """Variant of this space with real dtype.""" return ProductSpace(*[space.real_space for space in self.spaces]) @property def complex_space(self): """Variant of this space with complex dtype.""" return ProductSpace(*[space.complex_space for space in self.spaces])
[docs] def astype(self, dtype): """Return a copy of this space with new ``dtype``. Parameters ---------- dtype : Scalar data type of the returned space. Can be provided in any way the `numpy.dtype` constructor understands, e.g. as built-in type or as a string. Data types with non-trivial shapes are not allowed. Returns ------- newspace : `ProductSpace` Version of this space with given data type. """ if dtype is None: # Need to filter this out since Numpy iterprets it as 'float' raise ValueError('`None` is not a valid data type') dtype = np.dtype(dtype) current_dtype = getattr(self, 'dtype', object) if dtype == current_dtype: return self else: return ProductSpace(*[space.astype(dtype) for space in self.spaces])
[docs] def element(self, inp=None, copy=True): """Create an element in the product space. Parameters ---------- inp : optional If ``inp`` is ``None``, a new element is created from scratch by allocation in the spaces. If ``inp`` is already an element of this space, it is re-wrapped. Otherwise, a new element is created from the components by calling the ``element()`` methods in the component spaces. copy : bool, optional If ``True``, data may be copied from one representation to another in order to satisfy the requirements of the space and its subspaces. This is flexible but can cause poor performance. If ``False``, a ``TypeError`` is Returns ------- element : `ProductSpaceElement` The new element Examples -------- >>> r2, r3 = odl.rn(2), odl.rn(3) >>> vec_2, vec_3 = r2.element(), r3.element() >>> r2x3 = ProductSpace(r2, r3) >>> vec_2x3 = r2x3.element() >>> vec_2.space == vec_2x3[0].space True >>> vec_3.space == vec_2x3[1].space True Create an element of the product space >>> r2, r3 = odl.rn(2), odl.rn(3) >>> prod = ProductSpace(r2, r3) >>> x2 = r2.element([1, 2]) >>> x3 = r3.element([1, 2, 3]) >>> x = prod.element([x2, x3]) >>> x ProductSpace(rn(2), rn(3)).element([ [ 1., 2.], [ 1., 2., 3.] ]) """ # If data is given as keyword arg, prefer it over arg list if inp is None: inp = [space.element() for space in self.spaces] if inp in self: return inp if isinstance(inp, Number): inp = [space.element(inp) for space in self.spaces] if len(inp) != len(self): # Here, we handle the case where the user provides an input with a # single element that we will try to broadcast to all of the parts # of the ProductSpace. if len(inp) == 1 and copy: parts = [space.element(inp[0]) for space in self.spaces] else: raise ValueError(f"length of `inp` {len(inp)} does not match length of space {len(self)}") elif (all(isinstance(v, LinearSpaceElement) and v.space == space for v, space in zip(inp, self.spaces))): parts = list(inp) elif len(inp) == len(self): # Delegate constructors parts = [space.element(arg, copy=copy) for arg, space in zip(inp, self.spaces)] else: raise TypeError(f"input {inp} not a sequence of elements of the " + "component spaces") return self.element_type(self, parts)
@property def examples(self): """Return examples from all sub-spaces.""" for examples in product(*[spc.examples for spc in self.spaces]): name = ', '.join(name for name, _ in examples) element = self.element([elem for _, elem in examples]) yield (name, element)
[docs] def zero(self): """Create the zero element of the product space. The i-th component of the product space zero element is the zero element of the i-th space in the product. Parameters ---------- None Returns ------- zero : ProductSpaceElement The zero element in the product space. Examples -------- >>> r2, r3 = odl.rn(2), odl.rn(3) >>> zero_2, zero_3 = r2.zero(), r3.zero() >>> r2x3 = ProductSpace(r2, r3) >>> zero_2x3 = r2x3.zero() >>> zero_2 == zero_2x3[0] True >>> zero_3 == zero_2x3[1] True """ return self.element([space.zero() for space in self.spaces])
[docs] def one(self): """Create the one element of the product space. The i-th component of the product space one element is the one element of the i-th space in the product. Parameters ---------- None Returns ------- one : ProductSpaceElement The one element in the product space. Examples -------- >>> r2, r3 = odl.rn(2), odl.rn(3) >>> one_2, one_3 = r2.one(), r3.one() >>> r2x3 = ProductSpace(r2, r3) >>> one_2x3 = r2x3.one() >>> one_2 == one_2x3[0] True >>> one_3 == one_2x3[1] True """ return self.element([space.one() for space in self.spaces])
[docs] def _lincomb(self, a, x, b, y, out): """Linear combination ``out = a*x + b*y``.""" if out is None: return self.element([ space._lincomb(a, xp, b, yp, out=None) for space, xp, yp in zip(self.spaces, x.parts, y.parts)]) for space, xp, yp, outp in zip(self.spaces, x.parts, y.parts, out.parts): space._lincomb(a, xp, b, yp, outp)
[docs] def _dist(self, x1, x2): """Distance between two elements.""" return self.weighting.dist(x1, x2)
[docs] def _norm(self, x): """Norm of an element.""" return self.weighting.norm(x)
[docs] def _inner(self, x1, x2): """Inner product of two elements.""" return self.weighting.inner(x1, x2)
[docs] def _multiply(self, x1, x2, out): """Product ``out = x1 * x2``.""" if out is None: return self.element([ spc._multiply(xp, yp, out=None) for spc, xp, yp in zip(self.spaces, x1.parts, x2.parts)]) for spc, xp, yp, outp in zip(self.spaces, x1.parts, x2.parts, out.parts): spc._multiply(xp, yp, outp)
[docs] def _divide(self, x1, x2, out): """Quotient ``out = x1 / x2``.""" if out is None: return self.element([ spc._divide(xp, yp, out=None) for spc, xp, yp in zip(self.spaces, x1.parts, x2.parts)]) for spc, xp, yp, outp in zip(self.spaces, x1.parts, x2.parts, out.parts): spc._divide(xp, yp, outp)
[docs] def __eq__(self, other): """Return ``self == other``. Returns ------- equals : bool ``True`` if ``other`` is a `ProductSpace` instance, has the same length and the same factors. ``False`` otherwise. Examples -------- >>> r2, r3 = odl.rn(2), odl.rn(3) >>> rn, rm = odl.rn(2), odl.rn(3) >>> r2x3, rnxm = ProductSpace(r2, r3), ProductSpace(rn, rm) >>> r2x3 == rnxm True >>> r3x2 = ProductSpace(r3, r2) >>> r2x3 == r3x2 False >>> r5 = ProductSpace(*[odl.rn(1)]*5) >>> r2x3 == r5 False >>> r5 = odl.rn(5) >>> r2x3 == r5 False """ if other is self: return True else: return (isinstance(other, ProductSpace) and len(self) == len(other) and self.weighting == other.weighting and all(x == y for x, y in zip(self.spaces, other.spaces)))
def __hash__(self): """Return ``hash(self)``.""" return hash((type(self), self.spaces, self.weighting))
[docs] def __getitem__(self, indices): """Return ``self[indices]``. Examples -------- Integers are used to pick components, slices to pick ranges: >>> r2, r3, r4 = odl.rn(2), odl.rn(3), odl.rn(4) >>> pspace = odl.ProductSpace(r2, r3, r4) >>> pspace[1] rn(3) >>> pspace[1:] ProductSpace(rn(3), rn(4)) With lists, arbitrary components can be stacked together: >>> pspace[[0, 2, 1, 2]] ProductSpace(rn(2), rn(4), rn(3), rn(4)) Tuples, i.e. multi-indices, will recursively index higher-order product spaces. However, remaining indices cannot be passed down to component spaces that are not product spaces: >>> pspace2 = odl.ProductSpace(pspace, 3) # 2nd order product space >>> pspace2 ProductSpace(ProductSpace(rn(2), rn(3), rn(4)), 3) >>> pspace2[0] ProductSpace(rn(2), rn(3), rn(4)) >>> pspace2[1, 0] rn(2) >>> pspace2[:-1, 0] ProductSpace(rn(2), 2) """ if isinstance(indices, Integral): return self.spaces[indices] elif isinstance(indices, slice): return ProductSpace(*self.spaces[indices], field=self.field) elif isinstance(indices, tuple): # Use tuple indexing for recursive product spaces, i.e., # pspace[0, 0] == pspace[0][0] if not indices: return self idx = indices[0] if isinstance(idx, Integral): # Single integer in tuple, picking that space and passing # through the rest of the tuple. If the picked space # is not a product space and there are still indices left, # raise an error. space = self.spaces[idx] rest_indcs = indices[1:] if not rest_indcs: return space elif isinstance(space, ProductSpace): return space[rest_indcs] else: raise IndexError("too many indices for recursive product space:" + f" remaining indices {rest_indcs}") elif isinstance(idx, slice): # Doing the same as with single integer with all spaces # in the slice, but wrapping the result into a ProductSpace. spaces = self.spaces[idx] rest_indcs = indices[1:] if len(spaces) == 0 and rest_indcs: # Need to catch this situation since the code further # down doesn't trigger an error raise IndexError(f"too many indices for recursive product space: remaining indices {rest_indcs}") if not rest_indcs: return ProductSpace(*spaces) elif all(isinstance(space, ProductSpace) for space in spaces): return ProductSpace( *(space[rest_indcs] for space in spaces), field=self.field) else: raise IndexError("too many indices for recursive product space:" +f" remaining indices {rest_indcs}") else: raise TypeError("index tuple can only contain" + " integers or slices") elif isinstance(indices, list): return ProductSpace(*[self.spaces[i] for i in indices], field=self.field) raise TypeError(f"`indices` must be integer, slice, tuple or list, got {indices}")
def __str__(self): """Return ``str(self)``.""" if len(self) == 0: return '{}' if self.is_power_space: return f'({self.spaces[0]}) ** {len(self)}' return ' x '.join(str(space) for space in self.spaces) def __repr__(self): """Return ``repr(self)``.""" weight_str = self.weighting.repr_part edgeitems = np.get_printoptions()['edgeitems'] if len(self) == 0: posargs = [] posmod = '' optargs = [('field', self.field, None)] oneline = True elif self.is_power_space: posargs = [self.spaces[0], len(self)] posmod = '!r' optargs = [] oneline = True elif self.size <= 2 * edgeitems: posargs = self.spaces posmod = '!r' optargs = [] argstr = ', '.join(repr(s) for s in self.spaces) oneline = (len(argstr + weight_str) <= 40 and '\n' not in argstr + weight_str) else: posargs = (self.spaces[:edgeitems] + ('...',) + self.spaces[-edgeitems:]) posmod = ['!r'] * edgeitems + ['!s'] + ['!r'] * edgeitems optargs = [] oneline = False if oneline: inner_str = signature_string(posargs, optargs, sep=', ', mod=[posmod, '!r']) if weight_str: inner_str = ', '.join([inner_str, weight_str]) return f"{self.__class__.__name__}({inner_str})" else: inner_str = signature_string(posargs, optargs, sep=',\n', mod=[posmod, '!r']) if weight_str: inner_str = ',\n'.join([inner_str, weight_str]) return f"{self.__class__.__name__}(\n{indent(inner_str)}\n)" @property def element_type(self): """`ProductSpaceElement`""" return ProductSpaceElement
[docs] class ProductSpaceElement(LinearSpaceElement): """Elements of a `ProductSpace`."""
[docs] def __init__(self, space, parts): """Initialize a new instance.""" super().__init__(space) self.__parts = tuple(parts)
@property def parts(self): """Parts of this product space element.""" return self.__parts @property def shape(self): """Number of values per axis in ``self``, computed recursively. The recursion ends at the fist level that does not have a shape. Raises ------ ValueError If a `ProductSpace` is encountered that is not a power space. See Also -------- ProductSpace.shape Examples -------- >>> r4_3 = odl.ProductSpace(odl.rn(4), 3) >>> x = r4_3.element() >>> x.shape (3, 4) >>> r4_2_3 = odl.ProductSpace(r4_3, 2) >>> y = r4_2_3.element() >>> y.shape (2, 3, 4) """ return self.space.shape @property def ndim(self): """Number axes in ``self``, computed recursively. Raises ------ ValueError If a `ProductSpace` is encountered that is not a power space. See Also -------- shape Examples -------- >>> r4_3 = odl.ProductSpace(odl.rn(4), 3) >>> x = r4_3.element() >>> x.ndim 2 >>> r4_2_3 = odl.ProductSpace(r4_3, 2) >>> y = r4_2_3.element() >>> y.ndim 3 """ return len(self.shape) @property def size(self): """Total number of involved spaces, computed recursively. See Also -------- ProductSpace.size """ return int(np.prod(self.shape)) @property def dtype(self): """The data type of the space of this element.""" return self.space.dtype def _assign(self, other, avoid_deep_copy): """Assign the values of ``other``, which is assumed to be in the same product space, to ``self``.""" for tgt, src in zip(self.parts, other.parts): tgt.assign(src, avoid_deep_copy=avoid_deep_copy)
[docs] def set_zero(self): """Set this element to zero. See Also -------- LinearSpace.zero """ for tgt in self.parts: tgt.set_zero() return self
def __len__(self): """Return ``len(self)``.""" return len(self.space) @property def nbytes(self): """Total number of bytes in memory used by this element.""" return self.space.nbytes
[docs] def __eq__(self, other): """Return ``self == other``. Overrides the default `LinearSpace` method since it is implemented with the distance function, which is prone to numerical errors. This function checks equality per component. """ if other is self: return True elif other not in self.space: return False else: return all(sp == op for sp, op in zip(self.parts, other.parts))
[docs] def __getitem__(self, indices): """Return ``self[indices]``.""" if isinstance(indices, Integral): return self.parts[indices] elif isinstance(indices, slice): return self.space[indices].element(self.parts[indices]) elif isinstance(indices, list): out_parts = [self.parts[i] for i in indices] return self.space[indices].element(out_parts) elif isinstance(indices, tuple): if len(indices) == 0: return ProductSpace().element() elif len(indices) == 1: # Tuple with a single entry - we just unpack and delegate return self[indices[0]] else: # Tuple with multiple entries if isinstance(indices[0], Integral): # In case the first entry is an integer, we drop the # axis and return directly from `parts` return self.parts[indices[0]][indices[1:]] else: # indices[0] is a slice or list. We first retrieve the # parts indexed in this axis. # In any case we know that we want to keep this axis. if isinstance(indices[0], list): part = [self.parts[i] for i in indices[0]] else: part = self.parts[indices[0]] if (len(indices[1:]) == 1 and not all(isinstance(p, ProductSpaceElement) for p in part)): # This case means we have "hit the bottom", i.e., # there are non-ProductSpaces involved. In order # not to retrieve scalar values from these # elements, we use a slice of size 1. idx = indices[1] indexed = [p[idx:idx + 1] for p in part] else: # Here we're still in the "product space chain", # so we can use recursion to go on. indexed = [p[indices[1:]] for p in part] # Finally make a wrapping space for the indexed elements new_space = ProductSpace(*(p.space for p in indexed)) return new_space.element(indexed) else: raise TypeError(f"bad index type {type(indices)}")
[docs] def __setitem__(self, indices, values): """Implement ``self[indices] = values``.""" # Get the parts to which we assign values if isinstance(indices, Integral): indexed_parts = (self.parts[indices],) values = (values,) elif isinstance(indices, slice): indexed_parts = self.parts[indices] elif isinstance(indices, list): indexed_parts = tuple(self.parts[i] for i in indices) elif isinstance(indices, tuple): if len(indices) == 0: return else: # We need to explicitly use __setitem__ here, otherwise # __getitem__ is used and assigned to, which fails if # a space like rn(3) is indexed at the very end. part = self.parts[indices[0]] if isinstance(part, LinearSpaceElement): part.__setitem__(indices[1:], values) else: # part is a tuple for p in part: p.__setitem__(indices[1:], values) return else: raise TypeError(f"bad index type {type(indices)}") # Do the assignment, with broadcasting if desired try: iter(values) except TypeError: # `values` is not iterable, assume it can be assigned to # all indexed parts for p in indexed_parts: p[:] = values else: # `values` is iterable; it could still represent a single # element of a power space. if self.space.is_power_space and values in self.space[0]: # Broadcast a single element across a power space for p in indexed_parts: p[:] = values else: # Now we really have one assigned value per part if len(values) != len(indexed_parts): raise ValueError( f"length of iterable `values` not equal to number of indexed parts ({len(values)} != {len(indexed_parts)})" ) for p, v in zip(indexed_parts, values): p[:] = v
[docs] def asarray(self, out=None, must_be_contiguous=False): """Extract the data of this vector as a backend-specific array. Only available if `is_power_space` is True. The ordering is such that it commutes with indexing:: self[ind].asarray() == self.asarray()[ind] Parameters ---------- out : Arraylike, optional Array in which the result should be written in-place. Has to be contiguous and of the correct backend, dtype and shape. Raises ------ ValueError If `is_power_space` is false. Examples -------- >>> spc = odl.ProductSpace(odl.rn(3), 2) >>> x = spc.element([[ 1., 2., 3.], ... [ 4., 5., 6.]]) >>> x.asarray() array([[ 1., 2., 3.], [ 4., 5., 6.]]) """ if not self.space.is_power_space: raise ValueError('cannot use `asarray` if `space.is_power_space` ' 'is `False`') else: representative_array, representative_backend = get_array_and_backend(self.parts[0]) if out is None: # We are assuming that `empty` always produces a contiguous array, # so no need to ensure it separately. out = representative_backend.array_namespace.empty( shape=self.shape, dtype=self.dtype, device=representative_array.device) out[0] = representative_array for i in range(1, len(self)): self.parts[i].asarray(out = out[i]) return out
[docs] @contextmanager def writable_array(self, must_be_contiguous: bool =False): """ Expose the data underlying this element as a single array that can be modified in-place and the changes kept. Unlike in the `TensorSpace` case, which always uses a single array for storage, this is in general not possible for a product space, only for the special case of a power-space. Compare with `asarray`.""" arr = None try: arr = self.asarray(must_be_contiguous=must_be_contiguous) yield arr finally: if arr is not None: for i in range(1, len(self)): self.parts[i]._assign(self.parts[i].space.element(arr[i]))
@property def real(self): """Real part of the element. The real part can also be set using ``x.real = other``, where ``other`` is array-like or scalar. Examples -------- >>> space = odl.ProductSpace(odl.cn(3), odl.cn(2)) >>> x = space.element([[1 + 1j, 2, 3 - 3j], ... [-1 + 2j, -2 - 3j]]) >>> x.real ProductSpace(rn(3), rn(2)).element([ [ 1., 2., 3.], [-1., -2.] ]) The real part can also be set using different array-like types: >>> x.real = space.real_space.zero() >>> x ProductSpace(cn(3), cn(2)).element([ [ 0.+1.j, 0.+0.j, 0.-3.j], [ 0.+2.j, 0.-3.j] ]) >>> x.real = 1.0 >>> x ProductSpace(cn(3), cn(2)).element([ [ 1.+1.j, 1.+0.j, 1.-3.j], [ 1.+2.j, 1.-3.j] ]) >>> x.real = [[2, 3, 4], [5, 6]] >>> x ProductSpace(cn(3), cn(2)).element([ [ 2.+1.j, 3.+0.j, 4.-3.j], [ 5.+2.j, 6.-3.j] ]) """ real_part = [part.real for part in self.parts] return self.space.real_space.element(real_part) @real.setter def real(self, newreal): """Setter for the real part. This method is invoked by ``x.real = other``. Parameters ---------- newreal : array-like or scalar Values to be assigned to the real part of this element. """ try: iter(newreal) except TypeError: # `newreal` is not iterable, assume it can be assigned to # all indexed parts for part in self.parts: part.real = newreal return if self.space.is_power_space: try: # Set same value in all parts for part in self.parts: part.real = newreal except (AttributeError, ValueError, TypeError): # Iterate over all parts and set them separately for part, new_re in zip(self.parts, newreal): part.real = new_re elif len(newreal) == len(self): for part, new_re in zip(self.parts, newreal): part.real = new_re else: raise ValueError( f"dimensions of the new real part does not match the space," + f" got element {newreal} to set real part of {self}" ) @property def imag(self): """Imaginary part of the element. The imaginary part can also be set using ``x.imag = other``, where ``other`` is array-like or scalar. Examples -------- >>> space = odl.ProductSpace(odl.cn(3), odl.cn(2)) >>> x = space.element([[1 + 1j, 2, 3 - 3j], ... [-1 + 2j, -2 - 3j]]) >>> x.imag ProductSpace(rn(3), rn(2)).element([ [ 1., 0., -3.], [ 2., -3.] ]) The imaginary part can also be set using different array-like types: >>> x.imag = space.real_space.zero() >>> x ProductSpace(cn(3), cn(2)).element([ [ 1.+0.j, 2.+0.j, 3.+0.j], [-1.+0.j, -2.+0.j] ]) >>> x.imag = 1.0 >>> x ProductSpace(cn(3), cn(2)).element([ [ 1.+1.j, 2.+1.j, 3.+1.j], [-1.+1.j, -2.+1.j] ]) >>> x.imag = [[2, 3, 4], [5, 6]] >>> x ProductSpace(cn(3), cn(2)).element([ [ 1.+2.j, 2.+3.j, 3.+4.j], [-1.+5.j, -2.+6.j] ]) """ imag_part = [part.imag for part in self.parts] return self.space.real_space.element(imag_part) @imag.setter def imag(self, newimag): """Setter for the imaginary part. This method is invoked by ``x.imag = other``. Parameters ---------- newimag : array-like or scalar Values to be assigned to the imaginary part of this element. """ try: iter(newimag) except TypeError: # `newimag` is not iterable, assume it can be assigned to # all indexed parts for part in self.parts: part.imag = newimag return if self.space.is_power_space: try: # Set same value in all parts for part in self.parts: part.imag = newimag except (AttributeError, ValueError, TypeError): # Iterate over all parts and set them separately for part, new_im in zip(self.parts, newimag): part.imag = new_im elif len(newimag) == len(self): for part, new_im in zip(self.parts, newimag): part.imag = new_im else: raise ValueError("dimensions of the new imaginary part does not match the" + f" space, got element {newimag} to set real part of {self}")
[docs] def conj(self): """Complex conjugate of the element.""" complex_conj = [part.conj() for part in self.parts] return self.space.element(complex_conj)
def __str__(self): """Return ``str(self)``.""" return repr(self) def __repr__(self): """Return ``repr(self)``. Examples -------- >>> from odl import rn # need to import rn into namespace >>> r2, r3 = odl.rn(2), odl.rn(3) >>> r2x3 = ProductSpace(r2, r3) >>> x = r2x3.element([[1, 2], [3, 4, 5]]) >>> eval(repr(x)) == x True The result is readable: >>> x ProductSpace(rn(2), rn(3)).element([ [ 1., 2.], [ 3., 4., 5.] ]) Nestled spaces work as well: >>> X = ProductSpace(r2x3, r2x3) >>> x = X.element([[[1, 2], [3, 4, 5]],[[1, 2], [3, 4, 5]]]) >>> eval(repr(x)) == x True >>> x ProductSpace(ProductSpace(rn(2), rn(3)), 2).element([ [ [ 1., 2.], [ 3., 4., 5.] ], [ [ 1., 2.], [ 3., 4., 5.] ] ]) """ inner_str = '[\n' if len(self) < 5: inner_str += ',\n'.join(f"{_indent(_strip_space(part))}" for part in self.parts) else: inner_str += ',\n'.join(f"{_indent(_strip_space(part))}" for part in self.parts[:3]) inner_str += ',\n ...\n' inner_str += ',\n'.join(f"{_indent(_strip_space(part))}" for part in self.parts[-1:]) inner_str += '\n]' return f'{self.space}.element({inner_str})'
[docs] def show(self, title=None, indices=None, **kwargs): """Display the parts of this product space element graphically. Parameters ---------- title : string, optional Title of the figures indices : int, slice, tuple or list, optional Display parts of ``self`` in the way described in the following. A single list of integers selects the corresponding parts of this vector. For other tuples or lists, the first entry indexes the parts of this vector, and the remaining entries (if any) are used to slice into the parts. Handling those remaining indices is up to the ``show`` methods of the parts to be displayed. The types of the first entry trigger the following behaviors: - ``int``: take the part corresponding to this index - ``slice``: take a subset of the parts - ``None``: equivalent to ``slice(None)``, i.e., everything Typical use cases are displaying of selected parts, which can be achieved with a list, e.g., ``indices=[0, 2]`` for parts 0 and 2, and plotting of all parts sliced in a certain way, e.g., ``indices=[None, 20, None]`` for showing all parts sliced with indices ``[20, None]``. A single ``int``, ``slice``, ``list`` or ``None`` object indexes the parts only, i.e., is treated roughly as ``(indices, Ellipsis)``. In particular, for ``None``, all parts are shown with default slicing. in_figs : sequence of `matplotlib.figure.Figure`, optional Update these figures instead of creating new ones. Typically the return value of an earlier call to ``show`` is used for this parameter. kwargs Additional arguments passed on to the ``show`` methods of the parts. Returns ------- figs : tuple of `matplotlib.figure.Figure` The resulting figures. In an interactive shell, they are automatically displayed. See Also -------- odl.core.discr.discr_space.DiscretizedSpaceElement.show : Display of a discretized function odl.core.space.base_tensors.Tensor.show : Display of sequence type data odl.core.util.graphics.show_discrete_data : Underlying implementation """ if title is None: title = 'ProductSpaceElement' if indices is None: if len(self) < 5: indices = list(range(len(self))) else: indices = list(np.linspace(0, len(self) - 1, 4, dtype=int)) else: if (isinstance(indices, tuple) or (isinstance(indices, list) and not all(isinstance(idx, Integral) for idx in indices))): # Tuples or lists containing non-integers index by axis. # We use the first index for the current pspace and pass # on the rest. indices, kwargs['indices'] = indices[0], indices[1:] # Support `indices=[None, 0, None]` like syntax (`indices` is # the first entry as of now in that case) if indices is None: indices = slice(None) if isinstance(indices, slice): indices = list(range(*indices.indices(len(self)))) elif isinstance(indices, Integral): indices = [indices] else: # Use `indices` as-is pass in_figs = kwargs.pop('fig', None) in_figs = [None] * len(indices) if in_figs is None else in_figs figs = [] parts = self[indices] if len(parts) == 0: return () elif len(parts) == 1: # Don't extend the title if there is only one plot fig = parts[0].show(title=title, fig=in_figs[0], **kwargs) figs.append(fig) else: # Extend titles by indexed part to make them distinguishable for i, part, fig in zip(indices, parts, in_figs): fig = part.show(title=f'{title}. Part {i}', fig=fig, **kwargs) figs.append(fig) return tuple(figs)
[docs] class ProductSpaceArrayWeighting(ArrayWeighting): """Array weighting for `ProductSpace`. This class defines a weighting that has a different value for each index defined in a given space. See ``Notes`` for mathematical details. """
[docs] def __init__(self, array, exponent=2.0): r"""Initialize a new instance. Parameters ---------- array : 1-dim. `array-like` Weighting array of the inner product. exponent : positive float, optional Exponent of the norm. For values other than 2.0, no inner product is defined. Notes ----- - For exponent 2.0, a new weighted inner product with array :math:`w` is defined as .. math:: \langle x, y \rangle_w = \langle w \odot x, y \rangle with component-wise multiplication :math:`w \odot x`. For other exponents, only ``norm`` and ``dist`` are defined. In the case of exponent ``inf``, the weighted norm is .. math:: \|x\|_{w,\infty} = \|w \odot x\|_\infty, otherwise it is .. math:: \|x\|_{w,p} = \|w^{1/p} \odot x\|_p. - Note that this definition does **not** fulfill the limit property in :math:`p`, i.e., .. math:: \|x\|_{w,p} \not\to \|x\|_{w,\infty} \quad\text{for } p \to \infty unless :math:`w = (1,...,1)`. The reason for this choice is that the alternative with the limit property consists in ignoring the weights altogether. - The array may only have positive entries, otherwise it does not define an inner product or norm, respectively. This is not checked during initialization. """ super().__init__(array, impl=None, device=None, exponent=exponent)
[docs] def inner(self, x1, x2): """Calculate the array-weighted inner product of two elements. Parameters ---------- x1, x2 : `ProductSpaceElement` Elements whose inner product is calculated. Returns ------- inner : float or complex The inner product of the two provided elements. """ if self.exponent != 2.0: raise NotImplementedError(f"no inner product defined for exponent != 2 (got {self.exponent})") inners = np.fromiter( (x1i.inner(x2i) for x1i, x2i in zip(x1, x2)), dtype=x1[0].space.dtype_identifier, count=len(x1)) inner = np.dot(inners, self.array) if is_real_dtype(x1[0].dtype): return float(inner) else: return complex(inner)
[docs] def norm(self, x): """Calculate the array-weighted norm of an element. Parameters ---------- x : `ProductSpaceElement` Element whose norm is calculated. Returns ------- norm : float The norm of the provided element. """ if self.exponent == 2.0: norm_squared = self.inner(x, x).real # TODO: optimize?! return np.sqrt(norm_squared) else: norms = np.fromiter( (xi.norm() for xi in x), dtype=np.float64, count=len(x)) if self.exponent in (1.0, float('inf')): norms *= self.array else: norms *= self.array ** (1.0 / self.exponent) return float(np.linalg.norm(norms, ord=self.exponent))
[docs] class ProductSpaceConstWeighting(ConstWeighting): """Constant weighting for `ProductSpace`. """
[docs] def __init__(self, constant, exponent=2.0): r"""Initialize a new instance. Parameters ---------- constant : positive float Weighting constant of the inner product exponent : positive float, optional Exponent of the norm. For values other than 2.0, no inner product is defined. Notes ----- - For exponent 2.0, a new weighted inner product with constant :math:`c` is defined as .. math:: \langle x, y \rangle_c = c\, \langle x, y \rangle. For other exponents, only ``norm`` and ```dist`` are defined. In the case of exponent ``inf``, the weighted norm is .. math:: \|x\|_{c,\infty} = c\, \|x\|_\infty, otherwise it is .. math:: \|x\|_{c,p} = c^{1/p} \, \|x\|_p. - Note that this definition does **not** fulfill the limit property in :math:`p`, i.e., .. math:: \|x\|_{c,p} \not\to \|x\|_{c,\infty} \quad \text{for } p \to \infty unless :math:`c = 1`. The reason for this choice is that the alternative with the limit property consists in ignoring the weight altogether. - The constant must be positive, otherwise it does not define an inner product or norm, respectively. """ super().__init__( constant, impl=None, device=None, exponent=exponent)
[docs] def inner(self, x1, x2): """Calculate the constant-weighted inner product of two elements. Parameters ---------- x1, x2 : `ProductSpaceElement` Elements whose inner product is calculated. Returns ------- inner : float or complex The inner product of the two provided elements. """ if self.exponent != 2.0: raise NotImplementedError(f"no inner product defined for exponent != 2 (got {self.exponent})") accumulator = 0.0 # Manual loop, to avoid having to select a universally-applicable dtype for x1i, x2i in zip(x1, x2): accumulator = accumulator + x1i.inner(x2i) result = self.const * accumulator return x1.space.field.element(result)
[docs] def norm(self, x): """Calculate the constant-weighted norm of an element. Parameters ---------- x1 : `ProductSpaceElement` Element whose norm is calculated. Returns ------- norm : float The norm of the element. """ if self.exponent == 2.0: norm_squared = self.inner(x, x).real # TODO: optimize?! return np.sqrt(norm_squared) else: norms = np.fromiter( (xi.norm() for xi in x), dtype=np.float64, count=len(x)) if self.exponent in (1.0, float('inf')): return (self.const * float(np.linalg.norm(norms, ord=self.exponent))) else: return (self.const ** (1 / self.exponent) * float(np.linalg.norm(norms, ord=self.exponent)))
[docs] def dist(self, x1, x2): """Calculate the constant-weighted distance between two elements. Parameters ---------- x1, x2 : `ProductSpaceElement` Elements whose mutual distance is calculated. Returns ------- dist : float The distance between the elements. """ dnorms = np.fromiter( ((x1i - x2i).norm() for x1i, x2i in zip(x1, x2)), dtype=np.float64, count=len(x1)) if self.exponent == float('inf'): return self.const * np.linalg.norm(dnorms, ord=self.exponent) else: return (self.const ** (1 / self.exponent) * np.linalg.norm(dnorms, ord=self.exponent))
[docs] class ProductSpaceCustomInner(CustomInner): """Class for handling a user-specified inner products."""
[docs] def __init__(self, inner): """Initialize a new instance. Parameters ---------- inner : callable The inner product implementation. It must accept two `ProductSpaceElement` arguments, return a element from the field of the space (real or complex number) and satisfy the following conditions for all space elements ``x, y, z`` and scalars ``s``: - ``<x, y> = conj(<y, x>)`` - ``<s*x + y, z> = s * <x, z> + <y, z>`` - ``<x, x> = 0`` if and only if ``x = 0`` """ super().__init__( impl=None, inner=inner, device=None)
[docs] class ProductSpaceCustomNorm(CustomNorm): """Class for handling a user-specified norm on `ProductSpace`. Note that this removes ``inner``. """
[docs] def __init__(self, norm): """Initialize a new instance. Parameters ---------- norm : callable The norm implementation. It must accept a `ProductSpaceElement` argument, return a float and satisfy the following conditions for all space elements ``x, y`` and scalars ``s``: - ``||x|| >= 0`` - ``||x|| = 0`` if and only if ``x = 0`` - ``||s * x|| = |s| * ||x||`` - ``||x + y|| <= ||x|| + ||y||`` """ super().__init__(norm, impl=None, device=None)
[docs] class ProductSpaceCustomDist(CustomDist): """Class for handling a user-specified distance on `ProductSpace`. Note that this removes ``inner`` and ``norm``. """
[docs] def __init__(self, dist): """Initialize a new instance. Parameters ---------- dist : callable The distance function defining a metric on `ProductSpace`. It must accept two `ProductSpaceElement` arguments and fulfill the following mathematical conditions for any three space elements ``x, y, z``: - ``dist(x, y) >= 0`` - ``dist(x, y) = 0`` if and only if ``x = y`` - ``dist(x, y) = dist(y, x)`` - ``dist(x, y) <= dist(x, z) + dist(z, y)`` """ super(ProductSpaceCustomDist, self).__init__(dist, impl=None, device=None)
def _strip_space(x): """Strip the SPACE.element( ... ) part from a repr.""" r = repr(x) space_repr = f"{x.space}.element(" if r.startswith(space_repr) and r.endswith(')'): r = r[len(space_repr):-1] return r def _indent(x): """Indent a string by 4 characters.""" lines = x.splitlines() for i, line in enumerate(lines): lines[i] = ' ' + line return '\n'.join(lines) if __name__ == '__main__': from odl.core.util.testutils import run_doctests run_doctests()