Source code for odl.core.util.scipy_compatibility

# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

"""Scipy compatibility module"""

import warnings
from os import environ
if 'SCIPY_ARRAY_API' in environ and environ['SCIPY_ARRAY_API']=='1':
        pass
else:
    warnings.warn("The environment variable SCIPY_ARRAY_API must be set to 1."
                + " It should be by default when importing odl, but it seems that scipy was imported before odl."
                + " If not set, the array API support of scipy will be disabled,"
                + " meaning that function calls such as ``xlogy`` on GPU will error"
                + " and throw back pytorch Type errors."
                + " Please add the following lines before your first scipy import. \n"
                + 'from os import environ \n' \
                + 'environ["SCIPY_ARRAY_API"]=="1" \n ' \
                + '********End of Warning********',
              stacklevel=2)

import scipy
import numpy

__all__ = (
    'lambertw',
    'scipy_lambertw',
    'xlogy',
    )

def _helper(operation: str, x1, x2=None, out=None, namespace=scipy.special, **kwargs):
    """Proxy for calling a scipy function on the data underlying `x1` and
    (if provided) `x2` and wrapping the result as an element of the suitable space.

    Args:
        operation (str): Name of the function
        x1  (Tensor): _description_
        x2  (Tensor, optional): _description_. Defaults to None.
        out (Tensor, optional): Out argument for in-place operations. Defaults to None.
        namespace: scipy namespace to get the operation from. Defaults to scipy.special.

    Returns:
        Tensor
    """
    return x1.space._elementwise_num_operation(
        operation=operation, x1=x1, x2=x2, out=out, namespace=namespace, **kwargs)

[docs] def lambertw(x, k=0, tol=1e-8): """ Lambert W function. The Lambert W function W(z) is defined as the inverse function of w * exp(w). In other words, the value of W(z) is such that z = W(z) * exp(W(z)) for any complex number z. The Lambert W function is a multivalued function with infinitely many branches. Each branch gives a separate solution of the equation z = w exp(w). Here, the branches are indexed by the integer k. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.lambertw.html#scipy.special.lambertw """ return _helper('lambertw', x, k=k, tol=tol)
[docs] def scipy_lambertw(x, k=0, tol=1e-8): """ Lambert W function. The Lambert W function W(z) is defined as the inverse function of w * exp(w). In other words, the value of W(z) is such that z = W(z) * exp(W(z)) for any complex number z. The Lambert W function is a multivalued function with infinitely many branches. Each branch gives a separate solution of the equation z = w exp(w). Here, the branches are indexed by the integer k. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.lambertw.html#scipy.special.lambertw Note: This function is a direct call to scipy.special.lambertw on a Numpy Array! """ assert isinstance( x, numpy.ndarray ), ("Can only call scipy_lambertw on nd_array. For ODL Tensors, please use the function scipy_compatibility.lambertw" + f". Got {type(x)}") return scipy.special.lambertw(x, k, tol)
[docs] def xlogy(x1, x2, out=None): """Compute x*log(y) so that the result is 0 if x = 0. See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.xlogy.html """ return _helper('xlogy', x1=x1, x2=x2, out=out)