# Copyright 2014-2025 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
"""Interface for sparse matrices"""
# We import to initialize the backends
# pylint: disable=unused-import
# We want to import if the backends are actually available
# pylint: disable=import-outside-toplevel
# We want to use a global statement here
# pylint: disable=global-statement
from typing import Optional
import importlib.util
from odl.core.sparse.sparse_template import (
SparseMatrixFormat,
_registered_sparse_formats,
)
IS_INITIALIZED = False
def _initialize_if_needed():
"""Initialize ``_registered_sparse_formats`` if not already done."""
global IS_INITIALIZED
if not IS_INITIALIZED:
import odl.backends.sparse.scipy_backend
torch_module = importlib.util.find_spec("torch")
if torch_module is not None:
try:
import odl.backends.sparse.pytorch_backend
except ModuleNotFoundError:
pass
IS_INITIALIZED = True
[docs]
class SparseMatrix:
"""
SparseMatrix is the ODL interface to the sparse Matrix supports in different backends.
Note:
The user is responsible for using the *args and **kwargs expected by the respective backends:
Pytorch:
-> COO: https://docs.pytorch.org/docs/stable/generated/torch.sparse_coo_tensor.html
Scipy:
-> COO: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.
Examples:
SparseMatrix('COO', 'pytorch',
[[0, 1, 1],[2, 0, 2]], [3, 4, 5],
device='cuda:0')
SparseMatrix('COO', 'scipy',
(3, 4))
"""
def __new__(cls, format:str, impl:str, *args, **kwargs):
_initialize_if_needed()
# sanity checks
assert isinstance(
format, str
), f"The sparse data format can only be a string, got {type(format)}"
assert isinstance(
impl, str
), f"The impl argument can only be a str, got {type(impl)}"
# Getting the backend (scipy, Pytorch...)
backend_formats = _registered_sparse_formats.get(impl)
if backend_formats is None:
raise ValueError(
f"The backend {impl} is not supported. Only {list(_registered_sparse_formats.keys())} are registered backends."
)
# Getting the format (COO, CSR...)
sparse_impl = backend_formats.get(format)
if sparse_impl is None:
raise ValueError(
f"No format {impl}. Only {list(backend_formats.keys())} are registered backends."
)
return sparse_impl.constructor(*args, **kwargs)
[docs]
def is_sparse(matrix:object) -> bool:
"""Checks whether the object is a sparse matrix in one
of the format known to ODL.
Args:
matrix (object): input matrix
Returns:
bool: True if matrix is sparse else False
"""
return lookup_sparse_format(matrix) is not None
[docs]
def get_sparse_matrix_impl(matrix:object) -> str:
"""Gets the implementation string name of a matrix (which
must be in one of the sparse formats known to ODL).
Args:
matrix (object): matrix
Returns:
str: The implementation string identifier ('pytorch', 'scipy', ...)
"""
instance = lookup_sparse_format(matrix)
assert instance is not None, "The matrix is not a supported sparse matrix"
return instance.impl
if __name__ == '__main__':
print(SparseMatrix('COO', 'scipy', (3, 4)))