{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Data augmentation usage\n\nCredit: A Grigis\n\nA simple example on how to use a data augmentation. More specifically,\nlearn how to use a set of tools to efficiently augment 3D MRI images. It\nincludes random affine/non linear transformations, simulation of intensity\nartifacts due to MRI magnetic field inhomogeneity or k-space motion\nartifacts.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pprint import pprint\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport torch\nfrom torch.utils.data import DataLoader\nimport torchvision.transforms as transforms\nfrom torchvision.utils import make_grid\nimport brainrise\nfrom brainrise.datasets import MRIToyDataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Available augmentation methods\n\nFirst list all available augmentation methods.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "trfs = brainrise.get_augmentations()\npprint(trfs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Toy MRI dataset\n\nUse the toy MRI dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def show(imgs):\n    if not isinstance(imgs, list):\n        imgs = [imgs]\n    fix, axs = plt.subplots(nrows=len(imgs), squeeze=False)\n    for idx, img in enumerate(imgs):\n        img = img.detach()\n        img = transforms.functional.to_pil_image(img)\n        axs[idx, 0].imshow(np.asarray(img))\n        axs[idx, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)\n\ntransform = brainrise.Compose([\n    brainrise.Rescale(dynamic=(0, 1), percentiles=(5, 97)),\n    brainrise.ToTensor()])\ndataset = MRIToyDataset(root=\"/tmp\", transform=transform)\ndataloader = DataLoader(dataset, batch_size=1)\nbatch_input, batch_output = next(iter(dataloader))\nbatch_data = torch.cat((batch_input, batch_output.type(torch.float32)), dim=1)\nbatch_data = torch.transpose(batch_data, dim0=0, dim1=1)\nmid_slice = (batch_data.shape[-1] // 2)\ngrid = make_grid(batch_data[..., mid_slice], nrow=5)\nshow(grid)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data augmentation\n\nPerform a simple A/P random flip and an affine transformation + random noise\naugmentations.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "imgs = []\ntransform = brainrise.Compose([\n    brainrise.RandomApply([brainrise.RandomFlip(axis=1)], p=0.5),\n    brainrise.Rescale(dynamic=(0, 1), percentiles=(5, 97)),\n    brainrise.ToTensor()])\ndataset = MRIToyDataset(root=\"/tmp\", transform=transform)\ndataloader = DataLoader(dataset, batch_size=1)\nfor epoch in range(5):\n    for batch_input, batch_output in dataloader:\n        batch_data = torch.cat((\n            batch_input, batch_output.type(torch.float32)), dim=1)\n        batch_data = torch.transpose(batch_data, dim0=0, dim1=1)\n        mid_slice = (batch_data.shape[-1] // 2)\n        imgs.append(make_grid(batch_data[..., mid_slice]))\nshow(imgs)\n\nimgs = []\ntransform = brainrise.Compose([\n    brainrise.RandomApply([brainrise.RandomNoise(snr=20)], p=0.5),\n    brainrise.RandomAffine(rotation=3, translation=4, zoom=0.05, order=1),\n    brainrise.Rescale(dynamic=(0, 1), percentiles=(5, 97)),\n    brainrise.ToTensor()])\ndataset = MRIToyDataset(root=\"/tmp\", transform=transform)\ndataloader = DataLoader(dataset, batch_size=1)\nfor epoch in range(5):\n    for batch_input, batch_output in dataloader:\n        batch_data = torch.cat((\n            batch_input, batch_output.type(torch.float32)), dim=1)\n        batch_data = torch.transpose(batch_data, dim0=0, dim1=1)\n        mid_slice = (batch_data.shape[-1] // 2)\n        imgs.append(make_grid(batch_data[..., mid_slice]))\nshow(imgs)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}