Menu

Module providing common Brain MRI Augmentation Methods for PyTorch.

Source code for brainrise.utils

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
A module with common functions.
"""

# Import
import numbers


[docs]class Transform(object): """ A base class for transformations. """
[docs] @classmethod def apply(cls, arr, fct, *args, **kwargs): """ Apply transformation to data. Parameters ---------- arr: array or list of array the input data. fct: callable or str the transformation function. kwargs: dict the function parameters. Returns ------- transformed: array or list of array the transformed input data. """ if isinstance(fct, str): transformed = [ getattr(_arr, fct)(*args, **kwargs) for _arr in listify(arr)] else: transformed = [ fct(_arr, *args, **kwargs) for _arr in listify(arr)] return flatten(transformed)
[docs] @classmethod def shape(cls, arr): """ Return the shape of an array. Parameters ---------- arr: array or list of array input array. Returns ------- shape: tuple of int the elements of the shape tuple give the lengths of the corresponding array dimensions. """ _arr = listify(arr)[0] return _arr.shape
[docs] @classmethod def ndim(cls, arr): """ Number of array dimensions. Parameters ---------- arr: array or list of array input array. Returns ------- ndim: int the array number of dimensions. """ _arr = listify(arr)[0] return _arr.ndim
[docs] @classmethod def max(cls, arr, axis=None): """ Return the maximum along a given axis. Parameters ---------- arr: array or list of array input array. Returns ------- ndim: int the array number of dimensions. """ _arr = listify(arr)[0] return _arr.max(axis=axis)
[docs]def listify(data): """ Ensure that the input is a list or tuple. Parameters ---------- arr: list or array the input data. Returns ------- out: list the liftify input data. """ if isinstance(data, list) or isinstance(data, tuple): return data else: return [data]
[docs]def flatten(data): """ Ensure that the list contains more than one element. Parameters ---------- arr: list the listify input data. Returns ------- out: list or array the output data. """ if len(data) == 1: return data[0] else: return data
[docs]def interval(obj, lower=None): """ Listify an object. Parameters ---------- obj: 2-uplet or number the object used to build the interval. lower: number, default None the lower bound of the interval. If not specified, a symetric interval is generated. Returns ------- interval: 2-uplet an interval. """ if isinstance(obj, numbers.Number): if obj < 0: raise ValueError("Specified interval value must be positive.") if lower is None: lower = -obj return (lower, obj) if len(obj) != 2: raise ValueError("Interval must be specified with 2 values.") min_val, max_val = obj if min_val > max_val: raise ValueError("Wrong interval boudaries.") return tuple(obj)

Follow us

© 2023, brainrise developers