Menu

Module providing common Brain MRI Augmentation Methods for PyTorch.

Source code for brainrise.datasets

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Create the MRI Toy dataset.
"""

# Imports
import os
import shutil
import requests
import nibabel
import numpy as np
from torch.utils.data import Dataset


[docs]class MRIToyDataset(Dataset): """ Create the MRI Toy dataset. """ lesion_url = ( "https://raw.github.com/muschellij2/open_ms_data/master/" "cross_sectional/coregistered_resampled/patient01/consensus_gt.nii.gz") t1w_url = ( "https://raw.github.com/muschellij2/open_ms_data/master/" "cross_sectional/coregistered_resampled/patient01/T1W.nii.gz") t2w_url = ( "https://raw.github.com/muschellij2/open_ms_data/master/" "cross_sectional/coregistered_resampled/patient01/T2W.nii.gz") flair_url = ( "https://raw.github.com/muschellij2/open_ms_data/master/" "cross_sectional/coregistered_resampled/patient01/FLAIR.nii.gz") mask_url = ( "https://raw.github.com/muschellij2/open_ms_data/master/" "cross_sectional/coregistered_resampled/patient01/brainmask.nii.gz")
[docs] def __init__(self, root, transform=None): """ Init class. Parameters ---------- root: str root directory of dataset where data will be saved. transform: callable, default None optional transform to be applied on a sample. """ super(MRIToyDataset).__init__() self.root = root self.transform = transform self.data_file = os.path.join(root, "mritoy.npz") self.download() self.data = np.load(self.data_file, mmap_mode="r")
[docs] def download(self): """ Download data. """ if not os.path.isfile(self.data_file): # Fetch data dataset = {} for name, url in (("t1w", self.t1w_url), ("t2w", self.t2w_url), ("flair", self.flair_url), ("lesion", self.lesion_url), ("mask", self.mask_url)): basename = url.split("/")[-1] path = os.path.join(self.root, basename) if not os.path.isfile(path): print("Downloading {0}.".format(url)) response = requests.get(url, stream=True) with open(path, "wb") as out_file: response.raw.decode_content = False shutil.copyfileobj(response.raw, out_file) del response dataset[name] = nibabel.load(path).get_fdata() # Save dataset np.savez(self.data_file, **dataset)
def __len__(self): return 1 def __getitem__(self, idx): input_data = [self.data["t1w"], self.data["t2w"], self.data["flair"]] label_data = [self.data["lesion"], self.data["mask"]] if self.transform is not None: input_data, label_data = self.transform(input_data, label_data) return input_data, label_data

Follow us

© 2023, brainrise developers