Menu

Module providing common Brain MRI Augmentation Methods for PyTorch.

Source code for brainrise.core

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Core transformations dealing with a label target.
"""

# Imports
from inspect import signature
import numpy as np
import torch


[docs]class Compose(object): """ Composes several transforms together. """
[docs] def __init__(self, transforms): self.transforms = transforms self.n_params = [ len(signature(trf).parameters) for trf in self.transforms]
def __call__(self, image, label=None): for trf, size in zip(self.transforms, self.n_params): if label is not None and size == 2: image, label = trf(image, label) else: image = trf(image) if label is None: return image else: return image, label
[docs]class ToTensor(object): """ Convert a numpy.ndarray to tensor. """ def __call__(self, image, label=None): image = torch.as_tensor(np.array(image), dtype=torch.float32) if label is None: return image else: label = torch.as_tensor(np.array(label), dtype=torch.int64) return image, label
[docs]class RandomApply(object): """ Apply randomly a list of transformations with a given probability. """
[docs] def __init__(self, transforms, p=0.5): self.transforms = transforms self.p = p self.n_params = [ len(signature(trf).parameters) for trf in self.transforms]
def __call__(self, image, label=None): if self.p < torch.rand(1): if label is None: return image else: return image, label for trf, size in zip(self.transforms, self.n_params): if label is not None and size == 2: image, label = trf(image, label) else: image = trf(image) if label is None: return image else: return image, label def __repr__(self): format_string = self.__class__.__name__ + "(" format_string += "\n p={}".format(self.p) for trf in self.transforms: format_string += "\n" format_string += " {0}".format(trf) format_string += "\n)" return format_string

Follow us

© 2023, brainrise developers