Building Neural Networks to Detect Eye Diseases

💡

The source code is here, in Colab. You'll need a GPU to run it, but everything else should be ready to go. If something breaks, just email me: ben@beneverman.com📧.

Doctors Can't Keep Up

Healthcare is constantly changing and evolving. Research moves fast: in a 2011 publication, Densen estimated that by 2020, medical knowledge would double every 73 days, a marked departure from the 1950s, when doubling took 50 yearsDensen P, (2011). For the average physician, though some continuing education is required to maintain licensure and board certification, they must fight to keep pace with the rate of new research while they juggle heavy caseloads, administrative responsibility, and personal care.

"This growing aggregation of advances and reversals presents a significant challenge to physicians attempting to stay up to date. Historically, there has been an average 17-year lag between medical discoveries and implementation into clinical practice. With the acceleration of changes in clinical medicine, coupled with normal lag times in dissemination, there is a higher probability than ever before that physicians, within just a few years of leaving their training, may not be practicing contemporary standards of care." Laiteerapong, Neda, and Elbert S Huang. (2015)

With new technology like Machine Learning (ML) and Artificial Intelligence (AI), what are the implications for the average physician? Can we reduce the 17-year lag?

I think we can. How? By building clinical Decision Support Systems (DSS)Sutton, et al. (2020). While older decision support systems date back to the British in World War II, a modern example is NASA's mission control. While astronauts pilot air and spacecraft, mission control gathers and manages vast data sets, utilizes advanced models for analysis, offers user-friendly interfaces for astronaut interaction, and provides crucial support for real-time decision-making. Observing NASA in action - throughout the Apollo program, 32 astronauts were assigned to fly missions. At a 1:100 ratio, one would estimate as many as 3200 staff on mission control. The actual number? 400,000.

So how can we use ML and AI to support clinical decision making? One way is using computer vision. With neural networks that take images as input, we can batch-process large amounts of medical imagery and use the network to return the much smaller subset that requires actual physician review. The goal is not to diagnose, but to identify what needs physician review while providing the physician with as much supporting information as possible.

Using Math to See Disease

This project was inspired by the work of Ophthalytics. In their recent paperBajwa, et al. (2023) they trained a Convolutional Neural Network (CNN), commonly used in computer vision, to detect Diabetic Retinopathy. DR is a complication of diabetes that affects the retina and can lead to vision loss if not detected and treated early. In the paper, the model takes close-up retinal fundus images as input and classifies them as DR-Positive or DR-Negative. See the figure below.

ophth figure 3 (a) DR Negative and (b) DR Positive

I decided I wanted to build a similar eye disease classifier. My goals for this project were twofold:

  1. Train a classifier that can detect multiple eye diseases
  2. Train the model using minimal compute, as fast as possible

Implementing a CNN for Eye Disease Classification

Since my goal was to minimize compute costs, I chose the EfficientNet architecture. The original EfficientNet paperTan and Le. (2019) was published in 2019. The EfficientNet architecture uses a compound scaling method that uniformly scales all architectural dimensions of the CNN (depth/width/resolution).The big benefit of the Efficient is that its optimized for performance/compute cost. At the time of publication, their largest model, the EfficientNetB7 was deemed "8.4x smaller and 6.1x faster on inference than the best existing ConvNet"Tan and Le. (2019). Here's a figure from the paper:

EfficientNet

Notice the Y axis is Accuracy and the X axis is FLOPs - a direct measure for compute cost. FLOPs stands for Floating Point Operations, which is the total number of floating-point arithmetic operations performed by (inference) or required by (training) the model. For context, the speculated number of FLOPs used to train GPT-4 (aka ChatGPT) was ~2.15e25 or 21,500,000 EFLOPs, which for $1/A100 hour is estimated at $63 Million.

For this project, I used an A100 on Google Colab, which costs me about $9.99 a month. The isolated cost of this project would be much lower (<$2 maybe?). It turns out that being poor does force you to be more creative.

More Architecture Decisions

I realized I could leverage transfer learning by taking a pre-trained EfficientNet and continuing its training on my eye disease dataset (AKA fine-tuning). This has a few benefits:

  1. Reduced compute cost (time and throughput) that would be incurred by pre-training a CNN
  2. Prevent overfitting on my small dataset during pre-training
  3. Increase the speed I could iterate and finish this project.

My goal here was not to train the most performant classifier. Maybe I'll return to that in the future, for now, it's less relevant. The easiest way to boost performance would be to use a larger base model, which I did play around with.

Update 3/31: Since writing this, I've read much more literature on retinal fundus imaging and computer vision. The gold standard model architecture seems to be a U-Net with a pretrained endcoder (ImageNet) for segmentation, connected to a pre-trained classifier (ImageNet). If I go any further with this project, this is the direction I'll take. Lack of annotated data is the biggest issue.

The Code, Explained

💡

Again, Here is the link to the code. I've set it up so that you should be able to download the data, train/fine-tune the model, and test it pretty easily. NOTE: You'll probably need an A100 at least, which means you need Colab Pro. You could probably get away with a T4 (free) for the smallest model, EfficientNet-B0. Email me if you have any issues or questions! ben@beneverman.com📧

Setup and Configiuration

To set up the project we'll need to do a few things:

  1. Make sure we're using a GPU runtime (Runtime -> change runtime type)
  2. Get the data from the cloud onto the runtime, and unzip it
  3. Install missing dependencies
1!nvidia-smi # make sure you've got a GPU runtime selected (I used an A100)
1!nvidia-smi # make sure you've got a GPU runtime selected (I used an A100)

This will show you the GPU your're using. Here's what an example output looks like:

1+---------------------------------------------------------------------------------------+
2| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
3|-----------------------------------------+----------------------+----------------------+
4| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
5| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
6| | | MIG M. |
7|=========================================+======================+======================|
8| 0 NVIDIA A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 |
9| N/A 32C P0 44W / 400W | 2MiB / 40960MiB | 0% Default |
10| | | Disabled |
11+-----------------------------------------+----------------------+----------------------+
12
13+---------------------------------------------------------------------------------------+
14| Processes: |
15| GPU GI CI PID Type Process name GPU Memory |
16| ID ID Usage |
17|=======================================================================================|
18| No running processes found |
19+---------------------------------------------------------------------------------------+
1+---------------------------------------------------------------------------------------+
2| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |
3|-----------------------------------------+----------------------+----------------------+
4| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
5| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
6| | | MIG M. |
7|=========================================+======================+======================|
8| 0 NVIDIA A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 |
9| N/A 32C P0 44W / 400W | 2MiB / 40960MiB | 0% Default |
10| | | Disabled |
11+-----------------------------------------+----------------------+----------------------+
12
13+---------------------------------------------------------------------------------------+
14| Processes: |
15| GPU GI CI PID Type Process name GPU Memory |
16| ID ID Usage |
17|=======================================================================================|
18| No running processes found |
19+---------------------------------------------------------------------------------------+

For downloading the data, I've included a pre-signed S3 link that you can use to make things painless. Just run the cell.

1!curl {URL IS IN THE COLAB} -o data
1!curl {URL IS IN THE COLAB} -o data

Unzip the data

1!unzip data
1!unzip data

Install the necessary libraries. Everything we need comes preloaded in Colab, except for timm, which is a library that provides pre-trained image models.

1!pip install timm
1!pip install timm
💡

Tip! In colab you can click the folder icon on the left to see the files in the runtime. Or, you can use the !ls command to list the files in the current directory and cd to change directories

Training the Model

First things first, lets import all our dependencies.

1import os # filesystem
2import timm # pretrained image models
3import torch #PyTorch, deep learning framework
4import torch.nn as nn
5import torch.nn.functional as F
6import torch.optim as optim
7from torchvision import datasets, transforms # PyTorch vision utils
8from torch.utils.data import DataLoader
9from sklearn.model_selection import train_test_split # SciKit Learn for splitting the data
10from time import perf_counter # to time stuff
1import os # filesystem
2import timm # pretrained image models
3import torch #PyTorch, deep learning framework
4import torch.nn as nn
5import torch.nn.functional as F
6import torch.optim as optim
7from torchvision import datasets, transforms # PyTorch vision utils
8from torch.utils.data import DataLoader
9from sklearn.model_selection import train_test_split # SciKit Learn for splitting the data
10from time import perf_counter # to time stuff

Let's do a sanity check to make sure the data is where we expect it to be.

1DATA_DIR = 'dataset' # the root directory that the data was unzipped to
2print(os.listdir(DATA_DIR)) # list the contents of the directory
1DATA_DIR = 'dataset' # the root directory that the data was unzipped to
2print(os.listdir(DATA_DIR)) # list the contents of the directory

The ouput should be ['glaucoma', 'cataract', 'diabetic_retinopathy', 'normal']

Next, let's take the gpu and put it in a variable that we can use later to transfer onto the GPU.

1device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set the device (GPU) to move model and tensors to later
1device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set the device (GPU) to move model and tensors to later

Okay, now let's preprocess the data into datasets. We'll do some minimal preprocessing before we split the data into train, validation, and test sets. Before we do this, it can be helpful to look at the data and see what we're working with.Andrej Karpathy (AKA the GOAT 🐐) has a fantastic blog post that includes how and why you should manually inspect your data.

Here's a way we can view some of the images inside this notebook, using pillow. Try to experiment with the data (look at the directory structure, view some images, etc.) to get a sense of what you're working with.

1from PIL import Image
2from IPython.display import display
3
4folders = ["cataract", "glaucoma", "diabetic_retinopathy", "normal"]
5for folder in folders:
6 files = os.listdir(os.path.join(DATA_DIR, folder))
7 if len(files) > 0:
8 print(f"{folder}: {files[0]}")
9 img = Image.open(os.path.join(DATA_DIR, folder, files[0]))
10 display(img)
1from PIL import Image
2from IPython.display import display
3
4folders = ["cataract", "glaucoma", "diabetic_retinopathy", "normal"]
5for folder in folders:
6 files = os.listdir(os.path.join(DATA_DIR, folder))
7 if len(files) > 0:
8 print(f"{folder}: {files[0]}")
9 img = Image.open(os.path.join(DATA_DIR, folder, files[0]))
10 display(img)

Now we can define a function to make the datasets.

1def make_datasets(path: str):
2 print(f"Making datasets from {path}")
3 if not os.path.exists(path): raise FileNotFoundError(f"Path {path} does not exist") # check if the path exists
4 assert set(os.listdir(path)) == set(['cataract', 'glaucoma', 'diabetic_retinopathy', 'normal']), "wrong dataset maybe??" # check if the dataset is correct
5
6 transform = transforms.Compose([
7 # the following line resizes the images from their original dimensions dimX x dimY to 384x384. Each model has different input sizes.
8 transforms.Resize((384, 384)), # for B4, see all sizes here: https://discuss.pytorch.org/t/input-size-for-efficientnet-versions-from-torchvision-models/140525
9 transforms.ToTensor(), # convert to tensor
10 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # normalize the images so
11 ])
12 full_dataset = datasets.ImageFolder(path, transform=transform)
13
14 # time to split the dat into train, validation, and test sets
15 # here we're using the sklearn train_test_split function to split the data into 70% train, 30% val/test
16 train_indices, temp_indices = train_test_split(
17 range(len(full_dataset)),
18 test_size=0.3, # will split this into validation and test later
19 stratify=full_dataset.targets, # stratify the split (maintain the same distribution of classes) (targets are class indices)
20 random_state=42 # for reproducibility (the randomness will be the same each time)
21 )
22
23 # further split the leftover data into validation and test (15% each)
24 val_indices, test_indices = train_test_split(
25 temp_indices,
26 test_size=0.5, # split the temp_indices into validation and test
27 stratify=[full_dataset.targets[i] for i in temp_indices], # stratify the split (maintain the same distribution of classes)
28 random_state=42 # for reproducibility
29 )
30
31 # now lets create the datasets for torch
32 train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
33 val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
34 test_dataset = torch.utils.data.Subset(full_dataset, test_indices)
35
36 return train_dataset, val_dataset, test_dataset
1def make_datasets(path: str):
2 print(f"Making datasets from {path}")
3 if not os.path.exists(path): raise FileNotFoundError(f"Path {path} does not exist") # check if the path exists
4 assert set(os.listdir(path)) == set(['cataract', 'glaucoma', 'diabetic_retinopathy', 'normal']), "wrong dataset maybe??" # check if the dataset is correct
5
6 transform = transforms.Compose([
7 # the following line resizes the images from their original dimensions dimX x dimY to 384x384. Each model has different input sizes.
8 transforms.Resize((384, 384)), # for B4, see all sizes here: https://discuss.pytorch.org/t/input-size-for-efficientnet-versions-from-torchvision-models/140525
9 transforms.ToTensor(), # convert to tensor
10 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # normalize the images so
11 ])
12 full_dataset = datasets.ImageFolder(path, transform=transform)
13
14 # time to split the dat into train, validation, and test sets
15 # here we're using the sklearn train_test_split function to split the data into 70% train, 30% val/test
16 train_indices, temp_indices = train_test_split(
17 range(len(full_dataset)),
18 test_size=0.3, # will split this into validation and test later
19 stratify=full_dataset.targets, # stratify the split (maintain the same distribution of classes) (targets are class indices)
20 random_state=42 # for reproducibility (the randomness will be the same each time)
21 )
22
23 # further split the leftover data into validation and test (15% each)
24 val_indices, test_indices = train_test_split(
25 temp_indices,
26 test_size=0.5, # split the temp_indices into validation and test
27 stratify=[full_dataset.targets[i] for i in temp_indices], # stratify the split (maintain the same distribution of classes)
28 random_state=42 # for reproducibility
29 )
30
31 # now lets create the datasets for torch
32 train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
33 val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
34 test_dataset = torch.utils.data.Subset(full_dataset, test_indices)
35
36 return train_dataset, val_dataset, test_dataset

Now we can make the datasets and dataloaders. Usually a batch size of 32 or 64 works best. Here's a detailed explanation from Weights & Biases

1train_dataset, val_dataset, test_dataset = make_datasets(DATA_DIR)
2
3BATCH_SIZE = 32
4
5train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
6val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the validation set (for reproducibility across architecture/other changes)
7test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the test set
1train_dataset, val_dataset, test_dataset = make_datasets(DATA_DIR)
2
3BATCH_SIZE = 32
4
5train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
6val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the validation set (for reproducibility across architecture/other changes)
7test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the test set

Let's set our model hyperparams.

1EPOCHS = 10 # the number of epochs to train for
2LR = 0.0001 # the learning rate (how much the model adjusts its weights during training/backpropgation)
1EPOCHS = 10 # the number of epochs to train for
2LR = 0.0001 # the learning rate (how much the model adjusts its weights during training/backpropgation)

Loading some different models to mess around with. Optionally, you can list all the available models with timm.list_models().

1example_models = ["efficientnet_b0", "efficientnet_b3", "efficientnet_b4", "vit_large_patch16_224"]
2model = timm.create_model(example_models[2], pretrained=True, num_classes=4) # were using the EfficientNet-B4 model
3model = model.to(device) # move the model to the GPU
1example_models = ["efficientnet_b0", "efficientnet_b3", "efficientnet_b4", "vit_large_patch16_224"]
2model = timm.create_model(example_models[2], pretrained=True, num_classes=4) # were using the EfficientNet-B4 model
3model = model.to(device) # move the model to the GPU

Lets define the loss function and optimizer. We're using the CrossEntropyLoss function, which is commonly used for multi-class classification problems. The optimizer is Adam, which is a popular choice for deep learning models. You can read more about how deep learning works in this fantastic book by Francois Fleuret, The Little Book of Deep Learning. I would recommend trying to get the code to work first, and then trying to reverse engineer it and understand how/why it works.

1criterion = nn.CrossEntropyLoss() # categorical cross-entropy loss for multi-class classification
2optimizer = optim.Adam(model.parameters(), lr=LR) # Adam optimizer
1criterion = nn.CrossEntropyLoss() # categorical cross-entropy loss for multi-class classification
2optimizer = optim.Adam(model.parameters(), lr=LR) # Adam optimizer

Okay, this next cell is pretty important. This is the training loop where we do a few things, iterating over a number of epochs (Epochs).

  1. Set the model to training mode (this is important because we need to make a backward pass during training, to update the model weights)
  2. Get the batch of inputs and targets (images and labels), move them to the GPU
  3. Zero the gradients (this is important because PyTorch accumulates gradients by default)
  4. Forward pass the inputs through the model ouput = model(inputs)
  5. Measure how close the model got to the correct answer loss = criterion(output, targets)
  6. Update the model weights based on the gradient of the loss loss.backward(); optimizer.step()
  7. Print the epoch and training loss.

Now into the validation loop. This is similar to the training loop, but we don't update the model weights. We're just measuring how well the model is doing on unseen data.

  1. Set the model to evaluation mode (this is important because we don't want to update the model weights)

  2. with torch.no_grad() tells PyTorch not to calculate gradients (we don't need them) which saves compute

  3. Forward pass the inputs through the model output = model(inputs)

  4. Measure how close the model got to the correct answer loss = criterion(output, targets)

  5. Print the epoch and validation loss.

  6. Important: I've addded early stopping. This is a technique to prevent overfitting (the bane of every ML engineer's existance). If the validation loss doesn't improve for a certain number of epochs, we stop training as to not overfit the model.

1def train(early_stopping=False, es_tol=.05):
2 start = perf_counter()
3 best_val_loss = float("inf") # arbitrarily high value
4 for epoch in range(EPOCHS):
5 model.train() # training mode (grads)
6 running_train_loss = 0.0
7 for i, (inputs, targets) in enumerate(train_loader):
8 print(f"Batch {i+1}/{len(train_loader)}", end="\r")
9 inputs, targets = inputs.to(device), targets.to(device) # move to device
10 optimizer.zero_grad() # reset gradients
11
12 outputs = model(inputs) # forward pass
13 loss = criterion(outputs, targets) # compute loss
14 loss.backward()
15 optimizer.step()
16
17 running_train_loss += loss.item() * inputs.size(0) # multiply by batch size
18
19 train_loss = running_train_loss / len(train_loader.dataset) # divide by total number of samples
20 print(f"Epoch {epoch+1}/{EPOCHS} - train loss: {train_loss:.4f}")
21
22 model.eval()
23 running_validation_loss = 0.0
24 correct = 0
25 total = 0
26
27 with torch.no_grad():
28 for i, (inputs, targets) in enumerate(val_loader):
29 inputs, targets = inputs.to(device), targets.to(device)
30 outputs = model(inputs)
31 loss = criterion(outputs, targets)
32 running_validation_loss += loss.item() * inputs.size(0)
33
34 _, predicted = torch.max(outputs, 1) # get the index of the logprobs
35 total += targets.size(0) # add the number of targets in this batch
36 correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
37
38 val_loss = running_validation_loss / len(val_loader.dataset) # divide by total number of samples
39 val_acc = correct / total
40 print(f"Epoch {epoch+1}/{EPOCHS} - validation loss: {val_loss:.4f}, validation accuracy: {val_acc:.4f}")
41
42 if val_loss - best_val_loss > es_tol: # if we're more tan (es_tol) greater than best val loss, stop
43 print(f"Early stopping at epoch {epoch}")
44 break
45
46 best_val_loss = min(best_val_loss, val_loss)
47
48 print(f"Total Time: {perf_counter()-start:0.2f}s")
1def train(early_stopping=False, es_tol=.05):
2 start = perf_counter()
3 best_val_loss = float("inf") # arbitrarily high value
4 for epoch in range(EPOCHS):
5 model.train() # training mode (grads)
6 running_train_loss = 0.0
7 for i, (inputs, targets) in enumerate(train_loader):
8 print(f"Batch {i+1}/{len(train_loader)}", end="\r")
9 inputs, targets = inputs.to(device), targets.to(device) # move to device
10 optimizer.zero_grad() # reset gradients
11
12 outputs = model(inputs) # forward pass
13 loss = criterion(outputs, targets) # compute loss
14 loss.backward()
15 optimizer.step()
16
17 running_train_loss += loss.item() * inputs.size(0) # multiply by batch size
18
19 train_loss = running_train_loss / len(train_loader.dataset) # divide by total number of samples
20 print(f"Epoch {epoch+1}/{EPOCHS} - train loss: {train_loss:.4f}")
21
22 model.eval()
23 running_validation_loss = 0.0
24 correct = 0
25 total = 0
26
27 with torch.no_grad():
28 for i, (inputs, targets) in enumerate(val_loader):
29 inputs, targets = inputs.to(device), targets.to(device)
30 outputs = model(inputs)
31 loss = criterion(outputs, targets)
32 running_validation_loss += loss.item() * inputs.size(0)
33
34 _, predicted = torch.max(outputs, 1) # get the index of the logprobs
35 total += targets.size(0) # add the number of targets in this batch
36 correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
37
38 val_loss = running_validation_loss / len(val_loader.dataset) # divide by total number of samples
39 val_acc = correct / total
40 print(f"Epoch {epoch+1}/{EPOCHS} - validation loss: {val_loss:.4f}, validation accuracy: {val_acc:.4f}")
41
42 if val_loss - best_val_loss > es_tol: # if we're more tan (es_tol) greater than best val loss, stop
43 print(f"Early stopping at epoch {epoch}")
44 break
45
46 best_val_loss = min(best_val_loss, val_loss)
47
48 print(f"Total Time: {perf_counter()-start:0.2f}s")

Now we can train the model. This will take a minute, so go grab a coffee or something.

🚫

If you are NOT using a GPU runtime (!nvidia-smi should show a GPU), this will take a long time (hours or days) and might never finish. If you're on a CPU, you need to restart the runtime and selecting a GPU runtime, before working back to this point (you'll need to re-run all previous cells).

1train(early_stopping=True)
1train(early_stopping=True)

Here's an example output from the training loop:

1Epoch 1/10 - train loss: 0.7035
2Epoch 1/10 - validation loss: 0.3994, validation accuracy: 0.8468
3Epoch 2/10 - train loss: 0.2524
4Epoch 2/10 - validation loss: 0.2609, validation accuracy: 0.9052
5Epoch 3/10 - train loss: 0.1327
6Epoch 3/10 - validation loss: 0.2327, validation accuracy: 0.9163
7Epoch 4/10 - train loss: 0.0651
8Epoch 4/10 - validation loss: 0.2136, validation accuracy: 0.9305
9Epoch 5/10 - train loss: 0.0430
10Epoch 5/10 - validation loss: 0.2282, validation accuracy: 0.9289
11Epoch 6/10 - train loss: 0.0216
12Epoch 6/10 - validation loss: 0.2334, validation accuracy: 0.9321
13Epoch 7/10 - train loss: 0.0159
14Epoch 7/10 - validation loss: 0.2580, validation accuracy: 0.9321
15Epoch 8/10 - train loss: 0.0156
16Epoch 8/10 - validation loss: 0.2603, validation accuracy: 0.9336
17Epoch 9/10 - train loss: 0.0081
18Epoch 9/10 - validation loss: 0.2560, validation accuracy: 0.9352
19Epoch 10/10 - train loss: 0.0087
20Epoch 10/10 - validation loss: 0.2601, validation accuracy: 0.9352
21Total Time: 600.57s
1Epoch 1/10 - train loss: 0.7035
2Epoch 1/10 - validation loss: 0.3994, validation accuracy: 0.8468
3Epoch 2/10 - train loss: 0.2524
4Epoch 2/10 - validation loss: 0.2609, validation accuracy: 0.9052
5Epoch 3/10 - train loss: 0.1327
6Epoch 3/10 - validation loss: 0.2327, validation accuracy: 0.9163
7Epoch 4/10 - train loss: 0.0651
8Epoch 4/10 - validation loss: 0.2136, validation accuracy: 0.9305
9Epoch 5/10 - train loss: 0.0430
10Epoch 5/10 - validation loss: 0.2282, validation accuracy: 0.9289
11Epoch 6/10 - train loss: 0.0216
12Epoch 6/10 - validation loss: 0.2334, validation accuracy: 0.9321
13Epoch 7/10 - train loss: 0.0159
14Epoch 7/10 - validation loss: 0.2580, validation accuracy: 0.9321
15Epoch 8/10 - train loss: 0.0156
16Epoch 8/10 - validation loss: 0.2603, validation accuracy: 0.9336
17Epoch 9/10 - train loss: 0.0081
18Epoch 9/10 - validation loss: 0.2560, validation accuracy: 0.9352
19Epoch 10/10 - train loss: 0.0087
20Epoch 10/10 - validation loss: 0.2601, validation accuracy: 0.9352
21Total Time: 600.57s

Now we need to test the model on a completely separate dataset. Even though the training loss looks good, the model could be overfitting, meaning it's memorizing the training data and not generalizing well to unseen data. This test loop is basically the same as the validation loop, except we're calculating total loss with the ground truth labels.

1def test_model():
2 model.eval()
3 running_test_loss = 0.0
4 correct = 0
5 total = 0
6
7 with torch.no_grad():
8 for i, (inputs, targets) in enumerate(test_loader):
9 inputs, targets = inputs.to(device), targets.to(device)
10 outputs = model(inputs)
11 loss = criterion(outputs, targets)
12 running_test_loss += loss.item() * inputs.size(0)
13
14 _, predicted = torch.max(outputs, 1) # get the index of the max log-probability (argmax)
15 total += targets.size(0) # add the number of targets in this batch
16 correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
17
18 test_loss = running_test_loss / len(test_loader.dataset) # divide by total number of samples
19 test_acc = correct / total
20 print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.4f}")
1def test_model():
2 model.eval()
3 running_test_loss = 0.0
4 correct = 0
5 total = 0
6
7 with torch.no_grad():
8 for i, (inputs, targets) in enumerate(test_loader):
9 inputs, targets = inputs.to(device), targets.to(device)
10 outputs = model(inputs)
11 loss = criterion(outputs, targets)
12 running_test_loss += loss.item() * inputs.size(0)
13
14 _, predicted = torch.max(outputs, 1) # get the index of the max log-probability (argmax)
15 total += targets.size(0) # add the number of targets in this batch
16 correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
17
18 test_loss = running_test_loss / len(test_loader.dataset) # divide by total number of samples
19 test_acc = correct / total
20 print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.4f}")

Moment of truth....

1test_model()
1test_model()

Output:

1Test loss: 0.4025, test accuracy: 0.8942
1Test loss: 0.4025, test accuracy: 0.8942

Not bad! almost 90% accuracy on the test set. If we can do this with public data, almost no computer, and an hour or two of work, imagine what we could do with a team of engineers, a budget, and a few months.

Let's visualize the model outputs, to get a sense of what's happening under the hood

1import matplotlib.pyplot as plt
2import numpy as np
3import random
4
5def visualize_predictions(model, dataloader, class_names, num_images=5):
6 model.eval() # Set model to evaluation mode
7
8 images, labels = next(iter(dataloader))
9 indices = random.sample(range(len(images)), num_images)
10 images, labels = images[indices], labels[indices]
11 images, labels = images.to(device), labels.to(device)
12
13 with torch.no_grad():
14 outputs = model(images)
15 probs = F.softmax(outputs, dim=1) # get actual probs
16 _, preds = torch.max(probs, 1) # using probs has the same effect as outputs
17
18 # Move the images to CPU and convert them to numpy for visualization
19 images = images.cpu().numpy().transpose((0, 2, 3, 1))
20
21 # Unnormalize for visualization
22 mean = np.array([0.485, 0.456, 0.406])
23 std = np.array([0.229, 0.224, 0.225])
24 images = std * images + mean
25 images = np.clip(images, 0, 1)
26
27 # Plot the images with labels
28 plt.figure(figsize=(20, 4))
29 for idx in range(num_images):
30 ax = plt.subplot(1, num_images, idx + 1)
31 plt.imshow(images[idx])
32 true_label = class_names[labels[idx]]
33 pred_label = class_names[preds[idx]]
34 prob_dist = ", ".join(f"{class_names[i]}: {probs[idx, i]:.4f}" for i in range(len(class_names)))
35 print(prob_dist)
36 ax.title.set_text(f"True: {true_label}\nPred: {pred_label}")
37 plt.axis("off")
38 plt.show()
1import matplotlib.pyplot as plt
2import numpy as np
3import random
4
5def visualize_predictions(model, dataloader, class_names, num_images=5):
6 model.eval() # Set model to evaluation mode
7
8 images, labels = next(iter(dataloader))
9 indices = random.sample(range(len(images)), num_images)
10 images, labels = images[indices], labels[indices]
11 images, labels = images.to(device), labels.to(device)
12
13 with torch.no_grad():
14 outputs = model(images)
15 probs = F.softmax(outputs, dim=1) # get actual probs
16 _, preds = torch.max(probs, 1) # using probs has the same effect as outputs
17
18 # Move the images to CPU and convert them to numpy for visualization
19 images = images.cpu().numpy().transpose((0, 2, 3, 1))
20
21 # Unnormalize for visualization
22 mean = np.array([0.485, 0.456, 0.406])
23 std = np.array([0.229, 0.224, 0.225])
24 images = std * images + mean
25 images = np.clip(images, 0, 1)
26
27 # Plot the images with labels
28 plt.figure(figsize=(20, 4))
29 for idx in range(num_images):
30 ax = plt.subplot(1, num_images, idx + 1)
31 plt.imshow(images[idx])
32 true_label = class_names[labels[idx]]
33 pred_label = class_names[preds[idx]]
34 prob_dist = ", ".join(f"{class_names[i]}: {probs[idx, i]:.4f}" for i in range(len(class_names)))
35 print(prob_dist)
36 ax.title.set_text(f"True: {true_label}\nPred: {pred_label}")
37 plt.axis("off")
38 plt.show()
1visualize_predictions(model, test_loader, class_names=["cataract", "glaucoma", "diabetic_retinopathy", "normal"], num_images=8) # TODO change class names
1visualize_predictions(model, test_loader, class_names=["cataract", "glaucoma", "diabetic_retinopathy", "normal"], num_images=8) # TODO change class names

Here is an example model output:

1cataract: 0.0006, glaucoma: 0.0000, diabetic_retinopathy: 0.0011, normal: 0.9983
2cataract: 0.9999, glaucoma: 0.0000, diabetic_retinopathy: 0.0001, normal: 0.0000
3cataract: 0.8247, glaucoma: 0.0001, diabetic_retinopathy: 0.1748, normal: 0.0004
4cataract: 0.7998, glaucoma: 0.0000, diabetic_retinopathy: 0.1971, normal: 0.0031
1cataract: 0.0006, glaucoma: 0.0000, diabetic_retinopathy: 0.0011, normal: 0.9983
2cataract: 0.9999, glaucoma: 0.0000, diabetic_retinopathy: 0.0001, normal: 0.0000
3cataract: 0.8247, glaucoma: 0.0001, diabetic_retinopathy: 0.1748, normal: 0.0004
4cataract: 0.7998, glaucoma: 0.0000, diabetic_retinopathy: 0.1971, normal: 0.0031

example model output

Let's put the model in the context of a Clinical DSS. To determine which images need to be reviewed by the physician, we can set a confidence threshold tt that corresponds to the required output probability to bypass manual clinical review.

If the model is confident enough, don't review the image. If the model is not confident enough, send the image to a physician for manual review.

Notice that the model was incorrect on the last image. The greatest probability in the distribution was cataract (at .7998), but the image was actually diabetic retinopathy. If our threshold is set to .85 this image would correctly be tagged for physician review.

Note that in this test environment, the True labels are available (we already know what disease is depicted, ahead of time). In production, this will be unknown.

For images that get tagged for review, once they are reviewed they can be added to a dataset which can be used to further improve model performance down the road.

Final Thoughts

To reiterate, Machine Learning and Artificial Intelligence have many applications in healthcare. A great way to implement this new technology is part of a larger Clinical Decision Support System that aims to help medical personnel like physicians with decision-making. With the low cost of inference, it's now possible to screen large numbers of patients for pathologies like cataracts, glaucoma, and DR.

The next steps for this project could include optimizing model performance, implementing a model into production, or exploring other use cases like more pathologies or different types of imaging.

If you have any questions, comments, or suggestions, feel free to reach out to me at ben@beneverman.com📧. Thanks for reading!