# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Utilities mainly for internal use."""
import contextlib
from collections import OrderedDict
from contextlib import contextmanager
from itertools import product
from packaging.requirements import Requirement
import numpy as np
from odl.core.util.print_utils import is_string
__all__ = (
'nd_iterator',
'conj_exponent',
'nullcontext',
'writable_array',
'run_from_ipython',
'npy_random_seed',
'unique',
)
[docs]
def nd_iterator(shape):
"""Iterator over n-d cube with shape.
Parameters
----------
shape : sequence of int
The number of points per axis
Returns
-------
nd_iterator : generator
Generator returning tuples of integers of length ``len(shape)``.
Examples
--------
>>> for pt in nd_iterator([2, 2]):
... print(pt)
(0, 0)
(0, 1)
(1, 0)
(1, 1)
"""
return product(*map(range, shape))
[docs]
def conj_exponent(exp):
"""Conjugate exponent ``exp / (exp - 1)``.
Parameters
----------
exp : positive float or inf
Exponent for which to calculate the conjugate. Must be
at least 1.0.
Returns
-------
conj : positive float or inf
Conjugate exponent. For ``exp=1``, return ``float('inf')``,
for ``exp=float('inf')`` return 1. In all other cases, return
``exp / (exp - 1)``.
"""
if exp == 1.0:
return float('inf')
elif exp == float('inf'):
return 1.0
else:
return exp / (exp - 1.0)
@contextmanager
def nullcontext(enter_result=None):
"""Backport of the Python >=3.7 trivial context manager.
See `the Python documentation
<https://docs.python.org/3/library/contextlib.html#contextlib.nullcontext>`_
for details.
"""
try:
yield enter_result
finally:
pass
try:
nullcontext = contextlib.nullcontext
except AttributeError:
pass
[docs]
@contextmanager
def writable_array(obj, must_be_contiguous: bool =False):
"""Context manager that casts `obj` to a backend-specific array and saves changes
made on that array back into `obj`.
Parameters
----------
obj : `array-like`
Object that should be made available as writable array.
It must be valid as input to `numpy.asarray` and needs to
support the syntax ``obj[:] = arr``.
must_be_contiguous : bool
Whether the writable array should guarantee standard C order.
Examples
--------
Usage with ODL vectors:
>>> space = odl.uniform_discr(0, 1, 3)
>>> x = space.element([1, 2, 3])
>>> with writable_array(x) as arr:
... arr += [1, 1, 1]
>>> x
uniform_discr(0.0, 1.0, 3).element([ 2., 3., 4.])
Note that the changes are in general only saved upon exiting the
context manager. Before, the input object may remain unchanged.
"""
if isinstance(obj, np.ndarray):
if must_be_contiguous and not obj.data.c_contiguous:
# Needs to convert to contiguous array
arr = np.ascontiguousarray(obj)
try:
yield arr
finally:
obj[:] = arr
else:
try:
yield obj
finally:
pass
else:
with obj.writable_array(must_be_contiguous=must_be_contiguous) as arr:
yield arr
[docs]
def run_from_ipython():
"""If the process is run from IPython."""
return '__IPYTHON__' in globals()
[docs]
def pkg_supports(feature, pkg_version, pkg_feat_dict):
"""Return bool indicating whether a package supports ``feature``.
Parameters
----------
feature : str
Name of a potential feature of a package.
pkg_version : str
Version of the package that should be checked for presence of the
feature.
pkg_feat_dict : dict
Specification of features of a package. Each item has the
following form::
feature_name: version_specification
Here, ``feature_name`` is a string that is matched against
``feature``, and ``version_specification`` is a string or a
sequence of strings that specifies version sets. These
specifications are the same as for ``setuptools`` requirements,
just without the package name.
A ``None`` entry signals "no support in any version", i.e.,
always ``False``.
If a sequence of requirements are given, they are OR-ed together.
See ``Examples`` for details.
Returns
-------
supports : bool
``True`` if ``pkg_version`` of the package in question supports
``feature``, ``False`` otherwise.
Examples
--------
>>> feat_dict = {
... 'feat1': '==0.5.1',
... 'feat2': '>0.6, <=0.9', # both required simultaneously
... 'feat3': ['>0.6', '<=0.9'], # only one required, i.e. always True
... 'feat4': ['==0.5.1', '>0.6, <=0.9'],
... 'feat5': None
... }
>>> pkg_supports('feat1', '0.5.1', feat_dict)
True
>>> pkg_supports('feat1', '0.4', feat_dict)
False
>>> pkg_supports('feat2', '0.5.1', feat_dict)
False
>>> pkg_supports('feat2', '0.6.1', feat_dict)
True
>>> pkg_supports('feat2', '0.9', feat_dict)
True
>>> pkg_supports('feat2', '1.0', feat_dict)
False
>>> pkg_supports('feat3', '0.4', feat_dict)
True
>>> pkg_supports('feat3', '1.0', feat_dict)
True
>>> pkg_supports('feat4', '0.5.1', feat_dict)
True
>>> pkg_supports('feat4', '0.6', feat_dict)
False
>>> pkg_supports('feat4', '0.6.1', feat_dict)
True
>>> pkg_supports('feat4', '1.0', feat_dict)
False
>>> pkg_supports('feat5', '0.6.1', feat_dict)
False
>>> pkg_supports('feat5', '1.0', feat_dict)
False
"""
# This is an ugly workaround for the future deprecation of pkg_resources
def parse_requirements(s):
return (
Requirement(line)
for line in s.splitlines()
if line.strip() and not line.startswith("#")
)
feature = str(feature)
pkg_version = str(pkg_version)
supp_versions = pkg_feat_dict.get(feature, None)
if supp_versions is None:
return False
# Make sequence from single string
if is_string(supp_versions):
supp_versions = [supp_versions]
# Make valid package requirements
ver_specs = ['pkg' + supp_ver for supp_ver in supp_versions]
# Each parse_requirements list contains only one entry since we specify
# only one package
ver_reqs = [list(parse_requirements(ver_spec))[0]
for ver_spec in ver_specs]
# If one of the requirements in the list is met, return True
for req in ver_reqs:
if req.specifier.contains(pkg_version, prereleases=True):
return True
# No match
return False
[docs]
@contextmanager
def npy_random_seed(seed):
"""Context manager to temporarily set the NumPy random generator seed.
Parameters
----------
seed : int or None
Seed value for the random number generator.
``None`` is interpreted as keeping the current seed.
Examples
--------
Use this to make drawing pseudo-random numbers repeatable:
>>> with npy_random_seed(42):
... rand_int = np.random.randint(10)
>>> with npy_random_seed(42):
... same_rand_int = np.random.randint(10)
>>> rand_int == same_rand_int
True
"""
do_seed = seed is not None
orig_rng_state = None
try:
if do_seed:
orig_rng_state = np.random.get_state()
np.random.seed(seed)
yield
finally:
if do_seed and orig_rng_state is not None:
np.random.set_state(orig_rng_state)
[docs]
def unique(seq):
"""Return the unique values in a sequence.
Parameters
----------
seq : sequence
Sequence with (possibly duplicate) elements.
Returns
-------
unique : list
Unique elements of ``seq``.
Order is guaranteed to be the same as in seq.
Examples
--------
Determine unique elements in list
>>> unique([1, 2, 3, 3])
[1, 2, 3]
>>> unique((1, 'str', 'str'))
[1, 'str']
The utility also works with unhashable types:
>>> unique((1, [1], [1]))
[1, [1]]
"""
# First check if all elements are hashable, if so O(n) can be done
try:
return list(OrderedDict.fromkeys(seq))
except TypeError:
# Non-hashable, resort to O(n^2)
unique_values = []
for i in seq:
if i not in unique_values:
unique_values.append(i)
return unique_values
if __name__ == '__main__':
from odl.core.util.testutils import run_doctests
run_doctests()