Building Neural Networks to Detect Eye Diseases

Summary

The source code is here, in Colab. You'll need a GPU to run it.

If you have any question, feel free to reach out to me on x or email me: ben@beneverman.com

Doctors Can't Keep Up

Healthcare is constantly changing and evolving. Research moves fast: in a 2011 publication, Densen estimated that by 2020, medical knowledge would double every 73 days, a marked departure from the 1950s, when doubling took 50 years1. For the average physician, though some continuing education is required to maintain licensure and board certification, they must fight to keep pace with the rate of new research while they juggle heavy caseloads, administrative responsibility, and personal care.

"This growing aggregation of advances and reversals presents a significant challenge to physicians attempting to stay up to date. Historically, there has been an average 17-year lag between medical discoveries and implementation into clinical practice. With the acceleration of changes in clinical medicine, coupled with normal lag times in dissemination, there is a higher probability than ever before that physicians, within just a few years of leaving their training, may not be practicing contemporary standards of care." 2

With new technology like Machine Learning (ML) and Artificial Intelligence (AI), what are the implications for the average physician? Can we reduce the 17-year lag?

I think we can. How? By building clinical Decision Support Systems (DSS)3. While older decision support systems date back to the British in World War II, a modern example is NASA's mission control. While astronauts pilot air and spacecraft, mission control gathers and manages vast data sets, utilizes advanced models for analysis, offers user-friendly interfaces for astronaut interaction, and provides crucial support for real-time decision-making. Observing NASA in action - throughout the Apollo program, 32 astronauts were assigned to fly missions. At a 1:100 ratio, one would estimate as many as 3200 staff on mission control. 4

So how can we use ML and AI to support clinical decision making? One way is using computer vision. With neural networks that take images as input, we can batch-process large amounts of medical imagery and use the network to return the much smaller subset that requires actual physician review. The goal is not to diagnose, but to identify what needs physician review while providing the physician with as much supporting information as possible.

Using Math to See Disease

This project was inspired by the work of Ophthalytics. In their recent paper5 they trained a Convolutional Neural Network (CNN), commonly used in computer vision, to detect Diabetic Retinopathy. DR is a complication of diabetes that affects the retina and can lead to vision loss if not detected and treated early. In the paper, the model takes close-up retinal fundus images as input and classifies them as DR-Positive or DR-Negative. See the figure below.

ophth figure 3 (a) DR Negative and (b) DR Positive

I decided I wanted to build a similar eye disease classifier. My goals for this project were twofold:

  1. Train a classifier that can detect multiple eye diseases
  2. Train the model using minimal compute, as fast as possible

Implementing a CNN for Eye Disease Classification

Since my goal was to minimize compute costs, I chose the EfficientNet architecture. The original EfficientNet paper6 was published in 2019. The EfficientNet architecture uses a compound scaling method that uniformly scales all architectural dimensions of the CNN (depth/width/resolution).The big benefit of the Efficient is that its optimized for performance/compute cost. At the time of publication, their largest model, the EfficientNetB7 was deemed "8.4x smaller and 6.1x faster on inference than the best existing ConvNet"7. Here's a figure from the paper:

EfficientNet

Notice the Y axis is Accuracy and the X axis is FLOPs - a direct measure for compute cost. FLOPs stands for Floating Point Operations, which is the total number of floating-point arithmetic operations performed by (inference) or required by (training) the model. 8

For this project, I used an A100 on Google Colab, which costs me about $9.99 a month. The isolated cost of this project would be much lower (<$2 maybe?). It turns out that being poor does force you to be more creative.

More Architecture Decisions

I realized I could leverage transfer learning by taking a pre-trained EfficientNet and continuing its training on my eye disease dataset (AKA fine-tuning). This has a few benefits:

  1. Reduced compute cost (time and throughput) that would be incurred by pre-training a CNN
  2. Prevent overfitting on my small dataset during pre-training
  3. Increase the speed I could iterate and finish this project.

My goal here was not to train the most performant classifier. Maybe I'll return to that in the future, for now, it's less relevant. The easiest way to boost performance would be to use a larger base model, which I did play around with.

Update 3/31/24: Since writing this, I've read much more literature on retinal fundus imaging and computer vision. The gold standard model architecture seems to be a U-Net with a pretrained endcoder (ImageNet) for segmentation, connected to a pre-trained classifier (ImageNet). If I go any further with this project, this is the direction I'll take. Lack of annotated data is the biggest issue.

The Code, Explained

Again, Here is the link to the code. I've set it up so that you should be able to download the data, train/fine-tune the model, and test it pretty easily. NOTE: You'll probably need an A100 at least, which means you need Colab Pro. You could probably get away with a T4 (free) for the smallest model, EfficientNet-B0. Email me if you have any issues or questions! ben@beneverman.com

Setup and Configiuration

To set up the project we'll need to do a few things:

  1. Make sure we're using a GPU runtime (Runtime -> change runtime type)
  2. Get the data from the cloud onto the runtime, and unzip it
  3. Install missing dependencies
!nvidia-smi # make sure you've got a GPU runtime selected (I used an A100)

This will show you the GPU your're using. Here's what an example output looks like:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0              44W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

For downloading the data, I've included a pre-signed S3 link that you can use to make things painless. Just run the cell.

!curl {URL IS IN THE COLAB} -o data

Unzip the data

!unzip data

Install the necessary libraries. Everything we need comes preloaded in Colab, except for timm, which is a library that provides pre-trained image models.

!pip install timm

Tip! In colab you can click the folder icon on the left to see the files in the runtime. Or, you can use the !ls command to list the files in the current directory and cd to change directories

Training the Model

First things first, lets import all our dependencies.

import os # filesystem
import timm # pretrained image models
import torch #PyTorch, deep learning framework
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms # PyTorch vision utils
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split # SciKit Learn for splitting the data
from time import perf_counter # to time stuff

Let's do a sanity check to make sure the data is where we expect it to be.

DATA_DIR = 'dataset' # the root directory that the data was unzipped to
print(os.listdir(DATA_DIR)) # list the contents of the directory

The ouput should be ['glaucoma', 'cataract', 'diabetic_retinopathy', 'normal']

Next, let's take the gpu and put it in a variable that we can use later to transfer onto the GPU.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set the device (GPU) to move model and tensors to later

Okay, now let's preprocess the data into datasets. We'll do some minimal preprocessing before we split the data into train, validation, and test sets. Before we do this, it can be helpful to look at the data and see what we're working with.9

Here's a way we can view some of the images inside this notebook, using pillow. Try to experiment with the data (look at the directory structure, view some images, etc.) to get a sense of what you're working with.

from PIL import Image
from IPython.display import display
 
folders = ["cataract", "glaucoma", "diabetic_retinopathy", "normal"]
for folder in folders:
  files = os.listdir(os.path.join(DATA_DIR, folder))
  if len(files) > 0:
    print(f"{folder}: {files[0]}")
    img = Image.open(os.path.join(DATA_DIR, folder, files[0]))
    display(img)

Now we can define a function to make the datasets.

def make_datasets(path: str):
    print(f"Making datasets from {path}")
    if not os.path.exists(path): raise FileNotFoundError(f"Path {path} does not exist") # check if the path exists
    assert set(os.listdir(path)) == set(['cataract', 'glaucoma', 'diabetic_retinopathy', 'normal']), "wrong dataset maybe??" # check if the dataset is correct
 
    transform = transforms.Compose([
        # the following line resizes the images from their original dimensions dimX x dimY to 384x384. Each model has different input sizes.
        transforms.Resize((384, 384)), # for B4, see all sizes here: https://discuss.pytorch.org/t/input-size-for-efficientnet-versions-from-torchvision-models/140525
        transforms.ToTensor(), # convert to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # normalize the images so 
    ])
    full_dataset = datasets.ImageFolder(path, transform=transform)
 
    # time to split the dat into train, validation, and test sets
    # here we're using the sklearn train_test_split function to split the data into 70% train, 30% val/test
    train_indices, temp_indices = train_test_split(
        range(len(full_dataset)),
        test_size=0.3, # will split this into validation and test later
        stratify=full_dataset.targets, # stratify the split (maintain the same distribution of classes) (targets are class indices)
        random_state=42 # for reproducibility (the randomness will be the same each time)
    )
 
    # further split the leftover data into validation and test (15% each)
    val_indices, test_indices = train_test_split(
        temp_indices,
        test_size=0.5, # split the temp_indices into validation and test
        stratify=[full_dataset.targets[i] for i in temp_indices], # stratify the split (maintain the same distribution of classes)
        random_state=42 # for reproducibility
    )
 
    # now lets create the datasets for torch
    train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
    val_dataset = torch.utils.data.Subset(full_dataset, val_indices)
    test_dataset = torch.utils.data.Subset(full_dataset, test_indices)
 
    return train_dataset, val_dataset, test_dataset

Now we can make the datasets and dataloaders. 10

train_dataset, val_dataset, test_dataset = make_datasets(DATA_DIR)
 
BATCH_SIZE = 32
 
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the validation set (for reproducibility across architecture/other changes)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the test set

Let's set our model hyperparams.

EPOCHS = 10 # the number of epochs to train for
LR = 0.0001 # the learning rate (how much the model adjusts its weights during training/backpropgation)

Loading some different models to mess around with. Optionally, you can list all the available models with timm.list_models().

example_models = ["efficientnet_b0", "efficientnet_b3", "efficientnet_b4", "vit_large_patch16_224"]
model = timm.create_model(example_models[2], pretrained=True, num_classes=4) # were using the EfficientNet-B4 model
model = model.to(device) # move the model to the GPU

Lets define the loss function and optimizer. We're using the CrossEntropyLoss function, which is commonly used for multi-class classification problems. The optimizer is Adam, which is a popular choice for deep learning models. 11

criterion = nn.CrossEntropyLoss() # categorical cross-entropy loss for multi-class classification
optimizer = optim.Adam(model.parameters(), lr=LR) # Adam optimizer

Okay, this next cell is pretty important. This is the training loop where we do a few things, iterating over a number of epochs (Epochs).

  1. Set the model to training mode (this is important because we need to make a backward pass during training, to update the model weights)
  2. Get the batch of inputs and targets (images and labels), move them to the GPU
  3. Zero the gradients (this is important because PyTorch accumulates gradients by default)
  4. Forward pass the inputs through the model ouput = model(inputs)
  5. Measure how close the model got to the correct answer loss = criterion(output, targets)
  6. Update the model weights based on the gradient of the loss loss.backward(); optimizer.step()
  7. Print the epoch and training loss.

Now into the validation loop. This is similar to the training loop, but we don't update the model weights. We're just measuring how well the model is doing on unseen data.

  1. Set the model to evaluation mode (this is important because we don't want to update the model weights)

  2. with torch.no_grad() tells PyTorch not to calculate gradients (we don't need them) which saves compute

  3. Forward pass the inputs through the model output = model(inputs)

  4. Measure how close the model got to the correct answer loss = criterion(output, targets)

  5. Print the epoch and validation loss.

  6. Important: I've addded early stopping. This is a technique to prevent overfitting (the bane of every ML engineer's existance). If the validation loss doesn't improve for a certain number of epochs, we stop training as to not overfit the model.

def train(early_stopping=False, es_tol=.05):
    start = perf_counter()
    best_val_loss = float("inf") # arbitrarily high value
    for epoch in range(EPOCHS):
        model.train() # training mode (grads)
        running_train_loss = 0.0
        for i, (inputs, targets) in enumerate(train_loader):
            print(f"Batch {i+1}/{len(train_loader)}", end="\r")
            inputs, targets = inputs.to(device), targets.to(device) # move to device
            optimizer.zero_grad() # reset gradients
 
            outputs = model(inputs) # forward pass
            loss = criterion(outputs, targets) # compute loss
            loss.backward()
            optimizer.step()
 
            running_train_loss += loss.item() * inputs.size(0) # multiply by batch size
 
        train_loss = running_train_loss / len(train_loader.dataset) # divide by total number of samples
        print(f"Epoch {epoch+1}/{EPOCHS} - train loss: {train_loss:.4f}")
 
        model.eval()
        running_validation_loss = 0.0
        correct = 0
        total = 0
 
        with torch.no_grad():
            for i, (inputs, targets) in enumerate(val_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                running_validation_loss += loss.item() * inputs.size(0)
 
                _, predicted = torch.max(outputs, 1) # get the index of the logprobs
                total += targets.size(0) # add the number of targets in this batch
                correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
 
            val_loss = running_validation_loss / len(val_loader.dataset) # divide by total number of samples
            val_acc = correct / total
            print(f"Epoch {epoch+1}/{EPOCHS} - validation loss: {val_loss:.4f}, validation accuracy: {val_acc:.4f}")
 
            if val_loss - best_val_loss > es_tol: # if we're more tan (es_tol) greater than best val loss, stop
                print(f"Early stopping at epoch {epoch}")
                break
 
            best_val_loss = min(best_val_loss, val_loss)
 
    print(f"Total Time: {perf_counter()-start:0.2f}s")

Now we can train the model. This will take a minute, so go grab a coffee or something.

If you are NOT using a GPU runtime (!nvidia-smi should show a GPU), this will take a long time (hours or days) and might never finish. If you're on a CPU, you need to restart the runtime and selecting a GPU runtime, before working back to this point (you'll need to re-run all previous cells).

train(early_stopping=True)

Here's an example output from the training loop:

Epoch 1/10 - train loss: 0.7035
Epoch 1/10 - validation loss: 0.3994, validation accuracy: 0.8468
Epoch 2/10 - train loss: 0.2524
Epoch 2/10 - validation loss: 0.2609, validation accuracy: 0.9052
Epoch 3/10 - train loss: 0.1327
Epoch 3/10 - validation loss: 0.2327, validation accuracy: 0.9163
Epoch 4/10 - train loss: 0.0651
Epoch 4/10 - validation loss: 0.2136, validation accuracy: 0.9305
Epoch 5/10 - train loss: 0.0430
Epoch 5/10 - validation loss: 0.2282, validation accuracy: 0.9289
Epoch 6/10 - train loss: 0.0216
Epoch 6/10 - validation loss: 0.2334, validation accuracy: 0.9321
Epoch 7/10 - train loss: 0.0159
Epoch 7/10 - validation loss: 0.2580, validation accuracy: 0.9321
Epoch 8/10 - train loss: 0.0156
Epoch 8/10 - validation loss: 0.2603, validation accuracy: 0.9336
Epoch 9/10 - train loss: 0.0081
Epoch 9/10 - validation loss: 0.2560, validation accuracy: 0.9352
Epoch 10/10 - train loss: 0.0087
Epoch 10/10 - validation loss: 0.2601, validation accuracy: 0.9352
Total Time: 600.57s

Now we need to test the model on a completely separate dataset. Even though the training loss looks good, the model could be overfitting, meaning it's memorizing the training data and not generalizing well to unseen data. This test loop is basically the same as the validation loop, except we're calculating total loss with the ground truth labels.

def test_model():
    model.eval()
    running_test_loss = 0.0
    correct = 0
    total = 0
 
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_test_loss += loss.item() * inputs.size(0)
 
            _, predicted = torch.max(outputs, 1) # get the index of the max log-probability (argmax)
            total += targets.size(0) # add the number of targets in this batch
            correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
 
        test_loss = running_test_loss / len(test_loader.dataset) # divide by total number of samples
        test_acc = correct / total
        print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.4f}")

Moment of truth....

test_model()

Output:

Test loss: 0.4025, test accuracy: 0.8942

Not bad! almost 90% accuracy on the test set. If we can do this with public data, almost no computer, and an hour or two of work, imagine what we could do with a team of engineers, a budget, and a few months.

Let's visualize the model outputs, to get a sense of what's happening under the hood

import matplotlib.pyplot as plt
import numpy as np
import random
 
def visualize_predictions(model, dataloader, class_names, num_images=5):
    model.eval()  # Set model to evaluation mode
 
    images, labels = next(iter(dataloader))
    indices = random.sample(range(len(images)), num_images)
    images, labels = images[indices], labels[indices]
    images, labels = images.to(device), labels.to(device)
 
    with torch.no_grad():
        outputs = model(images)
        probs = F.softmax(outputs, dim=1) # get actual probs
        _, preds = torch.max(probs, 1) # using probs has the same effect as outputs
 
    # Move the images to CPU and convert them to numpy for visualization
    images = images.cpu().numpy().transpose((0, 2, 3, 1))
 
    # Unnormalize for visualization
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    images = std * images + mean
    images = np.clip(images, 0, 1)
 
    # Plot the images with labels
    plt.figure(figsize=(20, 4))
    for idx in range(num_images):
        ax = plt.subplot(1, num_images, idx + 1)
        plt.imshow(images[idx])
        true_label = class_names[labels[idx]]
        pred_label = class_names[preds[idx]]
        prob_dist = ", ".join(f"{class_names[i]}: {probs[idx, i]:.4f}" for i in range(len(class_names)))
        print(prob_dist)
        ax.title.set_text(f"True: {true_label}\nPred: {pred_label}")
        plt.axis("off")
    plt.show()
visualize_predictions(model, test_loader, class_names=["cataract", "glaucoma", "diabetic_retinopathy", "normal"], num_images=8) # TODO change class names

Here is an example model output:

cataract: 0.0006, glaucoma: 0.0000, diabetic_retinopathy: 0.0011, normal: 0.9983
cataract: 0.9999, glaucoma: 0.0000, diabetic_retinopathy: 0.0001, normal: 0.0000
cataract: 0.8247, glaucoma: 0.0001, diabetic_retinopathy: 0.1748, normal: 0.0004
cataract: 0.7998, glaucoma: 0.0000, diabetic_retinopathy: 0.1971, normal: 0.0031

example model output

Let's put the model in the context of a Clinical DSS. To determine which images need to be reviewed by the physician, we can set a confidence threshold $t$ that corresponds to the required output probability to bypass manual clinical review.

If the model is confident enough, don't review the image. If the model is not confident enough, send the image to a physician for manual review.

Notice that the model was incorrect on the last image. The greatest probability in the distribution was cataract (at .7998), but the image was actually diabetic retinopathy. If our threshold is set to .85 this image would correctly be tagged for physician review.

Note that in this test environment, the True labels are available (we already know what disease is depicted, ahead of time). In production, this will be unknown.

For images that get tagged for review, once they are reviewed they can be added to a dataset which can be used to further improve model performance down the road.

Final Thoughts

To reiterate, Machine Learning and Artificial Intelligence have many applications in healthcare. A great way to implement this new technology is part of a larger Clinical Decision Support System that aims to help medical personnel like physicians with decision-making. With the low cost of inference, it's now possible to screen large numbers of patients for pathologies like cataracts, glaucoma, and DR.

The next steps for this project could include optimizing model performance, implementing a model into production, or exploring other use cases like more pathologies or different types of imaging.

If you have any questions, comments, or suggestions, feel free to reach out to me at ben@beneverman.com. Thanks for reading!