# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Helpers for discretization-related functionality.
Most functions deal with interpolation of arrays, sampling of functions and
providing a single interface for the sampler by wrapping functions or
arrays of functions appropriately.
"""
import inspect
import sys
from builtins import object
from functools import partial
from itertools import product
from typing import Callable
from odl.core.set.domain import IntervalProd
import numpy as np
from odl.core.array_API_support import asarray, lookup_array_backend, ArrayBackend, get_array_and_backend
from odl.core.array_API_support.utils import is_array_supported
from odl.core.util.npy_compat import AVOID_UNNECESSARY_COPY
from odl.core.util.dtype_utils import _universal_dtype_identifier, is_floating_dtype, real_dtype, is_int_dtype
from odl.core.util import (
dtype_repr, is_real_dtype, is_string, is_valid_input_array,
is_valid_input_meshgrid, out_shape_from_array, out_shape_from_meshgrid,
writable_array)
__all__ = (
'point_collocation',
'nearest_interpolator',
'linear_interpolator',
'per_axis_interpolator',
'sampling_function',
)
SUPPORTED_INTERP = ['nearest', 'linear']
[docs]
def point_collocation(func, points, out=None, **kwargs):
"""Sample a function on a grid of points.
This function represents the simplest way of discretizing a function.
It does little more than calling the function on a single point or a
set of points, and returning the result.
Parameters
----------
func : callable
Function to be sampled. It is expected to work with single points,
meshgrids and point arrays, and to support an optional ``out``
argument.
Usually, ``func`` is the return value of `make_func_for_sampling`.
points : point, meshgrid or array of points
The point(s) where to sample.
out : numpy.ndarray, optional
Array to which the result should be written.
kwargs :
Additional arguments that are passed on to ``func``.
Returns
-------
out : numpy.ndarray
Array holding the values of ``func`` at ``points``. If ``out`` was
given, the returned object is a reference to it.
Examples
--------
Sample a 1D function:
>>> from odl.core.discr.grid import sparse_meshgrid
>>> domain = odl.IntervalProd(0, 5)
>>> func = sampling_function(lambda x: x ** 2, domain, out_dtype=float)
>>> mesh = sparse_meshgrid([1, 2, 3])
>>> point_collocation(func, mesh)
array([ 1., 4., 9.])
By default, inputs are checked against ``domain`` to be in bounds. This
can be switched off by passing ``bounds_check=False``:
>>> mesh = sparse_meshgrid([-1, 0, 4])
>>> point_collocation(func, mesh, bounds_check=False)
array([ 1., 0., 16.])
In two or more dimensions, the function to be sampled can be written as
if its arguments were the components of a point, and an implicit loop
around the call would iterate over all points:
>>> domain = odl.IntervalProd([0, 0], [5, 5])
>>> xs = [1, 2]
>>> ys = [3, 4, 5]
>>> mesh = sparse_meshgrid(xs, ys)
>>> func = sampling_function(lambda x: x[0] - x[1], domain, out_dtype=float)
>>> point_collocation(func, mesh)
array([[-2., -3., -4.],
[-1., -2., -3.]])
It is possible to return results that require broadcasting, and to use
*optional* function parameters:
>>> def f(x, c=0):
... return x[0] + c
>>> func = sampling_function(f, domain, out_dtype=float)
>>> point_collocation(func, mesh) # uses default c=0
array([[ 1., 1., 1.],
[ 2., 2., 2.]])
>>> point_collocation(func, mesh, c=2)
array([[ 3., 3., 3.],
[ 4., 4., 4.]])
The ``point_collocation`` function also supports vector- and tensor-valued
functions. They can be given either as a single function returning an
array-like of results, or as an array-like of member functions:
>>> domain = odl.IntervalProd([0, 0], [5, 5])
>>> xs = [1, 2]
>>> ys = [3, 4]
>>> mesh = sparse_meshgrid(xs, ys)
>>> def vec_valued(x):
... return (x[0] - 1., 0., x[0] + x[1]) # broadcasting
>>> # For a function with several output components, we must specify the
>>> # shape explicitly in the `out_dtype` parameter
>>> func1 = sampling_function(
... vec_valued, domain, out_dtype=float
... )
>>> point_collocation(func1, mesh)
[array([[ 0., 0.],
[ 1., 1.]]), array([[ 0., 0.],
[ 0., 0.]]), array([[ 4., 5.],
[ 5., 6.]])]
>>> list_of_funcs = [ # equivalent to `vec_valued`
... lambda x: x[0] - 1,
... 0, # constants are allowed
... lambda x: x[0] + x[1]
... ]
Notes
-----
This function expects its input functions to be written in a
vectorization-conforming manner to ensure fast evaluation.
See the `ODL vectorization guide`_ for a detailed introduction.
See Also
--------
make_func_for_sampling : wrap a function
odl.core.discr.grid.RectGrid.meshgrid
numpy.meshgrid
References
----------
.. _ODL vectorization guide:
https://odlgroup.github.io/odl/guide/in_depth/vectorization_guide.html
"""
if out is None:
out = func(points, **kwargs)
else:
func(points, out=out, **kwargs)
return out
def _normalize_interp(interp, ndim):
"""Turn interpolation type into a tuple with one entry per axis."""
interp_in = interp
if is_string(interp):
interp = str(interp).lower()
interp_byaxis = (interp,) * ndim
else:
interp_byaxis = tuple(str(itp).lower() for itp in interp)
if len(interp_byaxis) != ndim:
raise ValueError(
f"length of `interp` ({len(interp_byaxis)}) does not match number of axes ({ndim})"
)
if not all(interp in SUPPORTED_INTERP for interp in interp_byaxis):
raise ValueError(
f"invalid `interp` {interp_in}; supported are: {SUPPORTED_INTERP}"
)
return interp_byaxis
def _check_interp_input(x, f):
"""Return transformed ``x``, its input type and whether it's scalar.
On bad input, raise ``ValueError``.
"""
errmsg_1d = f"bad input: expected scalar, array-like of shape (1,), (n,) or (1, n), or a meshgrid of length 1; got {x}"
errmsg_nd = f"bad input: expected scalar, array-like of shape ({f.ndim},) or ({f.ndim}, n), or a meshgrid of length {f.ndim}; got {x}"
if is_valid_input_meshgrid(x, f.ndim):
x_is_scalar = False
x_type = 'meshgrid'
else:
### Parsing the input
if isinstance(x, (int,float,complex, list, tuple)):
x = np.asarray(x)
else:
x, _ = get_array_and_backend(x)
if f.ndim == 1 and x.shape == ():
x_is_scalar = True
x = x.reshape((1, 1))
elif f.ndim == 1 and x.ndim == 1:
x_is_scalar = False
x = x.reshape((1, x.size))
elif f.ndim > 1 and x.shape == (f.ndim,):
x_is_scalar = True
x = x.reshape((f.ndim, 1))
else:
x_is_scalar = False
if not is_valid_input_array(x, f.ndim):
errmsg = errmsg_1d if f.ndim == 1 else errmsg_nd
raise ValueError(errmsg)
x_type = 'array'
return x, x_type, x_is_scalar
[docs]
def nearest_interpolator(f, coord_vecs):
"""Return the nearest neighbor interpolator for discrete values.
Given points ``x[1] < x[2] < ... < x[N]``, and function values
``f[1], ..., f[N]``, nearest neighbor interpolation at ``x`` is defined
as ::
I(x) = f[j]
with ``j`` such that ``|x - x[j]|`` is minimal.
The ambiguity at the midpoints is resolved by preferring the right
neighbor. In higher dimensions, this principle is applied per axis.
The returned interpolator is the piecewise constant function ``x -> I(x)``.
Parameters
----------
f : numpy.ndarray
Function values that should be interpolated.
coord_vecs : sequence of numpy.ndarray
Coordinate vectors of the rectangular grid on which interpolation
should be based. They must be sorted in ascending order. Usually
they are obtained as ``grid.coord_vectors`` from a `RectGrid`.
Returns
-------
interpolator : function
Python function that will interpolate the given values when called
with a point or multiple points (vectorized).
Examples
--------
We interpolate a 1d function. If called with a single point, the
interpolator returns a single value, and with multiple points at once,
an array of values is returned:
>>> part = odl.uniform_partition(0, 2, 5)
>>> part.coord_vectors # grid points
(array([ 0.2, 0.6, 1. , 1.4, 1.8]),)
>>> f = odl.tensor_space(5, dtype=int).element([1, 2, 3, 4, 5])
>>> interpolator = nearest_interpolator(f, part.coord_vectors)
>>> interpolator(0.3) # closest to 0.2 -> value 1
1
>>> interpolator([0.6, 1.3, 1.9]) # closest to [0.6, 1.4, 1.8]
array([2, 4, 5], dtype=int32)
In 2 dimensions, we can either use a (transposed) list of points or
a meshgrid:
>>> part = odl.uniform_partition([0, 0], [1, 5], shape=(2, 4))
>>> part.coord_vectors # grid points
(array([ 0.25, 0.75]), array([ 0.625, 1.875, 3.125, 4.375]))
>>> f = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8]],
... dtype=float)
>>> interpolator = nearest_interpolator(f, part.coord_vectors)
>>> interpolator([1, 1]) # single point
5.0
>>> x = np.array([[0.5, 2.0],
... [0.0, 4.5],
... [0.0, 3.0]]).T # 3 points at once
>>> interpolator(x)
array([ 6., 4., 3.])
>>> from odl.core.discr.grid import sparse_meshgrid
>>> mesh = sparse_meshgrid([0.0, 0.4, 1.0], [1.5, 3.5])
>>> interpolator(mesh) # 3x2 grid of points
array([[ 2., 3.],
[ 2., 3.],
[ 6., 7.]])
See Also
--------
linear_interpolator : (bi-/tri-/...)linear interpolation
per_axis_interpolator : potentially different interpolation in each axis
Notes
-----
- **Important:** if called on a point array, the points are
assumed to be sorted in ascending order in each dimension
for efficiency reasons.
- Nearest neighbor interpolation is the only scheme which works
with data of non-numeric data type since it does not involve any
arithmetic operations on the values, in contrast to other
interpolation methods.
"""
f, _ = get_array_and_backend(f)
# TODO(kohr-h): pass reasonable options on to the interpolator
def nearest_interp(x, out=None):
"""Interpolating function with vectorization."""
x, x_type, x_is_scalar = _check_interp_input(x, f)
interpolator = _NearestInterpolator(coord_vecs, f, input_type=x_type)
res = interpolator(x, out=out)
if x_is_scalar:
res = res.item()
return res
return nearest_interp
[docs]
def linear_interpolator(f, coord_vecs):
"""Return the linear interpolator for discrete function values.
Parameters
----------
f : numpy.ndarray
Function values that should be interpolated.
coord_vecs : sequence of numpy.ndarray
Coordinate vectors of the rectangular grid on which interpolation
should be based. They must be sorted in ascending order. Usually
they are obtained as ``grid.coord_vectors`` from a `RectGrid`.
Returns
-------
interpolator : function
Python function that will interpolate the given values when called
with a point or multiple points (vectorized).
Examples
--------
We interpolate a 1d function. If called with a single point, the
interpolator returns a single value, and with multiple points at once,
an array of values is returned:
>>> part = odl.uniform_partition(0, 2, 5)
>>> part.coord_vectors # grid points
(array([ 0.2, 0.6, 1. , 1.4, 1.8]),)
>>> f = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
>>> interpolator = linear_interpolator(f, part.coord_vectors)
>>> interpolator(0.3) # 0.75 * 1 + 0.25 * 2 = 1.25
1.25
>>> # At 1.9, the value is interpolated between the last value 5.0 and
>>> # 0.0. The extra interpolation node is placed at the same distance
>>> # as the second-to-last, i.e., at 2.2. Hence, the interpolated value
>>> # is 0.75 * 5.0 + 0.25 * 0.0 = 3.75.
>>> interpolator([0.6, 1.3, 1.9])
array([ 2. , 3.75, 3.75])
In 2 dimensions, we can either use a (transposed) list of points or
a meshgrid:
>>> part = odl.uniform_partition([0, 0], [1, 5], shape=(2, 4))
>>> part.coord_vectors # grid points
(array([ 0.25, 0.75]), array([ 0.625, 1.875, 3.125, 4.375]))
>>> f = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8]],
... dtype=float)
>>> interpolator = linear_interpolator(f, part.coord_vectors)
>>> interpolator([1, 1]) # single point
2.65
>>> x = np.array([[0.5, 2.0],
... [0.0, 4.5],
... [0.0, 3.0]]).T # 3 points at once
>>> interpolator(x)
array([ 4.1 , 1.8 , 1.45])
>>> from odl.core.discr.grid import sparse_meshgrid
>>> mesh = sparse_meshgrid([0.0, 0.5, 1.0], [1.5, 3.5])
>>> interpolator(mesh) # 3x2 grid of points
array([[ 0.85, 1.65],
[ 3.7 , 5.3 ],
[ 2.85, 3.65]])
"""
f = asarray(f)
# TODO(kohr-h): pass reasonable options on to the interpolator
def linear_interp(x, out=None):
"""Interpolating function with vectorization."""
x, x_type, x_is_scalar = _check_interp_input(x, f)
interpolator = _LinearInterpolator(coord_vecs, f, input_type=x_type)
res = interpolator(x, out=out)
if x_is_scalar:
res = res.item()
return res
return linear_interp
[docs]
def per_axis_interpolator(f, coord_vecs, interp):
"""Return a per axis defined interpolator for discrete values.
With this function, the interpolation scheme can be chosen for each axis
separately.
Parameters
----------
f : numpy.ndarray
Function values that should be interpolated.
coord_vecs : sequence of numpy.ndarray
Coordinate vectors of the rectangular grid on which interpolation
should be based. They must be sorted in ascending order. Usually
they are obtained as ``grid.coord_vectors`` from a `RectGrid`.
interp : str or sequence of str
Indicates which interpolation scheme to use for which axis.
A single string is interpreted as a global scheme for all
axes.
Examples
--------
Choose linear interpolation in the first axis and nearest neighbor in
the second:
>>> part = odl.uniform_partition([0, 0], [1, 5], shape=(2, 4))
>>> part.coord_vectors
(array([ 0.25, 0.75]), array([ 0.625, 1.875, 3.125, 4.375]))
>>> f = np.array([[1, 2, 3, 4],
... [5, 6, 7, 8]],
... dtype=float)
>>> interpolator = per_axis_interpolator(
... f, part.coord_vectors, ['linear', 'nearest']
... )
>>> interpolator([1, 1]) # single point
2.5
>>> x = np.array([[0.5, 2.0],
... [0.0, 4.5],
... [0.0, 3.0]]).T # 3 points at once
>>> interpolator(x)
array([ 4. , 2. , 1.5])
>>> from odl.core.discr.grid import sparse_meshgrid
>>> mesh = sparse_meshgrid([0.0, 0.5, 1.0], [1.5, 3.5])
>>> interpolator(mesh) # 3x2 grid of points
array([[ 1. , 1.5],
[ 4. , 5. ],
[ 3. , 3.5]])
"""
f = asarray(f)
interp = _normalize_interp(interp, f.ndim)
def per_axis_interp(x, out=None):
"""Interpolating function with vectorization."""
x, x_type, x_is_scalar = _check_interp_input(x, f)
interpolator = _PerAxisInterpolator(
coord_vecs, f, interp=interp, input_type=x_type
)
res = interpolator(x, out=out)
if x_is_scalar:
res = res.item()
return res
return per_axis_interp
class _Interpolator:
r"""Abstract interpolator class.
The code is adapted from SciPy's `RegularGridInterpolator
<http://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.RegularGridInterpolator.html>`_
class.
The init method does not convert to floating point to
support arbitrary data type for nearest neighbor interpolation.
Subclasses need to override ``_evaluate`` for concrete
implementations.
"""
def __init__(self, coord_vecs, values, input_type):
"""Initialize a new instance.
coord_vecs : sequence of `numpy.ndarray`'s
Coordinate vectors defining the interpolation grid.
values : `array-like`
Grid values to use for interpolation.
input_type : {'array', 'meshgrid'}
Type of expected input values in ``__call__``.
"""
values, backend = get_array_and_backend(values)
typ_ = str(input_type).lower()
if typ_ not in ("array", "meshgrid"):
raise ValueError(f"`input_type` ({input_type}) not understood")
if len(coord_vecs) != values.ndim:
raise ValueError(
f"there are {len(coord_vecs)} point arrays, but `values` has {values.ndim} dimensions"
)
for i, p in enumerate(coord_vecs):
if not np.asarray(p).ndim == 1:
raise ValueError(
f"the points in dimension {i} must be " "1-dimensional"
)
if values.shape[i] != len(p):
raise ValueError(
f"there are {len(p)} points and {values.shape[i]} values in dimension {i}"
)
self.coord_vecs = tuple(np.asarray(p) for p in coord_vecs)
self.values = values
self.input_type = input_type
self.backend = backend
self.namespace = backend.array_namespace
self.device = values.device
def __call__(self, x, out=None):
"""Do the interpolation.
Parameters
----------
x : `meshgrid` (i.e., tuple of arrays) if `input_type` is meshgrid,
else `numpy.ndarray`.
Evaluation points of the interpolator.
out : `numpy.ndarray`, optional
Array to which the results are written. Needs to have
correct shape according to input ``x``.
Returns
-------
out : `numpy.ndarray`
Interpolated values. If ``out`` was given, the returned
object is a reference to it.
"""
if self.input_type == 'meshgrid':
# Given a meshgrid, the evaluation will be on a ragged array.
x = [get_array_and_backend(x_)[0] for x_ in x]
else:
x = get_array_and_backend(x)[0]
ndim = len(self.coord_vecs)
scalar_out = False
if self.input_type == 'array':
if ndim == 1:
scalar_out = x.ndim == 0
else:
scalar_out = x.shape == (ndim,)
# Make a (1, n) array from one with shape (n,)
x = x.reshape([ndim, -1])
out_shape = out_shape_from_array(x)
else:
if len(x) != ndim:
raise ValueError(
f"number of vectors in x is {len(x)} instead of the grid dimension {ndim}"
)
out_shape = out_shape_from_meshgrid(x)
if out is not None:
if not isinstance(out, self.backend.array_type):
raise TypeError(
f"The provided out argument is not an expected {type(self.backend.array_type)} but a {type(out)}"
)
if out.shape != out_shape:
raise ValueError(
f"output shape {out.shape} not equal to expected shape {out_shape}"
)
if out.dtype != self.values.dtype:
raise ValueError(
f"output dtype {out.dtype} not equal to expected dtype {self.values.dtype}"
)
indices, norm_distances = self._find_indices(x)
values = self._evaluate(indices, norm_distances, out)
if scalar_out:
return values.item()
else:
return values
def _find_indices(self, x):
"""Find indices and distances of the given nodes.
Can be overridden by subclasses to improve efficiency.
"""
# find relevant edges between which xi are situated
index_vecs = []
# compute distance to lower edge in unity units
norm_distances = []
# iterate through dimensions
for xi, cvec in zip(x, self.coord_vecs):
# try:
if is_floating_dtype(self.values.dtype):
dtype = real_dtype(self.values.dtype, backend=self.backend)
elif is_int_dtype(self.values.dtype):
dtype = real_dtype(float, backend=self.backend)
else:
raise ValueError(
f"Values can only be integers or float, not {type(self.values)}"
)
xi = self.backend.array_constructor(xi, dtype=dtype, device=self.device)
cvec = self.backend.array_constructor(cvec, dtype=dtype, device=self.device)
idcs = self.namespace.searchsorted(cvec, xi) - 1
idcs[idcs < 0] = 0
idcs[idcs > len(cvec) - 2] = len(cvec) - 2
index_vecs.append(idcs)
try:
norm_distances.append((xi - cvec[idcs]) / (cvec[idcs + 1] - cvec[idcs]))
except Exception as e:
print(f"{type(xi)=}, {type(cvec)=}")
raise e
return index_vecs, norm_distances
def _evaluate(self, indices, norm_distances, out=None):
"""Evaluation method, needs to be overridden."""
raise NotImplementedError("abstract method")
class _NearestInterpolator(_Interpolator):
"""Nearest neighbor interpolator.
The code is adapted from SciPy's `RegularGridInterpolator
<http://docs.scipy.org/doc/scipy/reference/generated/\
scipy.interpolate.RegularGridInterpolator.html>`_ class.
This implementation is faster than the more generic one in the
`_PerAxisPointwiseInterpolator`.
"""
def _evaluate(self, indices, norm_distances, out=None):
"""Evaluate nearest interpolation."""
idx_res = []
for i, yi in zip(indices, norm_distances):
idx_res.append(self.namespace.where(yi < .5, i, i + 1))
idx_res = tuple(idx_res)
if out is not None:
out[:] = self.values[idx_res]
return out
else:
return self.values[idx_res]
def _compute_nearest_weights_edge(idcs, ndist, backend):
"""Helper for nearest interpolation mimicing the linear case."""
# Get out-of-bounds indices from the norm_distances. Negative
# means "too low", larger than or equal to 1 means "too high"
lo = (ndist < 0)
hi = (ndist > 1)
# For "too low" nodes, the lower neighbor gets weight zero;
# "too high" gets 1.
w_lo = backend.array_namespace.where(ndist < 0.5, 1.0, 0.0)
w_lo[lo] = 0
w_lo[hi] = 1
# For "too high" nodes, the upper neighbor gets weight zero;
# "too low" gets 1.
w_hi = backend.array_namespace.where(ndist < 0.5, 0.0, 1.0)
w_hi[lo] = 1
w_hi[hi] = 0
# For upper/lower out-of-bounds nodes, we need to set the
# lower/upper neighbors to the last/first grid point
edge = [idcs, idcs + 1]
edge[0][hi] = -1
edge[1][lo] = 0
return w_lo, w_hi, edge
def _compute_linear_weights_edge(idcs, ndist, backend):
"""Helper for linear interpolation."""
assert isinstance(ndist, backend.array_type)
# Get out-of-bounds indices from the norm_distances. Negative
# means "too low", larger than or equal to 1 means "too high"
if backend.impl == 'numpy':
lo = backend.array_namespace.where(ndist < 0, ndist, 0).nonzero()
hi = backend.array_namespace.where(ndist > 1, ndist, 0).nonzero()
elif backend.impl == 'pytorch':
lo = backend.array_namespace.where(ndist < 0, ndist, 0).nonzero(as_tuple=True)
hi = backend.array_namespace.where(ndist > 1, ndist, 0).nonzero(as_tuple=True)
else:
raise NotImplementedError
# For "too low" nodes, the lower neighbor gets weight zero;
# "too high" gets 2 - yi (since yi >= 1)
w_lo = 1 - ndist
w_lo[lo] = 0
w_lo[hi] += 1
# For "too high" nodes, the upper neighbor gets weight zero;
# "too low" gets 1 + yi (since yi < 0)
w_hi = backend.array_constructor(ndist, copy=True)
w_hi[lo] += 1
w_hi[hi] = 0
# For upper/lower out-of-bounds nodes, we need to set the
# lower/upper neighbors to the last/first grid point
edge = [idcs, idcs + 1]
edge[0][hi] = -1
edge[1][lo] = 0
return w_lo, w_hi, edge
def _create_weight_edge_lists(indices, norm_distances, interp, backend):
"""Pre-calculate indices and weights (per axis)"""
low_weights = []
high_weights = []
edge_indices = []
for _, (idcs, yi, s) in enumerate(zip(indices, norm_distances, interp)):
if s == 'nearest':
w_lo, w_hi, edge = _compute_nearest_weights_edge(idcs, yi, backend=backend)
elif s == 'linear':
w_lo, w_hi, edge = _compute_linear_weights_edge(idcs, yi, backend=backend)
else:
raise ValueError(f"invalid `interp` {interp}")
low_weights.append(w_lo)
high_weights.append(w_hi)
edge_indices.append(edge)
return low_weights, high_weights, edge_indices
class _PerAxisInterpolator(_Interpolator):
"""Interpolator where the scheme is set per axis.
This allows to use e.g. nearest neighbor interpolation in the
first dimension and linear in dimensions 2 and 3.
"""
def __init__(self, coord_vecs, values, input_type, interp):
"""Initialize a new instance.
coord_vecs : sequence of `numpy.ndarray`'s
Coordinate vectors defining the interpolation grid
values : `array-like`
Grid values to use for interpolation
input_type : {'array', 'meshgrid'}
Type of expected input values in ``__call__``
interp : sequence of str
Indicates which interpolation scheme to use for which axis
"""
super().__init__(coord_vecs, values, input_type)
self.interp = interp
def _evaluate(self, indices, norm_distances, out=None):
"""Evaluate per-axis interpolation.
Modified for in-place evaluation and treatment of out-of-bounds
points by implicitly assuming 0 at the next node.
"""
# slice for broadcasting over trailing dimensions in self.values
vslice = (slice(None),) + (None,) * (self.values.ndim - len(indices))
if out is None:
out_shape = out_shape_from_meshgrid(norm_distances)
out_dtype = self.values.dtype
out = self.namespace.zeros(
out_shape, dtype=out_dtype, device=self.device
)
else:
out[:] = 0.0
# Weights and indices (per axis)
low_weights, high_weights, edge_indices = _create_weight_edge_lists(
indices, norm_distances, self.interp, backend=self.backend)
# Iterate over all possible combinations of [i, i+1] for each
# axis, resulting in a loop of length 2**ndim
for lo_hi, edge in zip(product(*([['l', 'h']] * len(indices))),
product(*edge_indices)):
weight = self.backend.array_constructor(
[1.0], dtype=self.values.dtype, device=self.device
)
# TODO(kohr-h): determine best summation order from array strides
for lh, w_lo, w_hi in zip(lo_hi, low_weights, high_weights):
# We don't multiply in-place to exploit the cheap operations
# in the beginning: sizes grow gradually as following:
# (n, 1, 1, ...) -> (n, m, 1, ...) -> ...
# Hence, it is faster to build up the weight array instead
# of doing full-size operations from the beginning.
# Emilien : This array-API compatibility is horribly slow ( sending the individual floats to the gpu while iterating is a hack around the inhomogeneous dimensions returned by _create_weight_edge_lists)
if lh == 'l':
weight = weight * self.backend.array_constructor(
w_lo, device=self.device)
else:
weight = weight * self.backend.array_constructor(
w_hi, device=self.device)
out += self.backend.array_constructor(self.values[edge], device=self.device) * weight[vslice]
# return np.array(out, copy=AVOID_UNNECESSARY_COPY, ndmin=1)
return self.backend.array_constructor(
out, copy=AVOID_UNNECESSARY_COPY, device=self.device
)
class _LinearInterpolator(_PerAxisInterpolator):
"""Linear (i.e. bi-/tri-/multi-linear) interpolator.
Convenience class.
"""
def __init__(self, coord_vecs, values, input_type):
"""Initialize a new instance.
coord_vecs : sequence of `numpy.ndarray`'s
Coordinate vectors defining the interpolation grid
values : `array-like`
Grid values to use for interpolation
input_type : {'array', 'meshgrid'}
Type of expected input values in ``__call__``
"""
super().__init__(
coord_vecs,
values,
input_type,
interp=['linear'] * len(coord_vecs),
)
def _check_func_out_arg(func):
"""Check if ``func`` has an (optional) ``out`` argument.
Also verify that the signature of ``func`` has no ``*args`` since
they make argument propagation a huge hassle.
Note: this function only works for objects that can be inspected
with the ``inspect`` module, i.e., Python functions and callables,
but not, e.g., NumPy UFuncs.
Parameters
----------
func : callable
Object that should be inspected.
Returns
-------
has_out : bool
``True`` if the signature has an ``out`` argument, ``False``
otherwise.
out_is_optional : bool
``True`` if ``out`` is present and optional in the signature,
``False`` otherwise.
Raises
------
TypeError
If ``func``'s signature has ``*args``.
"""
if sys.version_info.major > 2:
spec = inspect.getfullargspec(func)
kw_only = spec.kwonlyargs
else:
spec = inspect.getargspec(func)
kw_only = ()
if spec.varargs is not None:
raise TypeError("*args not allowed in function signature")
pos_args = spec.args
pos_defaults = () if spec.defaults is None else spec.defaults
if 'out' in pos_args:
has_out = True
out_optional = (
pos_args.index('out') >= len(pos_args) - len(pos_defaults)
)
elif 'out' in kw_only:
has_out = out_optional = True
else:
has_out = out_optional = False
return has_out, out_optional
def _func_out_type(func):
"""Determine the output argument type (if any) of a function-like object.
This function is intended to work with all types of callables
that are used as input to `sampling_function`.
"""
# Numpy `UFuncs` and similar objects (e.g. Numba `DUFuncs`)
if hasattr(func, 'nin') and hasattr(func, 'nout'):
if func.nin != 1:
raise ValueError(
f"ufunc {func.__name__} takes {func.nin} input arguments, expected 1"
)
if func.nout > 1:
raise ValueError(
f"ufunc {func.__name__} returns {func.nout} outputs, expected 0 or 1"
)
has_out = out_optional = (func.nout == 1)
elif inspect.isfunction(func):
has_out, out_optional = _check_func_out_arg(func)
elif callable(func):
has_out, out_optional = _check_func_out_arg(func.__call__)
else:
raise TypeError(f"object {func} not callable")
return has_out, out_optional
def _broadcast_nested_list(arr_lists, element_shape, ndim, backend: ArrayBackend):
"""A generalisation of `np.broadcast_to`, applied to an arbitrarily
deep list (or tuple) eventually containing arrays or scalars."""
if isinstance(arr_lists, backend.array_type) or np.isscalar(arr_lists):
if ndim == 1:
# As usual, 1d is tedious to deal with. This
# code deals with extra dimensions in result
# components that stem from using x instead of
# x[0] in a function.
# Without this, broadcasting fails.
shp = getattr(arr_lists, 'shape', ())
if shp and shp[0] == 1:
arr_lists = arr_lists.reshape(arr_lists.shape[1:])
return backend.array_namespace.broadcast_to(arr_lists, element_shape)
else:
return [_broadcast_nested_list(row, element_shape, ndim, backend=backend)
for row in arr_lists]
def _send_nested_list_to_backend(
arr_lists, backend : ArrayBackend, device, dtype
):
if backend.impl == 'numpy':
return arr_lists
if isinstance(arr_lists, np.ndarray) or np.isscalar(arr_lists):
return backend.array_constructor(arr_lists, device=device, dtype=dtype)
elif isinstance(arr_lists, (tuple,list)):
return [_send_nested_list_to_backend(arr, backend, device, dtype) for arr in arr_lists]
else:
raise TypeError(f"Type of input {type(arr_lists)} not supported.")
[docs]
def sampling_function(
func : Callable | list | tuple,
domain : IntervalProd,
out_dtype : str = None,
impl: str ='numpy',
device: str ='cpu'
):
"""Return a function that can be used for sampling.
For examples on this function's usage, see `point_collocation`.
Parameters
----------
func_or_arr : callable or array-like
Either a single callable object (possibly with multiple output
components), or an array or callables and constants.
A callable (or each callable) must take a single input and may
accept one output parameter called ``out``, and should return
its result.
domain : IntervalProd
Set in which inputs to the function are assumed to lie. It is used
to determine the type of input (point/meshgrid/array) based on
``domain.ndim``, and (unless switched off) to check whether all
inputs are in bounds.
out_dtype : optional
Data type of a *single* output of ``func_or_arr``, i.e., when
called with a single point as input. In particular:
- If ``func_or_arr`` is a scalar-valued function, ``out_dtype`` is
expected to be a basic dtype with empty shape.
- If ``func_or_arr`` is a vector- or tensor-valued function,
``out_dtype`` should be a shaped data type, e.g., ``(float, (3,))``
for a vector-valued function with 3 components.
- If ``func_or_arr`` is an array-like, ``out_dtype`` should be a
shaped dtype whose shape matches that of ``func_or_arr``. It can
also be ``None``, in which case the shape is inferred, and the
scalar data type is set to ``float``.
Returns
-------
func : function
Wrapper function that has no optional ``out`` argument.
"""
def _infer_dtype(out_dtype: str | None):
if out_dtype is None:
out_dtype = 'float64'
else:
assert is_floating_dtype(out_dtype)
return out_dtype
def _sanitise_callable(func: Callable) -> Callable:
# Get default implementations if necessary
has_out, out_optional = _func_out_type(func)
if has_out:
raise NotImplementedError(
"Currently, not implemented for out-of-place functions")
return func
def _sanitise_input_function(func: Callable):
"""
This function aims at unpacking the input function `func`.
The former API expects a callable or array-like (of callables)
The new API checks
"""
if isinstance(func, Callable):
return _sanitise_callable(func)
elif isinstance(func, (list, tuple)):
raise NotImplementedError("The sampling function cannot be instantiated"
+ " with a list-like of callables.")
else:
raise NotImplementedError("The function to sample must be either a Callable"
+ " or an array-like (list, tuple) of callables.")
### We begin by sanitising the inputs:
# 1) the dtype
out_dtype = _infer_dtype(out_dtype)
# 2) the func_or_arr
func = _sanitise_input_function(func)
### We then create the function
return _make_single_use_func(func, domain, out_dtype, impl, device)
def _make_single_use_func(
func_oop, domain, out_dtype, impl: str = 'numpy', device: str = 'cpu'):
"""Return a unifying wrapper function with optional ``out`` argument."""
# Default to `ndim=1` for unusual domains that do not define a dimension
# (like `Strings(3)`)
ndim = getattr(domain, 'ndim', 1)
if out_dtype is None:
# Don't let `np.dtype` convert `None` to `float64`
raise TypeError("`out_dtype` cannot be `None`")
out_dtype = np.dtype(out_dtype)
val_shape = out_dtype.shape
scalar_out_dtype = out_dtype.base
def single_use_func(x, **kwargs):
"""Wrapper function with optional ``out`` argument.
This function closes over two other functions, one for in-place,
the other for out-of-place evaluation. Its purpose is to unify their
interfaces to a single one with optional ``out`` argument, and to
automate all details of input/output checking, broadcasting and
type casting.
The closure also contains ``domain``, an `IntervalProd` where points
should lie, and the expected ``out_dtype``.
For usage examples, see `point_collocation`.
Parameters
----------
x : point, `meshgrid` or `numpy.ndarray`
Input argument for the function evaluation. Conditions
on ``x`` depend on its type:
- point: must be castable to an element of the enclosed ``domain``.
- meshgrid: length must be ``domain.ndim``, and the arrays must
be broadcastable against each other.
- array: shape must be ``(ndim, N)``, where ``ndim`` equals
``domain.ndim``.
out : `numpy.ndarray`, optional
Output argument holding the result of the function evaluation.
Its shape must be ``out_dtype.shape + np.broadcast(*x).shape``.
Other Parameters
----------------
bounds_check : bool, optional
If ``True``, check if all input points lie in ``domain``. This
requires ``domain`` to implement `Set.contains_all`.
Default: ``True``
Returns
-------
out : `numpy.ndarray`
Result of the function evaluation. If ``out`` was provided,
the returned object is a reference to it.
Raises
------
TypeError
If ``x`` is not a valid vectorized evaluation argument.
If ``out`` is neither ``None`` nor a `numpy.ndarray` of
adequate shape and data type.
ValueError
If ``bounds_check == True`` and some evaluation points fall
outside the valid domain.
"""
bounds_check = kwargs.pop('bounds_check', True)
if bounds_check and not hasattr(domain, 'contains_all'):
raise AttributeError(
f"bounds check not possible for domain {domain}, missing `contains_all()` method"
)
# Check for input type and determine output shape
if is_valid_input_meshgrid(x, ndim):
scalar_in = False
scalar_out_shape = out_shape_from_meshgrid(x)
# Avoid operations on tuples like x * 2 by casting to array
if ndim == 1:
x = x[0][None, ...]
elif is_valid_input_array(x, ndim):
x = np.asarray(x)
scalar_in = False
scalar_out_shape = out_shape_from_array(x)
elif x in domain:
x = np.atleast_2d(x).T # make a (d, 1) array
scalar_in = True
scalar_out_shape = (1,)
else:
# Unknown input
txt_1d = ' or (n,)' if ndim == 1 else ''
raise TypeError(
f"argument {x} not a valid function input. Expected an element of the domain {domain}, an array-like with shape ({domain.ndim}, n){txt_1d} or a length-{domain.ndim} meshgrid tuple."
)
# Check bounds if specified
if bounds_check and not domain.contains_all(x):
raise ValueError(f"input contains points outside the domain {domain}")
backend = lookup_array_backend(impl)
array_ns = backend.array_namespace
backend_scalar_out_dtype = backend.available_dtypes[
_universal_dtype_identifier(scalar_out_dtype)
]
x = _send_nested_list_to_backend(x, backend, device, backend_scalar_out_dtype)
if scalar_in:
out_shape = val_shape
else:
out_shape = val_shape + scalar_out_shape
if ndim == 1:
try:
out = func_oop(x, **kwargs)
except (TypeError, IndexError):
# TypeError is raised if a meshgrid was used but the
# function expected an array (1d only). In this case we try
# again with the first meshgrid vector.
# IndexError is raised in expressions like x[x > 0] since
# "x > 0" evaluates to 'True', i.e. 1, and that index is
# out of range for a meshgrid tuple of length 1 :-). To get
# the real errors with indexing, we check again for the
# same scenario (scalar output when not valid) as in the
# first case.
out = func_oop(x[0], **kwargs)
else:
# Here we don't catch exceptions since they are likely true
# errors
out = func_oop(x, **kwargs)
def _process_array(out):
if isinstance(out, backend.array_type) or np.isscalar(out):
# Cast to proper dtype if needed, also convert to array if out is a scalar.
out = backend.array_constructor(
out, dtype=backend_scalar_out_dtype, device=device
)
if scalar_in:
out = array_ns.squeeze(out, 0)
elif ndim == 1 and out.shape == (1,) + out_shape:
out = out.reshape(out_shape)
if out_shape != () and out.shape != out_shape:
# Broadcast the returned element, but not in the
# scalar case. The resulting array may be read-only,
# in which case we copy.
out = array_ns.broadcast_to(out, out_shape)
out = backend.array_constructor(out, copy=True)
return out
elif isinstance(out, (tuple, list)):
result = []
assert len(out) != 0
for sub_out in out:
result.append(_process_array(sub_out))
return result
return _process_array(out)
return single_use_func
if __name__ == '__main__':
from odl.core.util.testutils import run_doctests
run_doctests()