From fb0ee2cd435e65faa73433b573a3deea67f03633 Mon Sep 17 00:00:00 2001
From: Zachary Rahn <zr66@drexel.edu>
Date: Sun, 29 Jan 2023 16:48:27 -0500
Subject: [PATCH] adding training notebook for Google Colab

---
 ml/ML_Training_Colab.ipynb | 488 +++++++++++++++++++++++++++++++++++++
 1 file changed, 488 insertions(+)
 create mode 100644 ml/ML_Training_Colab.ipynb

diff --git a/ml/ML_Training_Colab.ipynb b/ml/ML_Training_Colab.ipynb
new file mode 100644
index 0000000..8b242e0
--- /dev/null
+++ b/ml/ML_Training_Colab.ipynb
@@ -0,0 +1,488 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    },
+    "accelerator": "GPU",
+    "gpuClass": "standard"
+  },
+  "cells": [
+    {
+      "cell_type": "code",
+      "execution_count": 2,
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "NLG5eyci_qGY",
+        "outputId": "d1df9059-6e7c-4f73-a420-60509cffcecf"
+      },
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+            "Collecting awscli\n",
+            "  Downloading awscli-1.27.59-py3-none-any.whl (4.0 MB)\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m91.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25hCollecting botocore==1.29.59\n",
+            "  Downloading botocore-1.29.59-py3-none-any.whl (10.4 MB)\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m43.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25hCollecting PyYAML<5.5,>=3.10\n",
+            "  Downloading PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m662.4/662.4 KB\u001b[0m \u001b[31m59.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25hCollecting colorama<0.4.5,>=0.2.5\n",
+            "  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)\n",
+            "Collecting s3transfer<0.7.0,>=0.6.0\n",
+            "  Downloading s3transfer-0.6.0-py3-none-any.whl (79 kB)\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.6/79.6 KB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25hRequirement already satisfied: docutils<0.17,>=0.10 in /usr/local/lib/python3.8/dist-packages (from awscli) (0.16)\n",
+            "Collecting rsa<4.8,>=3.1.2\n",
+            "  Downloading rsa-4.7.2-py3-none-any.whl (34 kB)\n",
+            "Collecting urllib3<1.27,>=1.25.4\n",
+            "  Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)\n",
+            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.6/140.6 KB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+            "\u001b[?25hRequirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.8/dist-packages (from botocore==1.29.59->awscli) (2.8.2)\n",
+            "Collecting jmespath<2.0.0,>=0.7.1\n",
+            "  Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)\n",
+            "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.8/dist-packages (from rsa<4.8,>=3.1.2->awscli) (0.4.8)\n",
+            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil<3.0.0,>=2.1->botocore==1.29.59->awscli) (1.15.0)\n",
+            "Installing collected packages: urllib3, rsa, PyYAML, jmespath, colorama, botocore, s3transfer, awscli\n",
+            "  Attempting uninstall: urllib3\n",
+            "    Found existing installation: urllib3 1.24.3\n",
+            "    Uninstalling urllib3-1.24.3:\n",
+            "      Successfully uninstalled urllib3-1.24.3\n",
+            "  Attempting uninstall: rsa\n",
+            "    Found existing installation: rsa 4.9\n",
+            "    Uninstalling rsa-4.9:\n",
+            "      Successfully uninstalled rsa-4.9\n",
+            "  Attempting uninstall: PyYAML\n",
+            "    Found existing installation: PyYAML 6.0\n",
+            "    Uninstalling PyYAML-6.0:\n",
+            "      Successfully uninstalled PyYAML-6.0\n",
+            "Successfully installed PyYAML-5.4.1 awscli-1.27.59 botocore-1.29.59 colorama-0.4.4 jmespath-1.0.1 rsa-4.7.2 s3transfer-0.6.0 urllib3-1.26.14\n",
+            "/content/drive/My Drive/awscli.ini\n"
+          ]
+        }
+      ],
+      "source": [
+        "!pip install -q awscli\n",
+        "import os\n",
+        "!export AWS_SHARED_CREDENTIALS_FILE=/content/drive/My\\ Drive/config/awscli.ini\n",
+        "path = \"/content/drive/My Drive/awscli.ini\"\n",
+        "os.environ['AWS_SHARED_CREDENTIALS_FILE'] = path\n",
+        "print(os.environ['AWS_SHARED_CREDENTIALS_FILE'])\n",
+        "!mkdir -p data\n",
+        "!aws s3 cp s3://digpath-chips/ data/ --recursive --quiet"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import glob\n",
+        "\n",
+        "mild_train_files = glob.glob('data/train/Mild/*.png')\n",
+        "mild_val_files = glob.glob('data/val/Mild/*.png')\n",
+        "mild_test_files = glob.glob('data/test/Mild/*.png')\n",
+        "\n",
+        "severe_train_files = glob.glob('data/train/Severe/*.png')\n",
+        "severe_val_files = glob.glob('data/val/Severe/*.png')\n",
+        "severe_test_files = glob.glob('data/test/Severe/*.png')\n",
+        "simulated_files = glob.glob('data/train/Severe/*sim*.png')\n",
+        "\n",
+        "print('mild_train_files', len(mild_train_files))\n",
+        "print('mild_val_files', len(mild_val_files))\n",
+        "print('mild_test_files', len(mild_test_files))\n",
+        "print()\n",
+        "print('severe_train_files', len(severe_train_files))\n",
+        "print('severe_val_files', len(severe_val_files))\n",
+        "print('severe_test_files', len(severe_test_files))\n",
+        "print('simulated_files', len(simulated_files))"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "DaRvl_fdPpNo",
+        "outputId": "53bdf005-5da7-4371-c7ad-b19af4f68ce0"
+      },
+      "execution_count": 3,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "mild_train_files 12744\n",
+            "mild_val_files 1661\n",
+            "mild_test_files 2490\n",
+            "\n",
+            "severe_train_files 4344\n",
+            "severe_val_files 851\n",
+            "severe_test_files 1413\n",
+            "simulated_files 965\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import random\n",
+        "random.shuffle(mild_train_files)\n",
+        "random.shuffle(mild_val_files)\n",
+        "random.shuffle(mild_test_files)\n",
+        "for filename in mild_train_files[len(severe_train_files):]:\n",
+        "  os.remove(filename)\n",
+        "for filename in mild_val_files[len(severe_val_files):]:\n",
+        "  os.remove(filename)\n",
+        "for filename in mild_test_files[len(severe_test_files):]:\n",
+        "  os.remove(filename)\n",
+        "\n",
+        "mild_train_files = glob.glob('data/train/Mild/*.png')\n",
+        "mild_val_files = glob.glob('data/val/Mild/*.png')\n",
+        "mild_test_files = glob.glob('data/test/Mild/*.png')\n",
+        "\n",
+        "severe_train_files = glob.glob('data/train/Severe/*.png')\n",
+        "severe_val_files = glob.glob('data/val/Severe/*.png')\n",
+        "severe_test_files = glob.glob('data/test/Severe/*.png')\n",
+        "simulated_files = glob.glob('data/train/Severe/*sim*.png')\n",
+        "\n",
+        "print('mild_train_files', len(mild_train_files))\n",
+        "print('mild_val_files', len(mild_val_files))\n",
+        "print('mild_test_files', len(mild_test_files))\n",
+        "print()\n",
+        "print('severe_train_files', len(severe_train_files))\n",
+        "print('severe_val_files', len(severe_val_files))\n",
+        "print('severe_test_files', len(severe_test_files))"
+      ],
+      "metadata": {
+        "id": "tcYLGTfvNc4r",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "fe4a5701-208f-49c1-f8fb-d9d7ceda9776"
+      },
+      "execution_count": 4,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "mild_train_files 4344\n",
+            "mild_val_files 851\n",
+            "mild_test_files 1413\n",
+            "\n",
+            "severe_train_files 4344\n",
+            "severe_val_files 851\n",
+            "severe_test_files 1413\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import uuid\n",
+        "import numpy as np\n",
+        "import torch\n",
+        "import torch.nn as nn\n",
+        "import torch.nn.functional as F\n",
+        "import torch.optim as optim\n",
+        "import torchvision\n",
+        "import torchvision.transforms as transforms\n",
+        "import torchvision.models as models\n",
+        "\n",
+        "from datetime import datetime\n",
+        "from tqdm import tqdm\n",
+        "from torch.utils.data import DataLoader\n",
+        "from torch.utils.tensorboard import SummaryWriter\n",
+        "from torchvision.datasets import ImageFolder\n",
+        "\n",
+        "class EarlyStopping():\n",
+        "    def __init__(self, metric_name, tolerance=5):\n",
+        "        self.metric_name = metric_name\n",
+        "        self.tolerance = tolerance\n",
+        "        self.min_loss = np.inf\n",
+        "        self.counter = 0\n",
+        "\n",
+        "    def __call__(self, loss):\n",
+        "        if loss < self.min_loss:\n",
+        "            print(f'Best {self.metric_name} updated from {self.min_loss:.3f} to {loss:.3f}')\n",
+        "            self.counter = 0\n",
+        "            self.min_loss = loss\n",
+        "        else:\n",
+        "            self.counter += 1\n",
+        "            print(f'{self.metric_name} {loss:.3f} not better than current best ({self.min_loss:.3f}) - counter: {self.counter}')\n",
+        "            if self.counter >= self.tolerance:  \n",
+        "                return True\n",
+        "        return False\n",
+        "\n",
+        "def make_weights_for_balanced_classes(images, nclasses):                        \n",
+        "    count = [0] * nclasses                                                      \n",
+        "    for item in images:                                                         \n",
+        "        count[item[1]] += 1                                                     \n",
+        "    weight_per_class = [0.] * nclasses                                      \n",
+        "    N = float(sum(count))                                                   \n",
+        "    for i in range(nclasses):                                                   \n",
+        "        weight_per_class[i] = N/float(count[i])                                 \n",
+        "    weight = [0] * len(images)                                              \n",
+        "    for idx, val in enumerate(images):                                          \n",
+        "        weight[idx] = weight_per_class[val[1]]                                  \n",
+        "    return weight  \n",
+        "\n",
+        "DATA_DIR = './data'\n",
+        "EPOCHS = 500\n",
+        "BATCH_SIZE = 4\n",
+        "MEAN = [191.738, 137.17, 172.24]\n",
+        "STD = [46.31, 61.81, 49.34]\n",
+        "SAVE_FREQ = 5\n",
+        "LOG_DIR = f'./data/training_runs/{str(uuid.uuid4().hex)}'\n",
+        "CLASSES = ('Mild', 'Severe')\n",
+        "\n",
+        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+        "\n",
+        "!mkdir -p {LOG_DIR}/checkpoints\n",
+        "!mkdir -p drive/MyDrive/models"
+      ],
+      "metadata": {
+        "id": "1GjWBztySy9w"
+      },
+      "execution_count": 12,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "# the training transforms\n",
+        "train_transform = transforms.Compose([\n",
+        "    transforms.RandomHorizontalFlip(p=0.5),\n",
+        "    transforms.RandomVerticalFlip(p=0.5),\n",
+        "    transforms.AugMix(),\n",
+        "    transforms.ToTensor(),\n",
+        "    transforms.Normalize(mean=MEAN, std=STD)\n",
+        "])\n",
+        "# the test/validation transforms\n",
+        "test_transform = transforms.Compose([\n",
+        "    transforms.ToTensor(),\n",
+        "    transforms.Normalize(mean=MEAN, std=STD)\n",
+        "])\n",
+        "\n",
+        "# datasets\n",
+        "#weights = [0.1, 0.9]\n",
+        "#sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))\n",
+        "trainset = ImageFolder(f'{DATA_DIR}/train', transform=train_transform)\n",
+        "valset = ImageFolder(f'{DATA_DIR}/val', transform=test_transform)\n",
+        "testset = ImageFolder(f'{DATA_DIR}/test', transform=test_transform)\n",
+        "\n",
+        "# dataloaders\n",
+        "trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) #sampler=sampler\n",
+        "valloader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)\n",
+        "testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)"
+      ],
+      "metadata": {
+        "id": "0dRFG7ImTp-B"
+      },
+      "execution_count": 13,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "LOG_DIR = f'./data/training_runs/{str(uuid.uuid4().hex)}'\n",
+        "!mkdir -p {LOG_DIR}/checkpoints\n",
+        "\n",
+        "#Load pretrained model\n",
+        "model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT')\n",
+        "#Update first layer\n",
+        "model_layers = [\n",
+        "    nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)\n",
+        "]\n",
+        "model_layers.extend(list(model.features))  \n",
+        "model.features = nn.Sequential(*model_layers)\n",
+        "#Update last layer\n",
+        "number_features = model.classifier[-1].in_features\n",
+        "features = list(model.classifier.children())[:-1] # Remove last layer\n",
+        "features.extend([\n",
+        "    nn.Linear(number_features, len(CLASSES)),\n",
+        "    nn.Softmax(1)\n",
+        "])\n",
+        "model.classifier = nn.Sequential(*features)\n",
+        "\n",
+        "#checkpoint = torch.load(LOAD_MODEL_PATH)\n",
+        "#model.load_state_dict(checkpoint['model'])\n",
+        "#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
+        "#loss = checkpoint['loss']\n",
+        "\n",
+        "start_epoch = 0 #checkpoint['epoch']\n",
+        "\n",
+        "model.to(device)\n",
+        "model.train()\n",
+        "\n",
+        "#Update to Adam\n",
+        "criterion = nn.CrossEntropyLoss()\n",
+        "learning_rate = 0.0001\n",
+        "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
+        "\n",
+        "early_stop = EarlyStopping(metric_name='validation loss', tolerance=5)\n",
+        "\n",
+        "# default `log_dir` is \"training_runs\" - we'll be more specific here\n",
+        "writer = SummaryWriter(f'{LOG_DIR}/logs')\n",
+        "\n",
+        "timestamp = datetime.now().isoformat('T', 'seconds').replace(':', '')\n",
+        "for epoch in range(start_epoch, start_epoch + EPOCHS):\n",
+        "    training_loss = 0.0\n",
+        "    correct = 0.0\n",
+        "    total = 0.0\n",
+        "\n",
+        "    with tqdm(trainloader, unit=\"batch\") as tepoch:\n",
+        "        tepoch.set_description(f\"Epoch {epoch + 1}\")\n",
+        "        for inputs, labels in tepoch:\n",
+        "            inputs = inputs.to(device)\n",
+        "            labels = labels.to(device)\n",
+        "\n",
+        "            # zero the parameter gradients\n",
+        "            optimizer.zero_grad()\n",
+        "\n",
+        "            # forward + backward + optimize\n",
+        "            outputs = model(inputs)\n",
+        "            loss = criterion(outputs, labels)\n",
+        "            loss.backward()\n",
+        "            optimizer.step()\n",
+        "\n",
+        "            total += BATCH_SIZE\n",
+        "            training_loss += loss.item()\n",
+        "            predictions = outputs.argmax(dim=1, keepdim=True).squeeze()\n",
+        "            correct += (predictions == labels).sum().item()\n",
+        "            accuracy = correct / total\n",
+        "\n",
+        "            tepoch.set_postfix(loss=training_loss / total, accuracy=100. * accuracy)\n",
+        "\n",
+        "        validation_loss = 0.0\n",
+        "        val_correct = 0.0\n",
+        "        for j, val_data in enumerate(valloader, 0):\n",
+        "            val_inputs, val_labels = val_data\n",
+        "            val_inputs = val_inputs.to(device)\n",
+        "            val_labels = val_labels.to(device)\n",
+        "            val_outputs = model(val_inputs)\n",
+        "            val_loss = criterion(val_outputs, val_labels)\n",
+        "            validation_loss += val_loss.item()\n",
+        "            val_predictions = val_outputs.argmax(dim=1, keepdim=True).squeeze()\n",
+        "            val_correct += (val_predictions == val_labels).sum().item()\n",
+        "            val_accuracy = val_correct / (BATCH_SIZE * len(valloader))\n",
+        "\n",
+        "        # log the training loss\n",
+        "        writer.add_scalar('training loss',\n",
+        "                        training_loss / (BATCH_SIZE * len(trainloader)),\n",
+        "                        epoch * len(trainloader))\n",
+        "\n",
+        "        # log the training accuracy\n",
+        "        writer.add_scalar('training accuracy',\n",
+        "                        100. * accuracy,\n",
+        "                        epoch * len(trainloader))\n",
+        "\n",
+        "        # log the validation loss\n",
+        "        writer.add_scalar('validation loss',\n",
+        "                        validation_loss / (BATCH_SIZE * len(valloader)),\n",
+        "                        epoch * len(trainloader))\n",
+        "\n",
+        "        # log the training accuracy\n",
+        "        writer.add_scalar('validation accuracy',\n",
+        "                        100. * val_accuracy,\n",
+        "                        epoch * len(trainloader))\n",
+        "\n",
+        "        print(f\"Validation Accuracy: {100. * val_accuracy}\")\n",
+        "        # save a checkpoint\n",
+        "        if (epoch + 1) % SAVE_FREQ == 0:\n",
+        "            print(\"Saving Model\")\n",
+        "            model_filename = f'{LOG_DIR}/checkpoints/model_{timestamp}_{(epoch + 1)}.pth'\n",
+        "            torch.save(model, model_filename)\n",
+        "            !cp {model_filename} drive/MyDrive/\n",
+        "\n",
+        "        if early_stop(validation_loss / (BATCH_SIZE * len(valloader))):\n",
+        "            break\n",
+        "\n",
+        "print('Training Complete')\n",
+        "torch.save(model, f'{LOG_DIR}/checkpoints/model_{timestamp}_{(epoch + 1)}.pth')"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "ASL2fVEUTxya",
+        "outputId": "ab38349b-f15c-44ca-d442-ee5e34b96702"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "Epoch 1: 100%|██████████| 2172/2172 [08:11<00:00,  4.42batch/s, accuracy=70.6, loss=0.146]\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Validation Accuracy: 71.88967136150235\n",
+            "Best validation loss updated from inf to 0.145\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "Epoch 2: 100%|██████████| 2172/2172 [08:12<00:00,  4.41batch/s, accuracy=74.7, loss=0.137]\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Validation Accuracy: 72.76995305164318\n",
+            "Best validation loss updated from 0.145 to 0.142\n"
+          ]
+        },
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "Epoch 3:  85%|████████▍ | 1840/2172 [07:00<01:20,  4.12batch/s, accuracy=76.7, loss=0.133]"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        " !aws s3 cp ./data/training_runs s3://digpath-models/training_runs --recursive"
+      ],
+      "metadata": {
+        "id": "PfMaBA57VGl_"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        " "
+      ],
+      "metadata": {
+        "id": "S_1PX1jgYcnN"
+      },
+      "execution_count": null,
+      "outputs": []
+    }
+  ]
+}
\ No newline at end of file
-- 
GitLab