Source code for odl.core.space.entry_points

# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Entry points for adding more spaces to ODL using external packages.

External packages can add an implementation of `TensorSpace` by hooking
into the setuptools entry point ``'odl.core.space'`` and exposing the methods
``tensor_space_impl`` and ``tensor_space_impl_names``.

This is used with functions such as `rn`, `cn`, `tensor_space` or
`uniform_discr` in order to allow arbitrary implementations.

See Also
--------
NumpyTensorSpace : Numpy-based implementation of `TensorSpace`
"""

# We want to import if the backends are actually available
# pylint: disable=import-outside-toplevel
# We want to use a global statement here
# pylint: disable=global-statement
# The global variable TENSOR_SPACE_IMPLS is modified in a condition, which triggers the pylint warning
# pylint: disable=global-variable-not-assigned

from odl.backends.arrays.npy_tensors import NumpyTensorSpace

# We don't expose anything to odl.core.space
__all__ = ()

IS_INITIALIZED = False
TENSOR_SPACE_IMPLS = {
    'numpy': NumpyTensorSpace
    }

def _initialize_if_needed():
    """Initialize ``TENSOR_SPACE_IMPLS`` if not already done."""
    global IS_INITIALIZED, TENSOR_SPACE_IMPLS
    if not IS_INITIALIZED:
        import importlib.util       
        torch_module = importlib.util.find_spec("torch")
        if torch_module is not None:
            try:
                from odl.backends.arrays.pytorch_tensors import PyTorchTensorSpace
                TENSOR_SPACE_IMPLS['pytorch'] = PyTorchTensorSpace
            except ModuleNotFoundError:
                pass
        IS_INITIALIZED = True


[docs] def tensor_space_impl_names(): """A tuple of strings with valid tensor space implementation names.""" _initialize_if_needed() return tuple(TENSOR_SPACE_IMPLS.keys())
[docs] def tensor_space_impl(impl): """Tensor space class corresponding to the given impl name. Parameters ---------- impl : str Name of the implementation, see `tensor_space_impl_names` for the full list. Returns ------- tensor_space_impl : type Class inheriting from `TensorSpace`. Raises ------ ValueError If ``impl`` is not a valid name of a tensor space imlementation. """ try: return TENSOR_SPACE_IMPLS[impl] except KeyError: raise KeyError( f"`impl` {impl} does not correspond to a valid tensor " "space implmentation" )
_initialize_if_needed()