The source code is here, in Colab. You'll need a GPU to run it, but everything else should be ready to go. If something breaks, just email me: ben@beneverman.com📧.
Doctors Can't Keep Up
Healthcare is constantly changing and evolving. Research moves fast: in a 2011 publication, Densen estimated that by 2020, medical knowledge would double every 73 days, a marked departure from the 1950s, when doubling took 50 yearsDensen P, (2011). For the average physician, though some continuing education is required to maintain licensure and board certification, they must fight to keep pace with the rate of new research while they juggle heavy caseloads, administrative responsibility, and personal care.
"This growing aggregation of advances and reversals presents a significant challenge to physicians attempting to stay up to date. Historically, there has been an average 17-year lag between medical discoveries and implementation into clinical practice. With the acceleration of changes in clinical medicine, coupled with normal lag times in dissemination, there is a higher probability than ever before that physicians, within just a few years of leaving their training, may not be practicing contemporary standards of care." Laiteerapong, Neda, and Elbert S Huang. (2015)
With new technology like Machine Learning (ML) and Artificial Intelligence (AI), what are the implications for the average physician? Can we reduce the 17-year lag?
I think we can. How? By building clinical Decision Support Systems (DSS)Sutton, et al. (2020). While older decision support systems date back to the British in World War II, a modern example is NASA's mission control. While astronauts pilot air and spacecraft, mission control gathers and manages vast data sets, utilizes advanced models for analysis, offers user-friendly interfaces for astronaut interaction, and provides crucial support for real-time decision-making. Observing NASA in action - throughout the Apollo program, 32 astronauts were assigned to fly missions. At a 1:100 ratio, one would estimate as many as 3200 staff on mission control. The actual number? 400,000.
So how can we use ML and AI to support clinical decision making? One way is using computer vision. With neural networks that take images as input, we can batch-process large amounts of medical imagery and use the network to return the much smaller subset that requires actual physician review. The goal is not to diagnose, but to identify what needs physician review while providing the physician with as much supporting information as possible.
Using Math to See Disease
This project was inspired by the work of Ophthalytics. In their recent paperBajwa, et al. (2023) they trained a Convolutional Neural Network (CNN), commonly used in computer vision, to detect Diabetic Retinopathy. DR is a complication of diabetes that affects the retina and can lead to vision loss if not detected and treated early.
In the paper, the model takes close-up retinal fundus images as input and classifies them as DR-Positive or DR-Negative. See the figure below.
(a) DR Negative and (b) DR Positive
I decided I wanted to build a similar eye disease classifier. My goals for this project were twofold:
Train a classifier that can detect multiple eye diseases
Train the model using minimal compute, as fast as possible
Implementing a CNN for Eye Disease Classification
Since my goal was to minimize compute costs, I chose the EfficientNet architecture. The original EfficientNet paperTan and Le. (2019) was published in 2019. The EfficientNet architecture uses a compound scaling method that uniformly scales all architectural dimensions of the CNN (depth/width/resolution).The big benefit of the Efficient is that its optimized for performance/compute cost. At the time of publication, their largest model, the EfficientNetB7 was deemed "8.4x smaller and 6.1x faster on inference than the best existing ConvNet"Tan and Le. (2019). Here's a figure from the paper:
Notice the Y axis is Accuracy and the X axis is FLOPs - a direct measure for compute cost. FLOPs stands for Floating Point Operations, which is the total number of floating-point arithmetic operations performed by (inference) or required by (training) the model. For context, the speculated number of FLOPs used to train GPT-4 (aka ChatGPT) was ~2.15e25 or 21,500,000 EFLOPs, which for $1/A100 hour is estimated at $63 Million.
For this project, I used an A100 on Google Colab, which costs me about $9.99 a month. The isolated cost of this project would be much lower (<$2 maybe?). It turns out that being poor does force you to be more creative.
More Architecture Decisions
I realized I could leverage transfer learning by taking a pre-trained EfficientNet and continuing its training on my eye disease dataset (AKA fine-tuning). This has a few benefits:
Reduced compute cost (time and throughput) that would be incurred by pre-training a CNN
Prevent overfitting on my small dataset during pre-training
Increase the speed I could iterate and finish this project.
My goal here was not to train the most performant classifier. Maybe I'll return to that in the future, for now, it's less relevant. The easiest way to boost performance would be to use a larger base model, which I did play around with.
Update 3/31: Since writing this, I've read much more literature on retinal fundus imaging and computer vision. The gold standard model architecture seems to be a U-Net with a pretrained endcoder (ImageNet) for segmentation, connected to a pre-trained classifier (ImageNet). If I go any further with this project, this is the direction I'll take. Lack of annotated data is the biggest issue.
The Code, Explained
💡
Again, Here is the link to the code. I've set it up so that you should be able to download the data, train/fine-tune the model, and test it pretty easily. NOTE: You'll probably need an A100 at least, which means you need Colab Pro. You could probably get away with a T4 (free) for the smallest model, EfficientNet-B0. Email me if you have any issues or questions! ben@beneverman.com📧
Setup and Configiuration
To set up the project we'll need to do a few things:
Make sure we're using a GPU runtime (Runtime -> change runtime type)
Get the data from the cloud onto the runtime, and unzip it
Install missing dependencies
1!nvidia-smi # make sure you've got a GPU runtime selected (I used an A100)
1!nvidia-smi # make sure you've got a GPU runtime selected (I used an A100)
This will show you the GPU your're using. Here's what an example output looks like:
For downloading the data, I've included a pre-signed S3 link that you can use to make things painless. Just run the cell.
1!curl {URL IS IN THE COLAB} -o data
1!curl {URL IS IN THE COLAB} -o data
Unzip the data
1!unzip data
1!unzip data
Install the necessary libraries. Everything we need comes preloaded in Colab, except for timm, which is a library that provides pre-trained image models.
1!pip install timm
1!pip install timm
💡
Tip! In colab you can click the folder icon on the left to see the files in the runtime. Or, you can use the !ls command to list the files in the current directory and cd to change directories
Training the Model
First things first, lets import all our dependencies.
9from sklearn.model_selection import train_test_split # SciKit Learn for splitting the data
10from time import perf_counter # to time stuff
Let's do a sanity check to make sure the data is where we expect it to be.
1DATA_DIR = 'dataset' # the root directory that the data was unzipped to
2print(os.listdir(DATA_DIR)) # list the contents of the directory
1DATA_DIR = 'dataset' # the root directory that the data was unzipped to
2print(os.listdir(DATA_DIR)) # list the contents of the directory
The ouput should be ['glaucoma', 'cataract', 'diabetic_retinopathy', 'normal']
Next, let's take the gpu and put it in a variable that we can use later to transfer onto the GPU.
1device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set the device (GPU) to move model and tensors to later
1device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # set the device (GPU) to move model and tensors to later
Okay, now let's preprocess the data into datasets. We'll do some minimal preprocessing before we split the data into train, validation, and test sets. Before we do this, it can be helpful to look at the data and see what we're working with.Andrej Karpathy (AKA the GOAT 🐐) has a fantastic blog post that includes how and why you should manually inspect your data.
Here's a way we can view some of the images inside this notebook, using pillow. Try to experiment with the data (look at the directory structure, view some images, etc.) to get a sense of what you're working with.
Now we can define a function to make the datasets.
1def make_datasets(path: str):
2 print(f"Making datasets from {path}")
3 if not os.path.exists(path): raise FileNotFoundError(f"Path {path} does not exist") # check if the path exists
4 assert set(os.listdir(path)) == set(['cataract', 'glaucoma', 'diabetic_retinopathy', 'normal']), "wrong dataset maybe??" # check if the dataset is correct
5
6 transform = transforms.Compose([
7 # the following line resizes the images from their original dimensions dimX x dimY to 384x384. Each model has different input sizes.
8 transforms.Resize((384, 384)), # for B4, see all sizes here: https://discuss.pytorch.org/t/input-size-for-efficientnet-versions-from-torchvision-models/140525
9 transforms.ToTensor(), # convert to tensor
10 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # normalize the images so
3 if not os.path.exists(path): raise FileNotFoundError(f"Path {path} does not exist") # check if the path exists
4 assert set(os.listdir(path)) == set(['cataract', 'glaucoma', 'diabetic_retinopathy', 'normal']), "wrong dataset maybe??" # check if the dataset is correct
5
6 transform = transforms.Compose([
7 # the following line resizes the images from their original dimensions dimX x dimY to 384x384. Each model has different input sizes.
8 transforms.Resize((384, 384)), # for B4, see all sizes here: https://discuss.pytorch.org/t/input-size-for-efficientnet-versions-from-torchvision-models/140525
9 transforms.ToTensor(), # convert to tensor
10 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # normalize the images so
6val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the validation set (for reproducibility across architecture/other changes)
7test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the test set
6val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the validation set (for reproducibility across architecture/other changes)
7test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # dont shuffle the test set
Let's set our model hyperparams.
1EPOCHS = 10 # the number of epochs to train for
2LR = 0.0001 # the learning rate (how much the model adjusts its weights during training/backpropgation)
1EPOCHS = 10 # the number of epochs to train for
2LR = 0.0001 # the learning rate (how much the model adjusts its weights during training/backpropgation)
Loading some different models to mess around with. Optionally, you can list all the available models with timm.list_models().
2model = timm.create_model(example_models[2], pretrained=True, num_classes=4) # were using the EfficientNet-B4 model
3model = model.to(device) # move the model to the GPU
Lets define the loss function and optimizer. We're using the CrossEntropyLoss function, which is commonly used for multi-class classification problems. The optimizer is Adam, which is a popular choice for deep learning models. You can read more about how deep learning works in this fantastic book by Francois Fleuret, The Little Book of Deep Learning. I would recommend trying to get the code to work first, and then trying to reverse engineer it and understand how/why it works.
1criterion = nn.CrossEntropyLoss() # categorical cross-entropy loss for multi-class classification
2optimizer = optim.Adam(model.parameters(), lr=LR) # Adam optimizer
1criterion = nn.CrossEntropyLoss() # categorical cross-entropy loss for multi-class classification
2optimizer = optim.Adam(model.parameters(), lr=LR) # Adam optimizer
Okay, this next cell is pretty important. This is the training loop where we do a few things, iterating over a number of epochs (Epochs).
Set the model to training mode (this is important because we need to make a backward pass during training, to update the model weights)
Get the batch of inputs and targets (images and labels), move them to the GPU
Zero the gradients (this is important because PyTorch accumulates gradients by default)
Forward pass the inputs through the model ouput = model(inputs)
Measure how close the model got to the correct answer loss = criterion(output, targets)
Update the model weights based on the gradient of the loss loss.backward(); optimizer.step()
Print the epoch and training loss.
Now into the validation loop. This is similar to the training loop, but we don't update the model weights. We're just measuring how well the model is doing on unseen data.
Set the model to evaluation mode (this is important because we don't want to update the model weights)
with torch.no_grad() tells PyTorch not to calculate gradients (we don't need them) which saves compute
Forward pass the inputs through the model output = model(inputs)
Measure how close the model got to the correct answer loss = criterion(output, targets)
Print the epoch and validation loss.
Important: I've addded early stopping. This is a technique to prevent overfitting (the bane of every ML engineer's existance). If the validation loss doesn't improve for a certain number of epochs, we stop training as to not overfit the model.
1def train(early_stopping=False, es_tol=.05):
2 start = perf_counter()
3 best_val_loss = float("inf") # arbitrarily high value
4 for epoch in range(EPOCHS):
5 model.train() # training mode (grads)
6 running_train_loss = 0.0
7 for i, (inputs, targets) in enumerate(train_loader):
Now we can train the model. This will take a minute, so go grab a coffee or something.
🚫
If you are NOT using a GPU runtime (!nvidia-smi should show a GPU), this will take a long time (hours or days) and might never finish. If you're on a CPU, you need to restart the runtime and selecting a GPU runtime, before working back to this point (you'll need to re-run all previous cells).
Now we need to test the model on a completely separate dataset. Even though the training loss looks good, the model could be overfitting, meaning it's memorizing the training data and not generalizing well to unseen data. This test loop is basically the same as the validation loop, except we're calculating total loss with the ground truth labels.
1def test_model():
2 model.eval()
3 running_test_loss = 0.0
4 correct = 0
5 total = 0
6
7 with torch.no_grad():
8 for i, (inputs, targets) in enumerate(test_loader):
14 _, predicted = torch.max(outputs, 1) # get the index of the max log-probability (argmax)
15 total += targets.size(0) # add the number of targets in this batch
16 correct += (predicted == targets).sum().item() # add the number of correct predictions in this batch
17
18 test_loss = running_test_loss / len(test_loader.dataset) # divide by total number of samples
19 test_acc = correct / total
20 print(f"Test loss: {test_loss:.4f}, test accuracy: {test_acc:.4f}")
Moment of truth....
1test_model()
1test_model()
Output:
1Test loss: 0.4025, test accuracy: 0.8942
1Test loss: 0.4025, test accuracy: 0.8942
Not bad! almost 90% accuracy on the test set. If we can do this with public data, almost no computer, and an hour or two of work, imagine what we could do with a team of engineers, a budget, and a few months.
Let's visualize the model outputs, to get a sense of what's happening under the hood
Let's put the model in the context of a Clinical DSS. To determine which images need to be reviewed by the physician, we can set a confidence threshold t that corresponds to the required output probability to bypass manual clinical review.
If the model is confident enough, don't review the image. If the model is not confident enough, send the image to a physician for manual review.
Notice that the model was incorrect on the last image. The greatest probability in the distribution was cataract (at .7998), but the image was actually diabetic retinopathy. If our threshold is set to .85 this image would correctly be tagged for physician review.
Note that in this test environment, the True labels are available (we already know what disease is depicted, ahead of time). In production, this will be unknown.
For images that get tagged for review, once they are reviewed they can be added to a dataset which can be used to further improve model performance down the road.
Final Thoughts
To reiterate, Machine Learning and Artificial Intelligence have many applications in healthcare. A great way to implement this new technology is part of a larger Clinical Decision Support System that aims to help medical personnel like physicians with decision-making. With the low cost of inference, it's now possible to screen large numbers of patients for pathologies like cataracts, glaucoma, and DR.
The next steps for this project could include optimizing model performance, implementing a model into production, or exploring other use cases like more pathologies or different types of imaging.
If you have any questions, comments, or suggestions, feel free to reach out to me at ben@beneverman.com📧. Thanks for reading!