From fb0ee2cd435e65faa73433b573a3deea67f03633 Mon Sep 17 00:00:00 2001
From: Zachary Rahn <zr66@drexel.edu>
Date: Sun, 29 Jan 2023 16:48:27 -0500
Subject: [PATCH] adding training notebook for Google Colab
---
ml/ML_Training_Colab.ipynb | 488 +++++++++++++++++++++++++++++++++++++
1 file changed, 488 insertions(+)
create mode 100644 ml/ML_Training_Colab.ipynb
diff --git a/ml/ML_Training_Colab.ipynb b/ml/ML_Training_Colab.ipynb
new file mode 100644
index 0000000..8b242e0
--- /dev/null
+++ b/ml/ML_Training_Colab.ipynb
@@ -0,0 +1,488 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "gpuClass": "standard"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "NLG5eyci_qGY",
+ "outputId": "d1df9059-6e7c-4f73-a420-60509cffcecf"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting awscli\n",
+ " Downloading awscli-1.27.59-py3-none-any.whl (4.0 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m91.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting botocore==1.29.59\n",
+ " Downloading botocore-1.29.59-py3-none-any.whl (10.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m43.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting PyYAML<5.5,>=3.10\n",
+ " Downloading PyYAML-5.4.1-cp38-cp38-manylinux1_x86_64.whl (662 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m662.4/662.4 KB\u001b[0m \u001b[31m59.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting colorama<0.4.5,>=0.2.5\n",
+ " Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)\n",
+ "Collecting s3transfer<0.7.0,>=0.6.0\n",
+ " Downloading s3transfer-0.6.0-py3-none-any.whl (79 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.6/79.6 KB\u001b[0m \u001b[31m12.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: docutils<0.17,>=0.10 in /usr/local/lib/python3.8/dist-packages (from awscli) (0.16)\n",
+ "Collecting rsa<4.8,>=3.1.2\n",
+ " Downloading rsa-4.7.2-py3-none-any.whl (34 kB)\n",
+ "Collecting urllib3<1.27,>=1.25.4\n",
+ " Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m140.6/140.6 KB\u001b[0m \u001b[31m18.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.8/dist-packages (from botocore==1.29.59->awscli) (2.8.2)\n",
+ "Collecting jmespath<2.0.0,>=0.7.1\n",
+ " Downloading jmespath-1.0.1-py3-none-any.whl (20 kB)\n",
+ "Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.8/dist-packages (from rsa<4.8,>=3.1.2->awscli) (0.4.8)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil<3.0.0,>=2.1->botocore==1.29.59->awscli) (1.15.0)\n",
+ "Installing collected packages: urllib3, rsa, PyYAML, jmespath, colorama, botocore, s3transfer, awscli\n",
+ " Attempting uninstall: urllib3\n",
+ " Found existing installation: urllib3 1.24.3\n",
+ " Uninstalling urllib3-1.24.3:\n",
+ " Successfully uninstalled urllib3-1.24.3\n",
+ " Attempting uninstall: rsa\n",
+ " Found existing installation: rsa 4.9\n",
+ " Uninstalling rsa-4.9:\n",
+ " Successfully uninstalled rsa-4.9\n",
+ " Attempting uninstall: PyYAML\n",
+ " Found existing installation: PyYAML 6.0\n",
+ " Uninstalling PyYAML-6.0:\n",
+ " Successfully uninstalled PyYAML-6.0\n",
+ "Successfully installed PyYAML-5.4.1 awscli-1.27.59 botocore-1.29.59 colorama-0.4.4 jmespath-1.0.1 rsa-4.7.2 s3transfer-0.6.0 urllib3-1.26.14\n",
+ "/content/drive/My Drive/awscli.ini\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install -q awscli\n",
+ "import os\n",
+ "!export AWS_SHARED_CREDENTIALS_FILE=/content/drive/My\\ Drive/config/awscli.ini\n",
+ "path = \"/content/drive/My Drive/awscli.ini\"\n",
+ "os.environ['AWS_SHARED_CREDENTIALS_FILE'] = path\n",
+ "print(os.environ['AWS_SHARED_CREDENTIALS_FILE'])\n",
+ "!mkdir -p data\n",
+ "!aws s3 cp s3://digpath-chips/ data/ --recursive --quiet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import glob\n",
+ "\n",
+ "mild_train_files = glob.glob('data/train/Mild/*.png')\n",
+ "mild_val_files = glob.glob('data/val/Mild/*.png')\n",
+ "mild_test_files = glob.glob('data/test/Mild/*.png')\n",
+ "\n",
+ "severe_train_files = glob.glob('data/train/Severe/*.png')\n",
+ "severe_val_files = glob.glob('data/val/Severe/*.png')\n",
+ "severe_test_files = glob.glob('data/test/Severe/*.png')\n",
+ "simulated_files = glob.glob('data/train/Severe/*sim*.png')\n",
+ "\n",
+ "print('mild_train_files', len(mild_train_files))\n",
+ "print('mild_val_files', len(mild_val_files))\n",
+ "print('mild_test_files', len(mild_test_files))\n",
+ "print()\n",
+ "print('severe_train_files', len(severe_train_files))\n",
+ "print('severe_val_files', len(severe_val_files))\n",
+ "print('severe_test_files', len(severe_test_files))\n",
+ "print('simulated_files', len(simulated_files))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "DaRvl_fdPpNo",
+ "outputId": "53bdf005-5da7-4371-c7ad-b19af4f68ce0"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "mild_train_files 12744\n",
+ "mild_val_files 1661\n",
+ "mild_test_files 2490\n",
+ "\n",
+ "severe_train_files 4344\n",
+ "severe_val_files 851\n",
+ "severe_test_files 1413\n",
+ "simulated_files 965\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import random\n",
+ "random.shuffle(mild_train_files)\n",
+ "random.shuffle(mild_val_files)\n",
+ "random.shuffle(mild_test_files)\n",
+ "for filename in mild_train_files[len(severe_train_files):]:\n",
+ " os.remove(filename)\n",
+ "for filename in mild_val_files[len(severe_val_files):]:\n",
+ " os.remove(filename)\n",
+ "for filename in mild_test_files[len(severe_test_files):]:\n",
+ " os.remove(filename)\n",
+ "\n",
+ "mild_train_files = glob.glob('data/train/Mild/*.png')\n",
+ "mild_val_files = glob.glob('data/val/Mild/*.png')\n",
+ "mild_test_files = glob.glob('data/test/Mild/*.png')\n",
+ "\n",
+ "severe_train_files = glob.glob('data/train/Severe/*.png')\n",
+ "severe_val_files = glob.glob('data/val/Severe/*.png')\n",
+ "severe_test_files = glob.glob('data/test/Severe/*.png')\n",
+ "simulated_files = glob.glob('data/train/Severe/*sim*.png')\n",
+ "\n",
+ "print('mild_train_files', len(mild_train_files))\n",
+ "print('mild_val_files', len(mild_val_files))\n",
+ "print('mild_test_files', len(mild_test_files))\n",
+ "print()\n",
+ "print('severe_train_files', len(severe_train_files))\n",
+ "print('severe_val_files', len(severe_val_files))\n",
+ "print('severe_test_files', len(severe_test_files))"
+ ],
+ "metadata": {
+ "id": "tcYLGTfvNc4r",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "fe4a5701-208f-49c1-f8fb-d9d7ceda9776"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "mild_train_files 4344\n",
+ "mild_val_files 851\n",
+ "mild_test_files 1413\n",
+ "\n",
+ "severe_train_files 4344\n",
+ "severe_val_files 851\n",
+ "severe_test_files 1413\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import uuid\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "import torch.optim as optim\n",
+ "import torchvision\n",
+ "import torchvision.transforms as transforms\n",
+ "import torchvision.models as models\n",
+ "\n",
+ "from datetime import datetime\n",
+ "from tqdm import tqdm\n",
+ "from torch.utils.data import DataLoader\n",
+ "from torch.utils.tensorboard import SummaryWriter\n",
+ "from torchvision.datasets import ImageFolder\n",
+ "\n",
+ "class EarlyStopping():\n",
+ " def __init__(self, metric_name, tolerance=5):\n",
+ " self.metric_name = metric_name\n",
+ " self.tolerance = tolerance\n",
+ " self.min_loss = np.inf\n",
+ " self.counter = 0\n",
+ "\n",
+ " def __call__(self, loss):\n",
+ " if loss < self.min_loss:\n",
+ " print(f'Best {self.metric_name} updated from {self.min_loss:.3f} to {loss:.3f}')\n",
+ " self.counter = 0\n",
+ " self.min_loss = loss\n",
+ " else:\n",
+ " self.counter += 1\n",
+ " print(f'{self.metric_name} {loss:.3f} not better than current best ({self.min_loss:.3f}) - counter: {self.counter}')\n",
+ " if self.counter >= self.tolerance: \n",
+ " return True\n",
+ " return False\n",
+ "\n",
+ "def make_weights_for_balanced_classes(images, nclasses): \n",
+ " count = [0] * nclasses \n",
+ " for item in images: \n",
+ " count[item[1]] += 1 \n",
+ " weight_per_class = [0.] * nclasses \n",
+ " N = float(sum(count)) \n",
+ " for i in range(nclasses): \n",
+ " weight_per_class[i] = N/float(count[i]) \n",
+ " weight = [0] * len(images) \n",
+ " for idx, val in enumerate(images): \n",
+ " weight[idx] = weight_per_class[val[1]] \n",
+ " return weight \n",
+ "\n",
+ "DATA_DIR = './data'\n",
+ "EPOCHS = 500\n",
+ "BATCH_SIZE = 4\n",
+ "MEAN = [191.738, 137.17, 172.24]\n",
+ "STD = [46.31, 61.81, 49.34]\n",
+ "SAVE_FREQ = 5\n",
+ "LOG_DIR = f'./data/training_runs/{str(uuid.uuid4().hex)}'\n",
+ "CLASSES = ('Mild', 'Severe')\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "\n",
+ "!mkdir -p {LOG_DIR}/checkpoints\n",
+ "!mkdir -p drive/MyDrive/models"
+ ],
+ "metadata": {
+ "id": "1GjWBztySy9w"
+ },
+ "execution_count": 12,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# the training transforms\n",
+ "train_transform = transforms.Compose([\n",
+ " transforms.RandomHorizontalFlip(p=0.5),\n",
+ " transforms.RandomVerticalFlip(p=0.5),\n",
+ " transforms.AugMix(),\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(mean=MEAN, std=STD)\n",
+ "])\n",
+ "# the test/validation transforms\n",
+ "test_transform = transforms.Compose([\n",
+ " transforms.ToTensor(),\n",
+ " transforms.Normalize(mean=MEAN, std=STD)\n",
+ "])\n",
+ "\n",
+ "# datasets\n",
+ "#weights = [0.1, 0.9]\n",
+ "#sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))\n",
+ "trainset = ImageFolder(f'{DATA_DIR}/train', transform=train_transform)\n",
+ "valset = ImageFolder(f'{DATA_DIR}/val', transform=test_transform)\n",
+ "testset = ImageFolder(f'{DATA_DIR}/test', transform=test_transform)\n",
+ "\n",
+ "# dataloaders\n",
+ "trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) #sampler=sampler\n",
+ "valloader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)\n",
+ "testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)"
+ ],
+ "metadata": {
+ "id": "0dRFG7ImTp-B"
+ },
+ "execution_count": 13,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "LOG_DIR = f'./data/training_runs/{str(uuid.uuid4().hex)}'\n",
+ "!mkdir -p {LOG_DIR}/checkpoints\n",
+ "\n",
+ "#Load pretrained model\n",
+ "model = models.efficientnet_b0(weights='EfficientNet_B0_Weights.DEFAULT')\n",
+ "#Update first layer\n",
+ "model_layers = [\n",
+ " nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True)\n",
+ "]\n",
+ "model_layers.extend(list(model.features)) \n",
+ "model.features = nn.Sequential(*model_layers)\n",
+ "#Update last layer\n",
+ "number_features = model.classifier[-1].in_features\n",
+ "features = list(model.classifier.children())[:-1] # Remove last layer\n",
+ "features.extend([\n",
+ " nn.Linear(number_features, len(CLASSES)),\n",
+ " nn.Softmax(1)\n",
+ "])\n",
+ "model.classifier = nn.Sequential(*features)\n",
+ "\n",
+ "#checkpoint = torch.load(LOAD_MODEL_PATH)\n",
+ "#model.load_state_dict(checkpoint['model'])\n",
+ "#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
+ "#loss = checkpoint['loss']\n",
+ "\n",
+ "start_epoch = 0 #checkpoint['epoch']\n",
+ "\n",
+ "model.to(device)\n",
+ "model.train()\n",
+ "\n",
+ "#Update to Adam\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "learning_rate = 0.0001\n",
+ "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
+ "\n",
+ "early_stop = EarlyStopping(metric_name='validation loss', tolerance=5)\n",
+ "\n",
+ "# default `log_dir` is \"training_runs\" - we'll be more specific here\n",
+ "writer = SummaryWriter(f'{LOG_DIR}/logs')\n",
+ "\n",
+ "timestamp = datetime.now().isoformat('T', 'seconds').replace(':', '')\n",
+ "for epoch in range(start_epoch, start_epoch + EPOCHS):\n",
+ " training_loss = 0.0\n",
+ " correct = 0.0\n",
+ " total = 0.0\n",
+ "\n",
+ " with tqdm(trainloader, unit=\"batch\") as tepoch:\n",
+ " tepoch.set_description(f\"Epoch {epoch + 1}\")\n",
+ " for inputs, labels in tepoch:\n",
+ " inputs = inputs.to(device)\n",
+ " labels = labels.to(device)\n",
+ "\n",
+ " # zero the parameter gradients\n",
+ " optimizer.zero_grad()\n",
+ "\n",
+ " # forward + backward + optimize\n",
+ " outputs = model(inputs)\n",
+ " loss = criterion(outputs, labels)\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " total += BATCH_SIZE\n",
+ " training_loss += loss.item()\n",
+ " predictions = outputs.argmax(dim=1, keepdim=True).squeeze()\n",
+ " correct += (predictions == labels).sum().item()\n",
+ " accuracy = correct / total\n",
+ "\n",
+ " tepoch.set_postfix(loss=training_loss / total, accuracy=100. * accuracy)\n",
+ "\n",
+ " validation_loss = 0.0\n",
+ " val_correct = 0.0\n",
+ " for j, val_data in enumerate(valloader, 0):\n",
+ " val_inputs, val_labels = val_data\n",
+ " val_inputs = val_inputs.to(device)\n",
+ " val_labels = val_labels.to(device)\n",
+ " val_outputs = model(val_inputs)\n",
+ " val_loss = criterion(val_outputs, val_labels)\n",
+ " validation_loss += val_loss.item()\n",
+ " val_predictions = val_outputs.argmax(dim=1, keepdim=True).squeeze()\n",
+ " val_correct += (val_predictions == val_labels).sum().item()\n",
+ " val_accuracy = val_correct / (BATCH_SIZE * len(valloader))\n",
+ "\n",
+ " # log the training loss\n",
+ " writer.add_scalar('training loss',\n",
+ " training_loss / (BATCH_SIZE * len(trainloader)),\n",
+ " epoch * len(trainloader))\n",
+ "\n",
+ " # log the training accuracy\n",
+ " writer.add_scalar('training accuracy',\n",
+ " 100. * accuracy,\n",
+ " epoch * len(trainloader))\n",
+ "\n",
+ " # log the validation loss\n",
+ " writer.add_scalar('validation loss',\n",
+ " validation_loss / (BATCH_SIZE * len(valloader)),\n",
+ " epoch * len(trainloader))\n",
+ "\n",
+ " # log the training accuracy\n",
+ " writer.add_scalar('validation accuracy',\n",
+ " 100. * val_accuracy,\n",
+ " epoch * len(trainloader))\n",
+ "\n",
+ " print(f\"Validation Accuracy: {100. * val_accuracy}\")\n",
+ " # save a checkpoint\n",
+ " if (epoch + 1) % SAVE_FREQ == 0:\n",
+ " print(\"Saving Model\")\n",
+ " model_filename = f'{LOG_DIR}/checkpoints/model_{timestamp}_{(epoch + 1)}.pth'\n",
+ " torch.save(model, model_filename)\n",
+ " !cp {model_filename} drive/MyDrive/\n",
+ "\n",
+ " if early_stop(validation_loss / (BATCH_SIZE * len(valloader))):\n",
+ " break\n",
+ "\n",
+ "print('Training Complete')\n",
+ "torch.save(model, f'{LOG_DIR}/checkpoints/model_{timestamp}_{(epoch + 1)}.pth')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ASL2fVEUTxya",
+ "outputId": "ab38349b-f15c-44ca-d442-ee5e34b96702"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Epoch 1: 100%|██████████| 2172/2172 [08:11<00:00, 4.42batch/s, accuracy=70.6, loss=0.146]\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Validation Accuracy: 71.88967136150235\n",
+ "Best validation loss updated from inf to 0.145\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Epoch 2: 100%|██████████| 2172/2172 [08:12<00:00, 4.41batch/s, accuracy=74.7, loss=0.137]\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Validation Accuracy: 72.76995305164318\n",
+ "Best validation loss updated from 0.145 to 0.142\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "Epoch 3: 85%|████████▍ | 1840/2172 [07:00<01:20, 4.12batch/s, accuracy=76.7, loss=0.133]"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ " !aws s3 cp ./data/training_runs s3://digpath-models/training_runs --recursive"
+ ],
+ "metadata": {
+ "id": "PfMaBA57VGl_"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ " "
+ ],
+ "metadata": {
+ "id": "S_1PX1jgYcnN"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
--
GitLab