In this tutorial, you will learn about the PyTorch deep learning library, including:
What PyTorch is
How to install PyTorch on your machine
Important PyTorch features, including tensors and autograd
How PyTorch supports GPUs
Why PyTorch is so popular among researchers
Whether or not PyTorch is better than Keras/TensorFlow
Whether you should be using PyTorch or Keras/TensorFlow in your projects
Additionally, this tutorial is part one in our five part series on PyTorch fundamentals:
What is PyTorch? (today’s tutorial)
Intro to PyTorch: Training your first neural network using PyTorch (next week’s tutorial)
PyTorch: Training your first Convolutional Neural Network
PyTorch image classification with pre-trained networks
PyTorch object detection with pre-trained networks
By the end of this tutorial, you’ll have a good introduction to the PyTorch library and be able to discuss the pros and cons of the library with other deep learning practitioners.
To learn about the PyTorch deep learning library, just keep reading.
PyTorch is an open source machine learning library that specializes in tensor computations, automatic differentiation, and GPU acceleration. For those reasons, PyTorch is one of the most popular deep learning libraries, competing with both Keras and TensorFlow for the prize of “most used” deep learning package:
Figure 1: PyTorch is second only to Keras/TensorFlow in terms of deep learning library popularity.
PyTorch tends to be especially popular among the research community due to its Pythonic nature and ease of extendability (i.e., implementing custom layer types, network architectures, etc.).
In this tutorial, we’ll discuss the basics of the PyTorch deep learning library. Starting next week, you’ll gain hands-on experience using PyTorch to train neural networks, perform image classification, and apply object detection to both images and real-time video.
Let’s get started learning about PyTorch!
PyTorch, deep learning, and neural networks
Figure 2: PyTorch is a scientific computing library primarily focused on deep learning and neural networks.
PyTorch is based on Torch, a scientific computing framework for Lua. Prior to both PyTorch and Keras/TensorFlow, deep learning packages such as Caffe and Torch tended to be the most popular.
However, as deep learning started to revolutionize nearly all areas of computer science, developers and researchers wanted an efficient, easy to use library to construct, train, and evaluate neural networks in the Python programming language.
Python, along with R, are the two most popular programming languages for data scientists and machine learning, so it’s only natural that researchers wanted deep learning algorithms inside their Python ecosystems.
François Chollet, a Google AI researcher, developed and released Keras in March 2015, an open source library that provides a Python API for training neural networks. Keras quickly gained popularity due to its easy to use API which modeled much of how scikit-learn, the de facto standard machine learning library for Python, works.
Soon over, Google released its first version of TensorFlow in November 2015. TensorFlow not only became the default backend/engine for the Keras library, but also implemented a number of lower-level features that advanced deep learning practitioners and researchers needed to create state-of-the-art networks and perform novel research.
However, there was a problem — the TensorFlow v1.x API wasn’t very Pythonic, nor was it intuitive and easy to use. To solve that problem PyTorch, sponsored by Facebook and endorsed by Yann LeCun (one of the grandfathers of the modern neural network resurgence, and AI researcher at Facebook), was released in September 2016.
PyTorch solved much of the problems researchers were having with Keras and TensorFlow. While Keras is incredibly easy to use, by its very nature and design Keras does not expose some of the low-level functions and customization that researchers needed.
On the other hand, TensorFlow certainly gave access to these types of functions, but they weren’t Pythonic and it was often hard to comb the TensorFlow documentation to find out exactly what function was needed. In short, Keras didn’t offer the low-level API that researchers needed and TensorFlow’s API wasn’t all that friendly.
PyTorch solved those problems by creating an API that was both Pythonic and easy to customize, allowing new layer types, optimizers, and novel architectures to be implemented. Research groups slowly started embracing PyTorch, switching over from TensorFlow. In essence, that is why you see so many researchers using PyTorch in their labs today.
That said, since the release of PyTorch 1.x and TensorFlow 2.x, the APIs for the respective libraries have essentially converged (pun intended).Both PyTorch and TensorFlow now implement essentially the same functionality and provide APIs and function calls to accomplish the same thing
This statement is even backed by Eli Stevens, Luca Antiga, and Thomas Viehmann, who quite literally wrote the book on PyTorch:
Interestingly, with the advent of TorchScript and eager mode, both PyTorch and TensorFlow have seen their feature sets start to converge with the other’s, though the presentation of these features and the overall experience is still quite different between the two.
My point here is to not get too caught up in the debate over whether PyTorch or Keras/TensorFlow is “better” — both libraries implement very similar features, just using different function calls and different training paradigms.
Not getting caught up in the (sometimes hostile) debate of which library is better is especially true if you are a beginner to deep learning. As I discuss later in this tutorial, it’s instead far better for you to just pick one and learn it. The fundamentals of deep learning are the same, regardless of whether you use PyTorch or Keras/TensorFlow.
How do I install PyTorch?
Figure 3: PyTorch can be installed via “pip,” Python’s package manager.
The PyTorch library can be installed using pip, Python’s package manager:
$ pip install torch torchvision
From there, you should fire up a Python shell and verify that you can import both torch and torchvision:
Figure 4: PyTorch represents multi-dimensional arrays as “tensors.” Tensors form the fundamental building blocks of a neural network (image source).
PyTorch represents data as multi-dimensional, NumPy-like arrays called tensors. Tensors store inputs to your neural network, hidden layer representations, and the outputs.
Here is an example of initializing an array with NumPy:
This doesn’t seem like a big deal, but under the hood, PyTorch can dynamically generate a graph from these tensors and then apply automatic differentiation on top of them:
PyTorch’s Autograd feature
Figure 5: We can easily train neural networks using PyTorch thanks to PyTorch’s “autograd” module (image source).
Speaking of automatic differentiation, PyTorch makes it super easy to train neural networks using torch.autograd.
Under the hood, PyTorch is able to:
Assemble a graph of a neural network
Perform a forward pass (i.e., make predictions)
Compute the loss/error
Traverse the network backwards (i.e., backpropagation) and adjust the parameters of the network such that it (ideally) makes more accurate predictions based on the computed loss/output
Step #4 is always the most tedious and time consuming step to implement by hand. Luckily for us, PyTorch takes care of that step automatically.
Note: Keras users typically just call model.fit to train a network while TensorFlow users utilize the GradientTape class. PyTorch requires us to implement our training loop by hand, so the fact that torch.autograd works for us under the hood is a huge help. Be thankful to the PyTorch developers for implementing automatic differentiation so you didn’t have to.
PyTorch and GPU support
Figure 6: PyTorch can be used to train neural networks using GPUs (predominantly NVIDIA CUDA-based GPUs).
The PyTorch library primarily supports NVIDIA CUDA-based GPUs. GPU acceleration allows you to train neural networks in a fraction of a time.
Furthermore, PyTorch supports distributed training that can allow you to train your models even faster.
Why is PyTorch popular among researchers?
Figure 7: PyTorch tends to be extremely popular amongst deep learning researchers due to its flexibility and customizability.
PyTorch gained a foothold in the research community between 2016 (when PyTorch was released) and 2019 (prior to TensorFlow 2.x being officially released).
The reasons PyTorch were able to obtain this foothold are many, but the predominant reasons are:
Keras, while incredibly easy to use, didn’t provide access to low-level functions that researchers needed to perform novel deep learning research
In the same vain, Keras made it hard for researchers to implement their own custom optimizers, layer types, and model architectures
TensorFlow 1.x did provide this low-level access and custom implementation; however, the API was hard to use and not very Pythonic
PyTorch, and specifically its autograd support, helped resolve much of the issues with TensorFlow 1.x, making it easier for researchers to implement their own custom methods
Furthermore, PyTorch gives deep learning practitioners complete control over the training loop
There is of course a dichotomy between the two. Keras makes it trivial to train a neural network using a single call to model.fit, similar to how we train a standard machine learning model inside scikit-learn.
The downside is that researchers could not (easily) modify this model.fit call, so they had to use TensorFlow’s lower-level functions. But these methods didn’t make it easy for them to implement their training routines.
PyTorch solved that problem, which is good in the sense that we have complete control, but bad because we can easily shoot ourselves in the foot with PyTorch (every PyTorch user has forgotten to zero their gradients before).
All that said, much of the debate over whether PyTorch or TensorFlow is “better” for research is starting to settle down. The PyTorch 1.x and TensorFlow 2.x APIs implement very similar features, they just go about it in a different way, sort of like learning one programming language versus another. Each programming language has its benefits, but both implement the same types of statements and controls (i.e., “if” statements, “for” loops, etc.).
Is PyTorch better than TensorFlow and Keras?
Figure 8: Neither PyTorch nor TensorFlow/Keras is better than the other — in fact, it’s the wrong question to be asking (image source).
This is the wrong question to ask, especially if you are a novice in deep learning. Neither is better than the other. Keras and TensorFlow have specific uses, just as PyTorch does.
For example, you wouldn’t make a blanket statement saying that Java is unequivocally better than Python. When working with machine learning and data science, there is a strong argument for Python over Java. But if you intend on developing enterprise applications running on multiple architectures with high reliability, then Java is likely a better choice.
Unfortunately, we humans tend to get “entrenched” in our thinking once we become loyal to a particular camp or group. The entrenchment surrounding PyTorch versus Keras/TensorFlow can sometimes get ugly, once prompting François Chollet, the creator of Keras, to ask PyTorch users to stop sending him hate mail:
Figure 9: Deep learning practitioners can sometimes become too entrenched in their views and choice of libraries, prompting François Chollet, the creator of Keras, to ask PyTorch users to stop sending him hate mail (original tweet).
The hate mail isn’t limited to François, either. I’ve used Keras and TensorFlow in a good amount of my deep learning tutorials here on PyImageSearch, and I’m saddened to report that I’ve received hate mail criticizing me for using Keras/TensorFlow, calling me stupid/dumb, telling me to shut down PyImageSearch, that I’m not a “real” deep learning practitioner (whatever that means).
I’m sure other educators have experienced similar acts, regardless of whether they wrote tutorials using Keras/TensorFlow or PyTorch. Both sides get ugly, it’s not limited to PyTorch users.
My point here is that you shouldn’t become so entrenched that you attack others based on what deep learning library they use. Seriously, there are more important issues in the world that deserve your attention — and you really don’t need to use the reply button on your email client or social media platform to instigate and catalyze more hate into our already fragile world.
Secondly, if you are new to deep learning, it truly doesn’t matter which library you start with. The APIs of both PyTorch 1.x and TensorFlow 2.x have converged — both implement similar functionality, just done in different ways.
What you learn in one library will transfer to the other, just like learning a new programming language. The first language you learn is often the hardest since you are not only learning the syntax of the language, but also the control structures and program design.
Your second programming language is often an order of magnitude easier to learn since by that point you already understand the basics of control and program design.
The same is true for deep learning libraries. Just pick one and learn it. If you have trouble picking, flip a coin — it genuinely doesn’t matter, your experience will transfer regardless.
Should I use PyTorch instead of TensorFlow/Keras?
Figure 10: If you’re brand new to deep learning, just pick either PyTorch or Keras/TensorFlow. The APIs are essentially converged at this point.
As I’ve mentioned multiple times in this post, choosing between Keras/TensorFlow and PyTorch doesn’t involve making blanket statements such as:
“If you are doing research, you should absolutely use PyTorch.”
“If you’re a beginner, you should use Keras.”
“If you’re developing an industry application, use TensorFlow and Keras.”
Much of the feature sets between PyTorch/Keras and TensorFlow are converged — both contain essentially the same set of features, just accomplished in different ways.
If you are brand new to deep learning, just pick one and learn it. Personally, I do think Keras is the most suitable for teaching budding deep learning practitioners. I also think that Keras is the best choice to rapidly prototype and deploy deep learning models.
That said, PyTorch does make it easier for more advanced practitioners to implement custom training loops, layer types, and architectures. This argument is somewhat diminished now that the TensorFlow 2.x API is out, but I believe it’s still worth mentioning.
Most importantly, whatever deep learning library you use or choose to learn, don’t become a fanatic, don’t troll message boards, and in general, don’t cause problems. There’s enough hate in this world already — as a scientific community we should be above the hate mail and hair pulling.
Course information:
23 total classes • 35h 14m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 3,690 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 23 courses on essential computer vision, deep learning, and OpenCV topics
✓ 23 Certificates of Completion
✓ 35h 14m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned about the PyTorch deep learning library, including:
What PyTorch is
How to install PyTorch on your machine
PyTorch GPU support
Why PyTorch is popular in the research community
Whether to use PyTorch or Keras/TensorFlow in your projects
Next week, you’ll gain some hands-on experience with PyTorch by implementing and training your first neural network.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Intro to PyTorch: Training your first neural network using PyTorch (today’s tutorial)
PyTorch: Training your first Convolutional Neural Network (next week’s tutorial)
PyTorch image classification with pre-trained networks
PyTorch object detection with pre-trained networks
By the end of this guide, you will have learned:
How to define a basic neural network architecture with PyTorch
How to define your loss function and optimizer
How to properly zero your gradient, perform backpropagation, and update your model parameters — most deep learning practitioners new to PyTorch make a mistake in this step
To learn how to train your first neural network with PyTorch, just keep reading.
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation— PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having problems configuring your development environment?
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
To follow along with this tutorial, be sure to access the “Downloads” section of this guide to retrieve the source code.
You’ll then be presented with the following directory structure.
The mlp.py file will store our implementation of a basic multi-layer perceptron (MLP).
We’ll then implement train.py which will be used to train our MLP on an example dataset.
Implementing our neural network with PyTorch
Figure 2: Implementing a basic multi-layer perceptron with PyTorch.
You are now about ready to implement your first neural network with PyTorch!
This network is a very simple feedforward neural network called a multi-layer perceptron (MLP) (meaning that it has one or more hidden layers). You’ll learn how to build more advanced neural network architectures next week’s tutorial.
To get started building our PyTorch neural network, open the mlp.py file in the pyimagesearch module of your project directory structure, and let’s get to work:
# import the necessary packages
from collections import OrderedDict
import torch.nn as nn
def get_training_model(inFeatures=4, hiddenDim=8, nbClasses=3):
# construct a shallow, sequential neural network
mlpModel = nn.Sequential(OrderedDict([
("hidden_layer_1", nn.Linear(inFeatures, hiddenDim)),
("activation_1", nn.ReLU()),
("output_layer", nn.Linear(hiddenDim, nbClasses))
]))
# return the sequential model
return mlpModel
Lines 2 and 3 import our required Python packages:
OrderedDict: A dictionary object that remembers the order in which objects were added — we use this ordered dictionary to provide human-readable names to each layer in the network
nn: PyTorch’s neural network implementations
We then define the get_training_model function (Line 5) which accepts three parameters:
The number of input nodes to the neural network
The number of nodes in the hidden layer of the network
The number of output nodes (i.e., dimensionality of the output prediction)
Based on the default values provided, you can see that we are building a 4-8-3 neural network, meaning that the input layer has 4 nodes, the hidden layer 8 nodes, and the output of the neural network will consist of 3 values.
The actual neural network architecture is then constructed on Lines 7-11 by first initializing a nn.Sequential object (very similar to Keras/TensorFlow’s Sequential class).
Inside the Sequential class we build an OrderedDict where each entry in the dictionary consists of two values:
A string containing the human-readable name for the layer (which is very useful when debugging neural network architectures using PyTorch)
The PyTorch layer definition itself
The Linear class is our fully connected layer definition, meaning that each of the inputs connects to each of the outputs in the layer. The Linear class accepts two required arguments:
The number of inputs to the layer
The number of outputs
On Line 8, we define hidden_layer_1 which consists of a fully connected layer accepting inFeatures (4) inputs and then producing an output of hiddenDim (8).
From there, we apply a ReLU activation function (Line 9) followed by another Linear layer which serves as our output (Line 10).
Notice that the second Linear definition contains the same number of inputs as the previousLinear layer did outputs — this is not by accident! The output dimensions of the previous layer must match the input dimensions of the next layer, otherwise PyTorch will error out (and then you’ll have the quite tedious task of debugging the layer dimensions yourself).
PyTorch is not as forgiving in this regard (as opposed to Keras/TensorFlow), so be extra cautious when specifying your layer dimensions.
The resulting PyTorch neural network is then returned to the calling function.
Creating our PyTorch training script
With our neural network architecture implemented, we can move on to training the model using PyTorch.
To accomplish this task, we’ll need to implement a training script which:
Creates an instance of our neural network architecture
Builds our dataset
Determines whether or not we are training our model on a GPU
Defines a training loop (the hardest part of our script)
Open train.py, and lets get started:
# import the necessary packages
from pyimagesearch import mlp
from torch.optim import SGD
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs
import torch.nn as nn
import torch
make_blobs: Builds a synthetic dataset of example data
train_test_split: Splits our dataset into a training and testing split
nn: PyTorch’s neural network functionality
torch: The base PyTorch library
When training a neural network, we do so in batches of data (as you’ve previously learned). The following function, next_batch, yields such batches to our training loop:
def next_batch(inputs, targets, batchSize):
# loop over the dataset
for i in range(0, inputs.shape[0], batchSize):
# yield a tuple of the current batched data and labels
yield (inputs[i:i + batchSize], targets[i:i + batchSize])
The next_batch function accepts three arguments:
inputs: Our input data to the neural network
targets: Our target output values (i.e., what we want our neural network to accurately predict)
batchSize: Size of data batch
We then loop over our input data in batchSize chunks (Line 11) and yield them to the calling function (Line 13).
Next, we have some important initializations to take care of:
# specify our batch size, number of epochs, and learning rate
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-2
# determine the device we will be using for training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("[INFO] training using {}...".format(DEVICE))
When training our neural network with PyTorch we’ll use a batch size of 64, train for 10 epochs, and use a learning rate of 1e-2 (Lines 16-18).
We set our training device (either CPU or GPU) on Line 21. A GPU will certainly speed up training but is not required for this example.
Next, we need an example dataset to train our neural network on. We’ll learn how to load images from disk and train a neural network on image data in the next tutorial in this series, but for now, let’s use scikit-learn’s make_blobs function to create a synthetic dataset for us:
# generate a 3-class classification problem with 1000 data points,
# where each data point is a 4D feature vector
print("[INFO] preparing data...")
(X, y) = make_blobs(n_samples=1000, n_features=4, centers=3,
cluster_std=2.5, random_state=95)
# create training and testing splits, and convert them to PyTorch
# tensors
(trainX, testX, trainY, testY) = train_test_split(X, y,
test_size=0.15, random_state=95)
trainX = torch.from_numpy(trainX).float()
testX = torch.from_numpy(testX).float()
trainY = torch.from_numpy(trainY).float()
testY = torch.from_numpy(testY).float()
Lines 27 and 28 build our dataset, consisting of:
Three class labels (centers=3)
Four total features/inputs to the neural network (n_features=4)
A total of 1000 data points (n_samples=1000)
Essentially, the make_blobs function is generating Gaussian blobs of clustered data points. For 2D data, the make_blobs function would create data similar to the following:
Figure 3: An example 2D dataset generated using scikit-learn’s “make_blobs” function (image source).
Notice there are three clusters of data here. We are doing the same thing, but instead of two dimensions we have four dimensions (meaning we cannot easily visualize it).
Once our data is generated, we apply the train_test_split function (Lines 32 and 33) to create our training split, 85% for training and 15% for evaluation.
From there, the training and testing data is converted to PyTorch tensors from NumPy arrays, and then converted to the floating point data type (Lines 34-37).
Let’s now instantiate our PyTorch neural network architecture:
# initialize our model and display its architecture
mlp = mlp.get_training_model().to(DEVICE)
print(mlp)
# initialize optimizer and loss function
opt = SGD(mlp.parameters(), lr=LR)
lossFunc = nn.CrossEntropyLoss()
Line 40 initializes our MLP and pushes it to whatever DEVICE we are using for training (either CPU or GPU).
Line 44 defines our SGD optimizer, which accepts two arguments:
The MLP model parameters, obtained by simply calling mlp.parameters()
The learning rate
Finally, we initialize our categorical cross-entropy loss function, which is the standard loss method you’ll use when performing classification with > 2 classes.
We now arrive at our most important code block, the training loop. Unlike Keras/TensorFlow, which allow you to simply call model.fit to train your model, PyTorch requires that you implement your training loop by hand.
There are pros and cons of having to implement the training loop by hand.
On one side of the spectrum, you have complete and total control over the training procedure, which makes it easier to implement custom training loops.
But on the other side of the spectrum, implementing a training loop by hand requires more code, and worst of all, makes it far easier to shoot yourself in the foot (which can be especially true for budding deep learning practitioners).
My suggestion: You’ll want to read the explanations to the following code blocks multiple times so that you understand the intricacies of the training loop. You’ll especially want to pay close attention to how we zero the gradient, perform backpropagation, and then update the model parameters — failing to do so in that exact order will lead to erroneous results!
Let’s review our training loop:
# create a template to summarize current training progress
trainTemplate = "epoch: {} test loss: {:.3f} test accuracy: {:.3f}"
# loop through the epochs
for epoch in range(0, EPOCHS):
# initialize tracker variables and set our model to trainable
print("[INFO] epoch: {}...".format(epoch + 1))
trainLoss = 0
trainAcc = 0
samples = 0
mlp.train()
# loop over the current batch of data
for (batchX, batchY) in next_batch(trainX, trainY, BATCH_SIZE):
# flash data to the current device, run it through our
# model, and calculate loss
(batchX, batchY) = (batchX.to(DEVICE), batchY.to(DEVICE))
predictions = mlp(batchX)
loss = lossFunc(predictions, batchY.long())
# zero the gradients accumulated from the previous steps,
# perform backpropagation, and update model parameters
opt.zero_grad()
loss.backward()
opt.step()
# update training loss, accuracy, and the number of samples
# visited
trainLoss += loss.item() * batchY.size(0)
trainAcc += (predictions.max(1)[1] == batchY).sum().item()
samples += batchY.size(0)
# display model progress on the current training batch
trainTemplate = "epoch: {} train loss: {:.3f} train accuracy: {:.3f}"
print(trainTemplate.format(epoch + 1, (trainLoss / samples),
(trainAcc / samples)))
Line 48 initializes trainTemplate, a string that will allow us to conveniently display the epoch number, along with the loss and accuracy at each step.
We then loop over our number of desired training epochs on Line 51. Immediately inside this for loop we:
Show the epoch number, which is useful for debugging purposes (Line 53)
Initialize our training loss and accuracy (Lines 54 and 55)
Initialize the total number of data points used inside the current iteration of the training loop (Line 56)
Put the PyTorch model in training mode (Line 57)
Calling the train() method of the PyTorch model is required for the model parameters to be updated during backpropagation.
In our next code block, you’ll see that we put the model into eval() mode so that we can evaluate the loss and accuracy on our testing set. If we forgot to then call train() at the top of the next training loop, then our model parameters will not be updated.
The outer for loop (Line 51) loops over our number of epochs. Line 60 then starts an inner for loop that loops over each of our batches in the training set. Nearly every training procedure you write using PyTorch will consist of an outer loop (over the number of epochs) and an inner loop (over the data batches).
Within the inner loop (i.e., the batch loop), we proceed to:
Move the batchX and batchY data to our CPU or GPU (depending on our DEVICE)
Pass the batchX data through the neural and make predictions on it
Use our loss function to compute our loss by comparing the output predictions to our ground-truth class labels
Now that we have our loss, we can update our model parameters — this is the most important step in the PyTorch training procedure and often the one most beginners mess up.
To update the parameters of our model, we must call Lines 69-71 in the exact order specified:
opt.zero_grad(): Zeros the gradients accumulated from the previous batch/step of the model
loss.backward(): Performs backpropagation
opt.step(): Updates the weights in our neural network based on the results of backpropagation
Again, I want to stress that you must apply zeroing the gradients, performing a backward pass, and then updating the model parameters in the exact order that I’ve indicated.
As I’ve mentioned, PyTorch gives you a lot of control over your training loop … but it also makes it very easy to shoot yourself in the foot. Every single deep learning practitioner, whether brand new to the world of deep learning or a seasoned expert, has at one time or another messed up these steps.
The most common mistake is forgetting to zero the gradient. If you don’t zero the gradient then you’ll accumulate gradients across multiple batches and over multiple epochs. That will mess up your backpropagation and lead to erroneous weight updates.
Seriously, don’t mess up these steps. Write them on a sticky note and put them on your monitor if you need to.
After we’ve updated the weights to our model, we compute our train loss, train accuracy, and number of samples examined (i.e., number of data points in the batch) on Lines 75-77.
We then apply our trainTemplate to display our epoch number, training loss, and training accuracy. Note how we divide our loss and accuracy by the total number of samples in the batch to obtain an average.
At this point, we’ve trained our PyTorch model on all data points in an epoch — now we need to evaluate it on our testing set:
# initialize tracker variables for testing, then set our model to
# evaluation mode
testLoss = 0
testAcc = 0
samples = 0
mlp.eval()
# initialize a no-gradient context
with torch.no_grad():
# loop over the current batch of test data
for (batchX, batchY) in next_batch(testX, testY, BATCH_SIZE):
# flash the data to the current device
(batchX, batchY) = (batchX.to(DEVICE), batchY.to(DEVICE))
# run data through our model and calculate loss
predictions = mlp(batchX)
loss = lossFunc(predictions, batchY.long())
# update test loss, accuracy, and the number of
# samples visited
testLoss += loss.item() * batchY.size(0)
testAcc += (predictions.max(1)[1] == batchY).sum().item()
samples += batchY.size(0)
# display model progress on the current test batch
testTemplate = "epoch: {} test loss: {:.3f} test accuracy: {:.3f}"
print(testTemplate.format(epoch + 1, (testLoss / samples),
(testAcc / samples)))
print("")
Similar to how we initialized our training loss, training accuracy, and number of samples in a batch, we do the same thing for our testing set on Lines 86-88. Here, we initialize variables to store our testing loss, testing accuracy, and number of samples in the testing set.
We also put our model into eval() model on Line 89. We are required to put our model in evaluation mode when we need to compute losses/accuracies on the testing or validation set.
But what does the eval() mode actually do?You think of evaluation mode as a switch for turning off specific layer functionality, such as stopping dropout from being applied, or allowing the accumulated states of batch normalization to be applied.
Secondly, you typically use eval() in conjunction with a torch.no_grad() context, meaning that gradient computation is turned off in evaluation mode (Line 92).
From there, we loop over all batches in our testing set (Line 94), similar to how we looped over our training batches in the previous code block.
For each batch (Line 96), we make predictions using our model and then compute the loss (Lines 99 and 100).
We then update our testLoss, testAcc, and number of samples (Lines 104-106).
Finally, we display our epoch number, testing loss, and testing accuracy on our terminal (Lines 109-112).
In general, the evaluation portion of our training loop is very similar to the training portion, with no minor but very significant changes:
We put our model into evaluation mode using eval()
We use a torch.no_grad() context to ensure no graduation computation is performed
From there, we can make predictions using our model and compute the accuracy/loss on the testing set.
PyTorch training results
We are now ready to train our neural network with PyTorch!
Be sure to access the “Downloads” section of this tutorial to retrieve the source code.
To launch the PyTorch training process, simply execute the train.py script:
$ python train.py
[INFO] training on cuda...
[INFO] preparing data...
Sequential(
(hidden_layer_1): Linear(in_features=4, out_features=8, bias=True)
(activation_1): ReLU()
(output_layer): Linear(in_features=8, out_features=3, bias=True)
)
[INFO] training in epoch: 1...
epoch: 1 train loss: 0.971 train accuracy: 0.580
epoch: 1 test loss: 0.737 test accuracy: 0.827
[INFO] training in epoch: 2...
epoch: 2 train loss: 0.644 train accuracy: 0.861
epoch: 2 test loss: 0.590 test accuracy: 0.893
[INFO] training in epoch: 3...
epoch: 3 train loss: 0.511 train accuracy: 0.916
epoch: 3 test loss: 0.495 test accuracy: 0.900
[INFO] training in epoch: 4...
epoch: 4 train loss: 0.425 train accuracy: 0.941
epoch: 4 test loss: 0.423 test accuracy: 0.933
[INFO] training in epoch: 5...
epoch: 5 train loss: 0.359 train accuracy: 0.961
epoch: 5 test loss: 0.364 test accuracy: 0.953
[INFO] training in epoch: 6...
epoch: 6 train loss: 0.302 train accuracy: 0.975
epoch: 6 test loss: 0.310 test accuracy: 0.960
[INFO] training in epoch: 7...
epoch: 7 train loss: 0.252 train accuracy: 0.984
epoch: 7 test loss: 0.259 test accuracy: 0.967
[INFO] training in epoch: 8...
epoch: 8 train loss: 0.209 train accuracy: 0.987
epoch: 8 test loss: 0.215 test accuracy: 0.980
[INFO] training in epoch: 9...
epoch: 9 train loss: 0.174 train accuracy: 0.988
epoch: 9 test loss: 0.180 test accuracy: 0.980
[INFO] training in epoch: 10...
epoch: 10 train loss: 0.147 train accuracy: 0.991
epoch: 10 test loss: 0.153 test accuracy: 0.980
Our first few lines of output show the simple 4-8-3 MLP architecture, meaning that there are four inputs to the neural network, a single hidden layer with eight nodes, and a final output layer with three nodes.
We then train our network for a total of ten epochs. By the end of the training process, we are obtaining 99.1% accuracy on our training set and 98% accuracy on our testing set.
We can therefore conclude that our neural network is doing a good job making accurate predictions.
Congrats on training your first neural network with PyTorch!
How do I train a PyTorch model on my own custom dataset?
This tutorial showed you how to train a PyTorch neural network on an example dataset generated by scikit-learn’s make_blobs function.
While this was a great example to learn the basics of PyTorch, it’s admittedly not very interesting from a real-world scenario perspective.
Next week, you’ll learn how to train a PyTorch model on a dataset of handwritten characters, which has many practical applications, including handwriting recognition, OCR, and more!
Stay tuned for next week’s tutorial to learn more about PyTorch and image classification.
Course information:
25 total classes • 37h 19m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 3,690 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
✓ 25 Certificates of Completion
✓ 37h 19m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned how to train your first neural network using the PyTorch deep learning library. This example was admittedly simple, but demonstrated the fundamentals of the PyTorch framework.
The biggest mistake I see with deep learning practitioners new to the PyTorch library is forgetting and/or mixing up the following steps:
Zeroing out gradients from the previous steps (opt.zero_grad())
Performing backpropagation (loss.backward())
Updating model parameters (opt.step())
Failure to perform these steps in this exact order is a surefire way to shoot yourself in the foot when using PyTorch, and worse, PyTorch doesn’t report an error if you mix up these steps, so you may not even know you shot yourself!
The PyTorch library is super powerful, but you’ll need to get used to the fact that training a neural network with PyTorch is like taking off your bicycle’s training wheels — there’s no safety net to catch you if you mix up important steps (unlike with Keras/TensorFlow which allow you to encapsulate entire training procedures into a single model.fit call).
That’s not to say that Keras/TensorFlow are “better” than PyTorch — it’s just a difference between the two deep learning libraries of which you need to be aware.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In this tutorial, we will be building a complete end-to-end application that can detect smiles in a video stream in real-time using deep learning along with traditional computer vision techniques.
To accomplish this task, we’ll be training the LetNet architecture on a dataset of images that contain faces of people who are smiling and not smiling. Once our network is trained, we’ll create a separate Python script — this one will detect faces in images via OpenCV’s built-in Haar cascade face detector, extract the face region of interest (ROI) from the image, and then pass the ROI through LeNet for smile detection.
To learn how to detect a smile with OpenCV, Keras, and TensorFlow, just keep reading.
Smile detection with OpenCV, Keras, and TensorFlow
When developing real-world applications for image classification, you’ll often have to mix traditional computer vision and image processing techniques with deep learning. I’ve done my best to ensure this tutorial stands on its own in terms of algorithms, techniques, and libraries you need to understand in order to be successful when studying and applying deep learning.
Configuring your development environment
To follow this guide, you need to have the OpenCV library installed on your system.
Luckily, OpenCV is pip-installable:
$ pip install opencv-contrib-python
If you need help configuring your development environment for OpenCV, I highly recommend that you read my pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having problems configuring your development environment?
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
The SMILES Dataset
The SMILES dataset consists of images of faces that are either smiling or not smiling (Hromada, 2010). In total, there are 13,165 grayscale images in the dataset, with each image having a size of 64×64 pixels.
As Figure 2 demonstrates, images in this dataset are tightly cropped around the face, which will make the training process easier as we’ll be able to learn the “smiling” or “not smiling” patterns directly from the input images.
Figure 2:Top: Examples of “smiling” faces. Bottom: Samples of “not smiling” faces. In this tutorial, we will be training a Convolutional Neural Network to recognize between smiling and not smiling faces in real-time video streams.
However, the close cropping poses a problem during testing — since our input images will not only contain a face but the background of the image as well, we first need to localize the face in the image and extract the face ROI before we can pass it through our network for detection. Luckily, using traditional computer vision methods such as Haar cascades, this is a much easier task than it sounds.
A second issue we need to handle in the SMILES dataset is class imbalance. While there are 13,165 images in the dataset, 9,475 of these examples are not smiling, while only 3,690 belong to the smiling class. Given that there are over 2.5x the number of “not smiling” images to “smiling” examples, we need to be careful when devising our training procedure.
Our network may naturally pick the “not smiling” label since (1) the distributions are uneven and (2) it has more examples of what a “not smiling” face looks like. Later, you will see how we can combat class imbalance by computing a “weight” for each class during training time.
Training the Smile CNN
The first step in building our smile detector is to train a CNN on the SMILES dataset to distinguish between a face that is smiling versus not smiling. To accomplish this task, let’s create a new file named train_model.py. From there, insert the following code:
# import the necessary packages
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.utils import to_categorical
from pyimagesearch.nn.conv import LeNet
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import imutils
import cv2
import os
Lines 2-14 import our required Python packages. We’ve used all of the packages before, but I want to call your attention to Line 7, where we import the LeNet (LeNet Tutorial) class — this is the architecture we’ll be using when creating our smile detector.
Next, let’s parse our command line arguments:
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,
help="path to input dataset of faces")
ap.add_argument("-m", "--model", required=True,
help="path to output model")
args = vars(ap.parse_args())
# initialize the list of data and labels
data = []
labels = []
Our script will require two command line arguments, each of which I’ve detailed below:
--dataset: The path to the SMILES directory residing on disk.
--model: The path to where the serialized LeNet weights will be saved after training.
We are now ready to load the SMILES dataset from disk and store it in memory:
# loop over the input images
for imagePath in sorted(list(paths.list_images(args["dataset"]))):
# load the image, pre-process it, and store it in the data list
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = imutils.resize(image, width=28)
image = img_to_array(image)
data.append(image)
# extract the class label from the image path and update the
# labels list
label = imagePath.split(os.path.sep)[-3]
label = "smiling" if label == "positives" else "not_smiling"
labels.append(label)
On Line 29, we loop over all images in the --dataset input directory. For each of these images, we:
Load it from disk (Line 31).
Convert it to grayscale (Line 32).
Resize it to have a fixed input size of 28×28 pixels (Line 33).
Convert the image to an array compatible with Keras and its channel ordering (Line 34).
Add the image to the data list that LeNet will be trained on.
Lines 39-41 handle extracting the class label from the imagePath and updating the labels list. The SMILES dataset stores smiling faces in the SMILES/positives/positives7 subdirectory, while not smiling faces live in the SMILES/negatives/negatives7 subdirectory.
Therefore, given the path to an image:
SMILEs/positives/positives7/10007.jpg
We can extract the class label by splitting on the image path separator and grabbing the third-to-last subdirectory: positives. In fact, this is exactly what Line 39 accomplishes.
Now that our data and labels are constructed, we can scale the raw pixel intensities to the range [0, 1] and then apply one-hot encoding to the labels:
# scale the raw pixel intensities to the range [0, 1]
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# convert the labels from integers to vectors
le = LabelEncoder().fit(labels)
labels = to_categorical(le.transform(labels), 2)
Our next code block handles our data imbalance issue by computing the class weights:
# calculate the total number of training images in each class and
# initialize a dictionary to store the class weights
classTotals = labels.sum(axis=0)
classWeight = dict()
# loop over all classes and calculate the class weight
for i in range(0, len(classTotals)):
classWeight[i] = classTotals.max() / classTotals[i]
Line 53 computes the total number of examples per class. In this case, classTotals will be an array: [9475, 3690] for “not smiling” and “smiling,” respectively.
We then scale these totals on Lines 57 and 58 to obtain the classWeight used to handle the class imbalance, yielding the array: [1, 2.56]. This weighting implies that our network will treat every instance of “smiling” as 2.56 instances of “not smiling” and helps combat the class imbalance issue by amplifying the per-instance loss by a larger weight when seeing “smiling” examples.
Now that we’ve computed our class weights, we can move on to partitioning our data into training and testing splits, using 80% of the data for training and 20% for testing:
# partition the data into training and testing splits using 80% of
# the data for training and the remaining 20% for testing
(trainX, testX, trainY, testY) = train_test_split(data,
labels, test_size=0.20, stratify=labels, random_state=42)
Finally, we are ready to train LeNet:
# initialize the model
print("[INFO] compiling model...")
model = LeNet.build(width=28, height=28, depth=1, classes=2)
model.compile(loss="binary_crossentropy", optimizer="adam",
metrics=["accuracy"])
# train the network
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY),
class_weight=classWeight, batch_size=64, epochs=15, verbose=1)
Line 67 initializes the LeNet architecture that will accept 28×28 single channel images. Given that there are only two classes (smiling versus not smiling), we set classes=2.
We’ll also be using binary_crossentropy rather than categorical_crossentropy as our loss function. Again, categorical cross-entropy is only used when the number of classes is more than two.
Up until this point, we’ve been using the SGD optimizer to train our network. Here, we’ll be using Adam (Kingma and Ba, 2014) (Line 68).
Again, the optimizer and associated parameters are often considered hyperparameters that you need to tune when training your network. When I put this example together, I found that Adam performed substantially better than SGD.
Lines 73 and 74 train LeNet for a total of 15 epochs using our supplied classWeight to combat class imbalance.
Once our network is trained, we can evaluate it and serialize the weights to disk:
# evaluate the network
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=64)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1), target_names=le.classes_))
# save the model to disk
print("[INFO] serializing network...")
model.save(args["model"])
We’ll also construct a learning curve for our network so we can visualize performance:
# plot the training + testing loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 15), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 15), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 15), H.history["accuracy"], label="acc")
plt.plot(np.arange(0, 15), H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.show()
To train our smile detector, execute the following command:
After 15 epochs, we can see that our network is obtaining 93% classification accuracy. Figure 3 plots our learning curve:
Figure 3: A plot of the learning curve for the LeNet architecture trained on the SMILES dataset. After fifteen epochs we are obtaining ≈93% classification accuracy on our testing set.
Past epoch six our validation loss starts to stagnate — further training past epoch 15 would result in overfitting. If desired, we would improve the accuracy of our smile detector by using more training data, either by:
Gathering additional training data.
Applying data augmentation to randomly translate, rotate, and shift our existing training set.
Running the Smile CNN in Real-time
Now that we’ve trained our model, the next step is to build the Python script to access our webcam/video file and apply smile detection to each frame. To accomplish this step, open a new file, name it detect_smile.py, and we’ll get to work.
# import the necessary packages
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import numpy as np
import argparse
import imutils
import cv2
Lines 2-7 import our required Python packages. The img_to_array function will be used to convert each individual frame from our video stream to a properly channel ordered array. The load_model function will be used to load the weights of our trained LeNet model from disk.
The detect_smile.py script requires two command line arguments followed by a third optional one:
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-c", "--cascade", required=True,
help="path to where the face cascade resides")
ap.add_argument("-m", "--model", required=True,
help="path to pre-trained smile detector CNN")
ap.add_argument("-v", "--video",
help="path to the (optional) video file")
args = vars(ap.parse_args())
The first argument, --cascade is the path to a Haar cascade used to detect faces in images. First published in 2001, Paul Viola and Michael Jones detail the Haar cascade in their work, Rapid Object Detection using a Boosted Cascade of Simple Features. This publication has become one of the most cited papers in the computer vision literature.
The Haar cascade algorithm is capable of detecting objects in images, regardless of their location and scale. Perhaps most intriguing (and relevant to our application), the detector can run in real-time on modern hardware. In fact, the motivation behind Viola and Jones’ work was to create a face detector.
The second common line argument, --model, specifies the path to our serialized LeNet weights on disk. Our script will default to reading frames from a built-in/USB webcam; however, if we instead want to read frames from a file, we can specify the file via the optional --video switch.
Before we can detect smiles, we first need to perform some initializations:
# load the face detector cascade and smile detector CNN
detector = cv2.CascadeClassifier(args["cascade"])
model = load_model(args["model"])
# if a video path was not supplied, grab the reference to the webcam
if not args.get("video", False):
camera = cv2.VideoCapture(0)
# otherwise, load the video
else:
camera = cv2.VideoCapture(args["video"])
Lines 20 and 21 load the Haar cascade face detector and the pre-trained LeNet model, respectively. If a video path was not supplied, we grab a pointer to our webcam (Lines 24 and 25). Otherwise, we open a pointer to the video file on disk (Lines 28 and 29).
We have now reached the main processing pipeline of our application:
# keep looping
while True:
# grab the current frame
(grabbed, frame) = camera.read()
# if we are viewing a video and we did not grab a frame, then we
# have reached the end of the video
if args.get("video") and not grabbed:
break
# resize the frame, convert it to grayscale, and then clone the
# original frame so we can draw on it later in the program
frame = imutils.resize(frame, width=300)
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frameClone = frame.copy()
Line 32 starts a loop that will continue until (1) we stop the script or (2) we reach the end of the video file (provided a --video path was applied).
Line 34 grabs the next frame from the video stream. If the frame could not be grabbed, then we have reached the end of the video file. Otherwise, we pre-process the frame for face detection by resizing it to have a width of 300 pixels (Line 43) and converting it to grayscale (Line 44).
The .detectMultiScale method handles detecting the bounding box (x, y)-coordinates of faces in the frame:
# detect faces in the input frame, then clone the frame so that
# we can draw on it
rects = detector.detectMultiScale(gray, scaleFactor=1.1,
minNeighbors=5, minSize=(30, 30),
flags=cv2.CASCADE_SCALE_IMAGE)
Here, we pass in our grayscale image and indicate that for a given region to be considered a face it must have a minimum width of 30×30 pixels. The minNeighbors attribute helps prune false positives while the scaleFactor controls the number of image pyramid (http://pyimg.co/rtped) levels generated.
Again, a detailed review of Haar cascades for object detection is outside the scope of this tutorial.
The .detectMultiScale method returns a list of 4-tuples that make up the rectangle that bounds the face in the frame. The first two values in this list are the starting (x, y)-coordinates. The second two values in the rects list are the width and height of the bounding box, respectively.
We loop over each set of bounding boxes below:
# loop over the face bounding boxes
for (fX, fY, fW, fH) in rects:
# extract the ROI of the face from the grayscale image,
# resize it to a fixed 28x28 pixels, and then prepare the
# ROI for classification via the CNN
roi = gray[fY:fY + fH, fX:fX + fW]
roi = cv2.resize(roi, (28, 28))
roi = roi.astype("float") / 255.0
roi = img_to_array(roi)
roi = np.expand_dims(roi, axis=0)
For each of the bounding boxes, we use NumPy array slicing to extract the face ROI (Line 58). Once we have the ROI, we preprocess it and prepare it for classification via LeNet by resizing it, scaling it, converting it to a Keras-compatible array, and padding the image with an extra dimension (Lines 59-62).
Once the roi is preprocessed, it can be passed through LeNet for classification:
# determine the probabilities of both "smiling" and "not
# smiling", then set the label accordingly
(notSmiling, smiling) = model.predict(roi)[0]
label = "Smiling" if smiling > notSmiling else "Not Smiling"
A call to .predict on Line 66 returns the probabilities of “not smiling” and “smiling,” respectively. Line 67 sets the label depending on which probability is larger.
Once we have the label, we can draw it, along with the corresponding bounding box on the frame:
# display the label and bounding box rectangle on the output
# frame
cv2.putText(frameClone, label, (fX, fY - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 0, 255), 2)
cv2.rectangle(frameClone, (fX, fY), (fX + fW, fY + fH),
(0, 0, 255), 2)
Our final code block handles displaying the output frame on our screen:
# show our detected faces along with smiling/not smiling labels
cv2.imshow("Face", frameClone)
# if the 'q' key is pressed, stop the loop
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# cleanup the camera and close any open windows
camera.release()
cv2.destroyAllWindows()
If the q key is pressed, we exit the script.
To run detect_smile.py using your webcam, execute the following command:
Course information:
25 total classes • 37h 19m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 3,690 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
✓ 25 Certificates of Completion
✓ 37h 19m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, we learned how to build an end-to-end computer vision and deep learning application to perform smile detection. To do so, we first trained the LeNet architecture on the SMILES dataset. Due to class imbalances in the SMILES dataset, we discovered how to compute class weights used to help mitigate the problem.
Once trained, we evaluated LeNet on our testing set and found the network obtained a respectable 93% classification accuracy. Higher classification accuracy can be obtained by gathering more training data or applying data augmentation to existing training data.
We then created a Python script to read frames from a webcam/video file, detect faces, and then apply our pre-trained network. To detect faces, we used OpenCV’s Haar cascades. Once a face was detected it was extracted from the frame and then passed through LeNet to determine if the person was smiling or not smiling. As a whole, our smile detection system can easily run in real-time on the CPU using modern hardware.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In the past, we’ve worked with datasets that have been pre-compiled and labeled for us — but what if we wanted to go about creating our own custom dataset and then training a CNN on it? In this tutorial, I’ll present a complete deep learning case study that will give you an example of:
Downloading a set of images.
Labeling and annotating your images for training.
Training a CNN on your custom dataset.
Evaluating and testing the trained CNN.
The dataset of images we’ll be downloading is a set of captcha images used to prevent bots from automatically registering or logging in to a given website (or worse, trying to brute force their way into someone’s account).
Once we’ve downloaded a set of captcha images we’ll need to manually label each of the digits in the captcha. As we’ll find out, obtaining and labeling a dataset can be half (if not more) the battle. Depending on how much data you need, how easy it is to obtain, and whether or not you need to label the data (i.e., assign a ground-truth label to the image), it can be a costly process, both in terms of time and/or finances (if you pay someone else to label the data).
Therefore, whenever possible we try to use traditional computer vision techniques to speed up the labeling process. If we were to use image processing software such as Photoshop or GIMP to manually extract digits in a captcha image to create our training set, it might take us days of non-stop work to complete the task.
However, by applying some basic computer vision techniques, we can download and label our training set in less than an hour. This is one of the many reasons why I encourage deep learning practitioners to also invest in their computer vision education.
To learn how to break captchas with deep learning, Keras, and TensorFlow, just keep reading.
Breaking captchas with deep learning, Keras, and TensorFlow
I’d also like to mention that datasets in the real-world are not like the benchmark datasets such as MNIST, CIFAR-10, and ImageNet where images are neatly labeled and organized and our goal is only to train a model on the data and evaluate it. These benchmark datasets may be challenging, but in the real-world, the struggle is often obtaining the (labeled) data itself — and in many instances, the labeled data is worth a lot more than the deep learning model obtained from training a network on your dataset.
For example, if you were running a company responsible for creating a custom Automatic License Plate Recognition (ALPR) system for the United States government, you might invest years building a robust, massive dataset, while at the same time evaluating various deep learning approaches to recognizing license plates. Accumulating such a massive labeled dataset would give you a competitive edge over other companies — and in this case, the data itself is worth more than the end product.
Your company would be more likely to be acquired simply because of the exclusive rights you have to the massive, labeled dataset. Building an amazing deep learning model to recognize license plates would only increase the value of your company, but again, labeled data is expensive to obtain and replicate, so if you own the keys to a dataset that is hard (if not impossible) to replicate, make no mistake: your company’s primary asset is the data, not the deep learning.
Let’s look at how we can obtain a dataset of images, label them, and then apply deep learning to break a captcha system.
Configuring your development environment
To follow this guide, you need to have the OpenCV library installed on your system.
Luckily, OpenCV is pip-installable:
$ pip install opencv-contrib-python
If you need help configuring your development environment for OpenCV, I highly recommend that you read my pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having problems configuring your development environment?
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Breaking Captchas with a CNN
Here’s how to think about breaking Captchas. Remember the concept of responsible disclosure — something you should always do when computer security is involved.
The process starts when we create a Python script to automatically download a set of images that we’ll be using for training and evaluation.
After downloading our images, we’ll need to use a bit of computer vision to aid us in labeling the images, making the process much easier and substantially faster than simply cropping and labeling inside photo software like GIMP or Photoshop. Once we have labeled our data, we’ll train the LeNet architecture — as we’ll find out, we’re able to break the captcha system and obtain 100% accuracy in less than 15 epochs.
A Note on Responsible Disclosure
Living in the northeastern/midwestern part of the United States, it’s hard to travel on major highways without an E-ZPass. E-ZPass is an electronic toll collection system used on many bridges, interstates, and tunnels. Travelers simply purchase an E-ZPass transponder, place it on the windshield of their car, and enjoy the ability to quickly travel through tolls without stopping, as a credit card attached to their E-ZPass account is charged for any tolls.
E-ZPass has made tolls a much more “enjoyable” process (if there is such a thing). Instead of waiting in interminable lines where a physical transaction needs to take place (i.e., hand the cashier money, receive your change, get a printed receipt for reimbursement, etc.), you can simply blaze through in the fast lane without stopping — it saves a bunch of time when traveling and is much less of a hassle (you still have to pay the toll though).
I spend much of my time traveling between Maryland and Connecticut, two states along the I-95 corridor of the United States. The I-95 corridor, especially in New Jersey, contains a plethora of toll booths, so an E-ZPass pass was a no-brainer decision for me. About a year ago, the credit card I had attached to my E-ZPass account expired, and I needed to update it. I went to the E-ZPass New York website (the state I bought my E-ZPass in) to log in and update my credit card, but I stopped dead in my tracks (Figure 2).
Figure 2: The E-Z Pass New York login form. Can you spot the flaw in their login system?
Can you spot the flaw in this system? Their “captcha” is nothing more than four digits on a plain white background which is a major security risk — someone with even basic computer vision or deep learning experience could develop a piece of software to break this system.
This is where the concept of responsible disclosurecomes in. Responsible disclosure is a computer security term for describing how to disclose a vulnerability. Instead of posting it on the internet for everyone to see immediately after the threat is detected, you try to contact the stakeholders first to ensure they know there is an issue. The stakeholders can then attempt to patch the software and resolve the vulnerability.
Simply ignoring the vulnerability and hiding the issue is a false security, something that should be avoided. In an ideal world, the vulnerability is resolved before it is publicly disclosed.
However, when stakeholders do not acknowledge the issue or do not fix the problem in a reasonable amount of time it creates an ethical conundrum — do you hide the issue and pretend it doesn’t exist? Or do you disclose it, bringing more attention to the problem in an effort to bring a fix to the problem faster? Responsible disclosure states that you first bring the problem to the stakeholders (responsible) — if it’s not resolved, then you need to disclose the issue (disclosure).
To demonstrate how the E-ZPass NY system was at risk, I trained a deep learning model to recognize the digits in the captcha. I then wrote a second Python script to (1) auto-fill my login credentials and (2) break the captcha, allowing my script access to my account.
In this case, I was only auto-logging into my account. Using this “feature,” I could auto-update a credit card, generate reports on my tolls, or even add a new car to my E-ZPass. But someone nefarious may use this as a method to brute force their way into a customer’s account.
I contacted E-ZPass over email, phone, and Twitter regarding the issue one year beforeI wrote this. They acknowledged the receipt of my messages; however, nothing has been done to fix the issue, despite multiple contacts.
In the rest of this tutorial, I’ll discuss how we can use the E-ZPass system to obtain a captcha dataset which we’ll then label and train a deep learning model on. I will not be sharing the Python code to auto-login to an account — that is outside the boundaries of responsible disclosure so please do not ask me for this code.
Keep in mind that with all knowledge comes responsibility. This knowledge, under no circumstance, should be used for nefarious or unethical reasons. This case study exists as a method to demonstrate how to obtain and label a custom dataset, followed by training a deep learning model on top of it.
I am required to say that I am not responsible for how this code is used — use this as an opportunity to learn, not an opportunity to be nefarious.
The Captcha Breaker Directory Structure
To build the captcha breaker system, we’ll need to update the pyimagesearch.utils submodule and include a new file named captchahelper.py:
This file will store a utility function named preprocess to help us process digits before feeding them into our deep neural network.
We’ll also create a second directory, this one named captcha_breaker, outside of our pyimagesearch module, and include the following files and subdirectories:
The captcha_breaker directory is where all our project code will be stored to break image captchas. The dataset directory is where we will store our labeled digits which we’ll be hand-labeling. I prefer to keep my datasets organized using the following directory structure template:
root_directory/class_name/image_filename.jpg
Therefore, our dataset directory will have the structure:
dataset/{1-9}/example.jpg
where dataset is the root directory, {1-9} are the possible digit names, and example.jpg will be an example of the given digit.
The downloads directory will store the raw captcha .jpg files downloaded from the E-ZPass website. Inside the output directory, we’ll store our trained LeNet architecture.
The download_images.py script, as the name suggests, will be responsible for actually downloading the example captchas and saving them to disk. Once we’ve downloaded a set of captchas we’ll need to extract the digits from each image and hand-label every digit — this will be accomplished by annotate.py.
The train_model.py script will train LeNet on the labeled digits, while test_model.py will apply LeNet to captcha images themselves.
Automatically Downloading Example Images
The first step in building our captcha breaker is to download the example captcha images themselves.
If you copy and paste “https://www.e-zpassny.com/vector/jcaptcha.do” into your web browser and hit refresh multiple times, you’ll notice that this is a dynamic program that generates a new captcha each time you refresh. Therefore, to obtain our example captcha images we need to request this image a few hundred times and save the resulting image.
To automatically fetch new captcha images and save them to disk we can use download_images.py:
# import the necessary packages
import argparse
import requests
import time
import os
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-o", "--output", required=True,
help="path to output directory of images")
ap.add_argument("-n", "--num-images", type=int,
default=500, help="# of images to download")
args = vars(ap.parse_args())
Lines 2-5 import our required Python packages. The requests library makes working with HTTP connections easy and is heavily used in the Python ecosystem. If you do not already have requests installed on your system, you can install it via:
$ pip install requests
We then parse our command line arguments on Lines 8-13. We’ll require a single command line argument, --output, which is the path to the output directory that will store our raw captcha images (we’ll later hand-label each of the digits in the images).
A second optional switch --num-images, controls the number of captcha images we’re going to download. We’ll default this value to 500 total images. Since there are four digits in each captcha, this value of 500 will give us 500×4 = 2,000 total digits that we can use for training our network.
Our next code block initializes the URL of the captcha image we are going to download along with the total number of images generated thus far:
# initialize the URL that contains the captcha images that we will
# be downloading along with the total number of images downloaded
# thus far
url = "https://www.e-zpassny.com/vector/jcaptcha.do"
total = 0
We are now ready to download the captcha images:
# loop over the number of images to download
for i in range(0, args["num_images"]):
try:
# try to grab a new captcha image
r = requests.get(url, timeout=60)
# save the image to disk
p = os.path.sep.join([args["output"], "{}.jpg".format(
str(total).zfill(5))])
f = open(p, "wb")
f.write(r.content)
f.close()
# update the counter
print("[INFO] downloaded: {}".format(p))
total += 1
# handle if any exceptions are thrown during the download process
except:
print("[INFO] error downloading image...")
# insert a small sleep to be courteous to the server
time.sleep(0.1)
On Line 22, we start looping over the --num-images that we wish to download. A request is made on Line 25 to download the image. We then save the image to disk on Lines 28-32. If there was an error downloading the image, our try/except block on Lines 39 and 40 catches it and allows our script to continue. Finally, we insert a small sleep on Line 43 to be courteous to the web server we are requesting.
You can execute download_images.py using the following command:
$ python download_images.py --output downloads
This script will take awhile to run since we have (1) are making a network request to download the image and (2) inserted a 0.1-second pause after each download.
Once the program finishes executing you’ll see that your download directory is filled with images:
$ ls -l downloads/*.jpg | wc -l
500
However, these are just the raw captcha images — we need to extract and label each of the digits in the captchas to create our training set. To accomplish this, we’ll use a bit of OpenCV and image processing techniques to make our life easier.
Annotating and Creating Our Dataset
So, how do you go about labeling and annotating each of our captcha images? Do we open Photoshop or GIMP and use the “select/marquee” tool to copy out a given digit, save it to disk, and then repeat ad nauseam? If we did, it might take us days of non-stop working to label each of the digits in the raw captcha images.
Instead, a better approach would be to use basic image processing techniques inside the OpenCV library to help us out. To see how we can label our dataset more efficiently, open a new file, name it annotate.py, and insert the following code:
# import the necessary packages
from imutils import paths
import argparse
import imutils
import cv2
import os
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--input", required=True,
help="path to input directory of images")
ap.add_argument("-a", "--annot", required=True,
help="path to output directory of annotations")
args = vars(ap.parse_args())
Lines 2-6 import our required Python packages, while Lines 9-14 parse our command line arguments. This script requires two arguments:
--input: The input path to our raw captcha images (i.e., the downloads directory).
--annot: The output path to where we’ll be storing the labeled digits (i.e., the dataset directory).
Our next code block grabs the paths to all images in the --input directory and initializes a dictionary named counts that will store the total number of times a given digit (the key) has been labeled (the value):
# grab the image paths then initialize the dictionary of character
# counts
imagePaths = list(paths.list_images(args["input"]))
counts = {}
The actual annotation process starts below:
# loop over the image paths
for (i, imagePath) in enumerate(imagePaths):
# display an update to the user
print("[INFO] processing image {}/{}".format(i + 1,
len(imagePaths)))
try:
# load the image and convert it to grayscale, then pad the
# image to ensure digits caught on the border of the image
# are retained
image = cv2.imread(imagePath)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.copyMakeBorder(gray, 8, 8, 8, 8,
cv2.BORDER_REPLICATE)
On Line 22, we start looping over each of the individual imagePaths. For each image, we load it from disk (Line 31), convert it to grayscale (Line 32), and pad the borders of the image with eight pixels in every direction (Lines 33 and 34). Figure 3 shows the difference between the original image (left) and the padded image (right).
Figure 3:Left: The original image loaded from disk. Right: Padding the image to ensure we can extract the digits just in case any of the digits are touching the border of the image.
We perform this padding just in case any of our digits are touching the border of the image. If the digits were touching the border, we wouldn’t be able to extract them from the image. Thus, to prevent this situation, we purposely pad the input image so it’s not possible for a given digit to touch the border.
We are now ready to binarize the input image via Otsu’s thresholding method:
# threshold the image to reveal the digits
thresh = cv2.threshold(gray, 0, 255,
cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
This function call automatically thresholds our image such that our image is now binary — black pixels represent the background while white pixels are our foreground as shown in Figure 4.
Figure 4: Thresholding the image ensures the foreground is white while the background is black. This is a typical assumption/requirement when working with many image processing functions with OpenCV.
Thresholding the image is a critical step in our image processing pipeline as we now need to find the outlines of each of the digits:
# find contours in the image, keeping only the four largest
# ones
cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if imutils.is_cv2() else cnts[1]
cnts = sorted(cnts, key=cv2.contourArea, reverse=True)[:4]
Lines 42 and 43 find the contours (i.e., outlines) of each of the digits in the image. Just in case there is “noise” in the image we sort the contours by their area, keeping only the four largest one (i.e., our digits themselves).
Given our contours we can extract each of them by computing the bounding box:
# loop over the contours
for c in cnts:
# compute the bounding box for the contour then extract
# the digit
(x, y, w, h) = cv2.boundingRect(c)
roi = gray[y - 5:y + h + 5, x - 5:x + w + 5]
# display the character, making it large enough for us
# to see, then wait for a keypress
cv2.imshow("ROI", imutils.resize(roi, width=28))
key = cv2.waitKey(0)
On Line 48, we loop over each of the contours found in the thresholded image. We call cv2.boundingRect to compute the bounding box (x, y)-coordinates of the digit region. This region of interest (ROI) is then extracted from the grayscale image on Line 52. I have included a sample of example digits extracted from their raw captcha images as a montage in Figure 5.
Figure 5: A sample of the digit ROIs extracted from our captcha images. Our goal will be to label these images in such a way that we can train a custom Convolutional Neural Network on them.
Line 56 displays the digit ROI to our screen, resizing it to be large enough for us to see easily. Line 57 then waits for a keypress on your keyboard — but choose your keypress wisely! The key you press will be used as the label for the digit.
To see how the labeling process works via the cv2.waitKey call, take a look at the following code block:
# if the '`' key is pressed, then ignore the character
if key == ord("`"):
print("[INFO] ignoring character")
continue
# grab the key that was pressed and construct the path
# the output directory
key = chr(key).upper()
dirPath = os.path.sep.join([args["annot"], key])
# if the output directory does not exist, create it
if not os.path.exists(dirPath):
os.makedirs(dirPath)
If the tilde key “`” (tilde) is pressed, we’ll ignore the character (Lines 60 and 62). Needing to ignore a character may happen if our script accidentally detects “noise” (i.e., anything but a digit) in the input image or if we are not sure what the digit is. Otherwise, we assume that the key pressed was the label for the digit (Line 66) and use the key to construct the directory path to our output label (Line 67).
For example, if I pressed the 7 key on my keyboard, the dirPath would be:
dataset/7
Therefore, all images containing the digit “7” will be stored in the dataset/7 subdirectory. Lines 70 and 71 make a check to see if the dirPath directory does not exist — if it doesn’t, we create it.
Once we have ensured that dirPath properly exists, we simply have to write the example digit to file:
# write the labeled character to file
count = counts.get(key, 1)
p = os.path.sep.join([dirPath, "{}.png".format(
str(count).zfill(6))])
cv2.imwrite(p, roi)
# increment the count for the current key
counts[key] = count + 1
Line 74 grabs the total number of examples written to disk thus far for the current digit. We then construct the output path to the example digit using the dirPath. After executing Lines 75 and 76, our output path p may look like:
datasets/7/000001.png
Again, notice how all example ROIs that contain the number seven will be stored in the datasets/7 subdirectory — this is an easy, convenient way to organize your datasets when labeling images.
Our final code block handles if we want to control-c out of the script to exit or if there is an error processing an image:
# we are trying to control-c out of the script, so break from the
# loop (you still need to press a key for the active window to
# trigger this)
except KeyboardInterrupt:
print("[INFO] manually leaving script")
break
# an unknown error has occurred for this particular image
except:
print("[INFO] skipping image...")
If we wish to control-c and quit the script early, Line 85 detects this and allows our Python program to exit gracefully. Line 90 catches all other errors and simply ignores them, allowing us to continue with the labeling process.
The last thing you want when labeling a dataset is for a random error to occur due to an image encoding problem, causing your entire program to crash. If this happens, you’ll have to restart the labeling process all over again. You can obviously build in extra logic to detect where you left off.
To label the images you downloaded from the E-ZPass NY website, just execute the following command:
Here, you can see that the number 7 is displayed on my screen in Figure 6.
Figure 6: When annotating our dataset of digits, a given digit ROI will display on our screen. We then need to press the corresponding key on our keyboard to label the image and save the ROI to disk.
I then press 7 key on my keyboard to label it and then the digit is written to file in the dataset/7 subdirectory.
The annotate.py script then proceeds to the next digit for me to label. You can then proceed to label all of the digits in the raw captcha images. You’ll quickly realize that labeling a dataset can be a very tedious, time-consuming process. Labeling all 2,000 digits should take you less than half an hour — but you’ll likely become bored within the first five minutes.
Remember, actually obtaining your labeled dataset is half the battle. From there the actual work can start. Luckily, I have already labeled the digits for you! If you check the dataset directory included in the accompanying downloads of this tutorial you’ll find the entire dataset ready to go:
$ ls dataset/
1 2 3 4 5 6 7 8 9
$ ls -l dataset/1/*.png | wc -l
232
Here, you can see nine subdirectories, one for each of the digits that we wish to recognize. Inside each subdirectory, there are example images of the particular digit. Now that we have our labeled dataset, we can proceed to training our captcha breaker using the LeNet architecture.
Preprocessing the Digits
As we know, our Convolutional Neural Networks require an image with a fixed width and height to be passed in during training. However, our labeled digit images are of various sizes — some are taller than they are wide, others are wider than they are tall. Therefore, we need a method to pad and resize our input images to a fixed size without distorting their aspect ratio.
We can resize and pad our images while preserving the aspect ratio by defining a preprocess function inside captchahelper.py:
# import the necessary packages
import imutils
import cv2
def preprocess(image, width, height):
# grab the dimensions of the image, then initialize
# the padding values
(h, w) = image.shape[:2]
# if the width is greater than the height then resize along
# the width
if w > h:
image = imutils.resize(image, width=width)
# otherwise, the height is greater than the width so resize
# along the height
else:
image = imutils.resize(image, height=height)
Our preprocess function requires three parameters:
image: The input image that we are going to pad and resize.
width: The target output width of the image.
height: The target output height of the image.
On Lines 12 and 13, we make a check to see if the width is greater than the height, and if so, we resize the image along the larger dimension (width) Otherwise, if the height is greater than the width, we resize along the height (Lines 17 and 18), which implies either the width or height (depending on the dimensions of the input image) are fixed.
However, the opposite dimension is smaller than it should be. To fix this issue, we can “pad” the image along the shorter dimension to obtain our fixed size:
# determine the padding values for the width and height to
# obtain the target dimensions
padW = int((width - image.shape[1]) / 2.0)
padH = int((height - image.shape[0]) / 2.0)
# pad the image then apply one more resizing to handle any
# rounding issues
image = cv2.copyMakeBorder(image, padH, padH, padW, padW,
cv2.BORDER_REPLICATE)
image = cv2.resize(image, (width, height))
# return the pre-processed image
return image
Lines 22 and 23 compute the required amount of padding to reach the target width and height. Lines 27 and 28 apply the padding to the image. Applying this padding should bring our image to our target width and height; however, there may be cases where we are one pixel off in a given dimension. The easiest way to resolve this discrepancy is to simply call cv2.resize (Line 29) to ensure all images are the same width and height.
The reason we do not immediately call cv2.resize at the top of the function is that we first need to consider the aspect ratio of the input image and attempt to pad it correctly first. If we do not maintain the image aspect ratio, then our digits will become distorted.
Training the Captcha Breaker
Now that our preprocess function is defined, we can move on to training LeNet on the image captcha dataset. Open the train_model.py file and insert the following code:
# import the necessary packages
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.optimizers import SGD
from pyimagesearch.nn.conv import LeNet
from pyimagesearch.utils.captchahelper import preprocess
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import argparse
import cv2
import os
Lines 2-14 import our required Python packages. Notice that we’ll be using the SGD optimizer along with the LeNet architecture to train a model on the digits. We’ll also be using our newly defined preprocess function on each digit before passing it through our network.
Next, let’s review our command line arguments:
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,
help="path to input dataset")
ap.add_argument("-m", "--model", required=True,
help="path to output model")
args = vars(ap.parse_args())
The train_model.py script requires two command line arguments:
--dataset: The path to the input dataset of labeled captcha digits (i.e., the dataset directory on disk).
--model: Here we supply the path to where our serialized LeNet weights will be saved after training.
We can now load our data and corresponding labels from disk:
# initialize the data and labels
data = []
labels = []
# loop over the input images
for imagePath in paths.list_images(args["dataset"]):
# load the image, pre-process it, and store it in the data list
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = preprocess(image, 28, 28)
image = img_to_array(image)
data.append(image)
# extract the class label from the image path and update the
# labels list
label = imagePath.split(os.path.sep)[-2]
labels.append(label)
On Lines 25 and 26, we initialize our data and labels lists, respectively. We then loop over every image in our labeled --dataset on Line 29. For each image in the dataset, we load it from disk, convert it to grayscale, and preprocess it such that it has a width of 28 pixels and a height of 28 pixels (Lines 31-35). The image is then converted to a Keras-compatible array and added to the data list (Lines 34 and 35).
One of the primary benefits of organizing your dataset directory structure in the format of:
root_directory/class_label/image_filename.jpg
is that you can easily extract the class label by grabbing the second-to-last component from the filename (Line 39). For example, given the input path dataset/7/000001.png, the label would be 7, which is then added to the labels list (Line 40).
Our next code block handles normalizing raw pixel intensity values to the range [0, 1], followed by constructing the training and testing splits, along with one-hot encoding the labels:
# scale the raw pixel intensities to the range [0, 1]
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# partition the data into training and testing splits using 75% of
# the data for training and the remaining 25% for testing
(trainX, testX, trainY, testY) = train_test_split(data,
labels, test_size=0.25, random_state=42)
# convert the labels from integers to vectors
lb = LabelBinarizer().fit(trainY)
trainY = lb.transform(trainY)
testY = lb.transform(testY)
We can then initialize the LeNet model and SGD optimizer:
# initialize the model
print("[INFO] compiling model...")
model = LeNet.build(width=28, height=28, depth=1, classes=9)
opt = SGD(lr=0.01)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
Our input images will have a width of 28 pixels, a height of 28 pixels, and a single channel. There are a total of 9 digit classes we are recognizing (there is no 0 class).
Given the initialized model and optimizer we can train the network for 15 epochs, evaluate it, and serialize it to disk:
# train the network
print("[INFO] training network...")
H = model.fit(trainX, trainY, validation_data=(testX, testY),
batch_size=32, epochs=15, verbose=1)
# evaluate the network
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1), target_names=lb.classes_))
# save the model to disk
print("[INFO] serializing network...")
model.save(args["model"])
Our last code block will handle plotting the accuracy and loss for both the training and testing sets over time:
# plot the training + testing loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 15), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 15), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 15), H.history["accuracy"], label="acc")
plt.plot(np.arange(0, 15), H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.show()
To train the LeNet architecture using the SGD optimizer on our custom captcha dataset, just execute the following command:
As we can see, after only 15 epochs our network is obtaining 100% classification accuracy on both the training and validation sets. This is not a case of overfitting either — when we investigate the training and validation curves in Figure 7 we can see that by epoch 5 the validation and training loss/accuracy match each other.
Figure 7: Using the LeNet architecture on our custom digits datasets enables us to obtain 100% classification accuracy after only fifteen epochs. Furthermore, there are no signs of overfitting.
If you check the output directory, you’ll also see the serialized lenet.hdf5 file:
$ ls -l output/
total 9844
-rw-rw-r-- 1 adrian adrian 10076992 May 3 12:56 lenet.hdf5
We can then use this model on new input images.
Testing the Captcha Breaker
Now that our captcha breaker is trained, let’s test it out on some example images. Open the test_model.py file and insert the following code:
# import the necessary packages
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
from pyimagesearch.utils.captchahelper import preprocess
from imutils import contours
from imutils import paths
import numpy as np
import argparse
import imutils
import cv2
As usual, our Python script starts with importing our Python packages. We’ll again be using the preprocess function to prepare digits for classification.
Next, we’ll parse our command line arguments:
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--input", required=True,
help="path to input directory of images")
ap.add_argument("-m", "--model", required=True,
help="path to input model")
args = vars(ap.parse_args())
The --input switch controls the path to the input captcha images that we wish to break. We could download a new set of captchas from the E-ZPass NY website, but for simplicity, we’ll sample images from our existing raw captcha files. The --model argument is simply the path to the serialized weights residing on disk.
We can now load our pre-trained CNN and randomly sample ten captcha images to classify:
# load the pre-trained network
print("[INFO] loading pre-trained network...")
model = load_model(args["model"])
# randomly sample a few of the input images
imagePaths = list(paths.list_images(args["input"]))
imagePaths = np.random.choice(imagePaths, size=(10,),
replace=False)
Here comes the fun part — actually breaking the captcha:
# loop over the image paths
for imagePath in imagePaths:
# load the image and convert it to grayscale, then pad the image
# to ensure digits caught near the border of the image are
# retained
image = cv2.imread(imagePath)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
gray = cv2.copyMakeBorder(gray, 20, 20, 20, 20,
cv2.BORDER_REPLICATE)
# threshold the image to reveal the digits
thresh = cv2.threshold(gray, 0, 255,
cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
On Line 30, we start looping over each of our sampled imagePaths. Just like in the annotate.py example, we need to extract each of the digits in the captcha. This extraction is accomplished by loading the image from disk, converting it to grayscale, and padding the border such that a digit cannot touch the boundary of the image (Lines 34-37). We add extra padding here so we have enough room to actually draw and visualize the correct prediction on the image.
Lines 40 and 41 threshold the image such that the digits appear as a white foreground against a black background.
We now need to find the contours of the digits in the thresh image:
# find contours in the image, keeping only the four largest ones,
# then sort them from left-to-right
cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if imutils.is_cv2() else cnts[1]
cnts = sorted(cnts, key=cv2.contourArea, reverse=True)[:4]
cnts = contours.sort_contours(cnts)[0]
# initialize the output image as a "grayscale" image with 3
# channels along with the output predictions
output = cv2.merge([gray] * 3)
predictions = []
We can find the digits by calling cv2.findContours on the thresh image. This function returns a list of (x, y)-coordinates that specify the outline of each individual digit.
We then perform two stages of sorting. The first stage sorts the contours by their size, keeping only the largest four outlines. We (correctly) assume that the four contours with the largest size are the digits we want to recognize. However, there is no guaranteed spatial ordering imposed on these contours — the third digit we wish to recognize may be first in the cnts list. Since we read digits from left-to-right, we need to sort the contours from left-to-right. This is accomplished via the sort_contours function (http://pyimg.co/sbm9p).
Line 53 takes our gray image and converts it to a three-channel image by replicating the grayscale channel three times (one for each Red, Green, and Blue channel). We then initialize our list of predictions by the CNN on Line 54.
Given the contours of the digits in the captcha, we can now break it:
# loop over the contours
for c in cnts:
# compute the bounding box for the contour then extract the
# digit
(x, y, w, h) = cv2.boundingRect(c)
roi = gray[y - 5:y + h + 5, x - 5:x + w + 5]
# pre-process the ROI and then classify it
roi = preprocess(roi, 28, 28)
roi = np.expand_dims(img_to_array(roi), axis=0) / 255.0
pred = model.predict(roi).argmax(axis=1)[0] + 1
predictions.append(str(pred))
# draw the prediction on the output image
cv2.rectangle(output, (x - 2, y - 2),
(x + w + 4, y + h + 4), (0, 255, 0), 1)
cv2.putText(output, str(pred), (x - 5, y - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 0), 2)
On Line 57, we loop over each of the outlines (which have been sorted from left-to-right) of the digits. We then extract the ROI of the digit on Lines 60 and 61 followed by preprocessing it on Lines 64 and 65.
Line 66 calls the .predict method of our model. The index with the largest probability returned by .predict will be our class label. We add 1 to this value since indexes values start at zero; however, there is no zero class — only classes for the digits 1-9. This prediction is then appended to the predictions list on Line 67.
Lines 70 and 71 draw a bounding box surrounding the current digit, while Lines 72 and 73 draw the predicted digit on the output image itself.
Our last code block handles writing the broken captcha as a string to our terminal as well as displaying the output image:
# show the output image
print("[INFO] captcha: {}".format("".join(predictions)))
cv2.imshow("Output", output)
cv2.waitKey()
To see our captcha breaker in action, simply execute the following command:
In Figure 8, I have included four samples generated from my run of test_model.py. In every case, we have correctly predicted the digit string and broken the image captcha using a simple network architecture trained on a small amount of training data.
Figure 8: Examples of captchas that have been correctly classified and broken by our LeNet model.
Course information:
25 total classes • 37h 19m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 3,690 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
✓ 25 Certificates of Completion
✓ 37h 19m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
Train a custom Convolutional Neural Network on our labeled dataset.
Test and evaluate our model on example images.
To accomplish this, we scraped 500 example captcha images from the E-ZPass NY website. We then wrote a Python script that aids us in the labeling process, enabling us to quickly label the entire dataset and store the resulting images in an organized directory structure.
After our dataset was labeled, we trained the LeNet architecture using the SGD optimizer on the dataset using categorical cross-entropy loss — the resulting model obtained 100% accuracy on the testing set with zero overfitting. Finally, we visualized results of the predicted digits to confirm that we have successfully devised a method to break the captcha.
Again, I want to remind you that this tutorial serves as only an example of how to obtain an image dataset and label it. Under no circumstances should you use this dataset or resulting model for nefarious reasons. If you are ever in a situation where you find that computer vision or deep learning can be used to exploit a vulnerability, be sure to practice responsible disclosure and attempt to report the issue to the proper stakeholders; failure to do so is unethical (as is misuse of this code, which, legally, I must say I cannot take responsibility for).
Secondly, this tutorial (as will the next one on smile detection with deep learning) have leveraged computer vision and the OpenCV library to facilitate building a complete application. If you are planning on becoming a serious deep learning practitioner, I highly recommend that you learn the fundamentals of image processing and the OpenCV library — having even a rudimentary understanding of these concepts will enable you to:
Appreciate deep learning at a higher level.
Develop more robust applications that use deep learning for image classification
Leverage image processing techniques to more quickly obtain your goals.
A great example of using basic image processing techniques to our advantage can be found in the Annotating and Creating Our Dataset section above, where we were able to quickly annotate and label our dataset. Without using simple computer vision techniques, we would have been stuck manually cropping and saving the example digits to disk using image editing software such as Photoshop or GIMP. Instead, we were able to write a quick-and-dirty application that automatically extracted each digit from the captcha — all we had to do was press the proper key on our keyboard to label the image.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In this tutorial, you will receive a gentle introduction to training your first Convolutional Neural Network (CNN) using the PyTorch deep learning library. This network will be able to recognize handwritten Hiragana characters.
Today’s tutorial is part three in our five part series on PyTorch fundamentals:
Today, we will take the next step and learn how to train a CNN to recognize handwritten Hiragana characters using the Kuzushiji-MNIST (KMNIST) dataset.
As you’ll see, training a CNN on an image dataset isn’t all that different from training a basic multi-layer perceptron (MLP) on numerical data. We still need to:
Define our model architecture
Load our dataset from disk
Loop over our epochs and batches
Make predictions and compute our loss
Properly zero our gradient, perform backpropagation, and update our model parameters
Furthermore, this post will also give you some experience with PyTorch’s DataLoader implementation which makes it super easy to work with datasets — becoming proficient with PyTorch’s DataLoader is a critical skill you’ll want to develop as a deep learning practitioner (and it’s a topic that I’ve dedicated an entire course to inside PyImageSearch University).
To learn how to train your first CNN with PyTorch, just keep reading.
PyTorch: Training your first Convolutional Neural Network (CNN)
Throughout the remainder of this tutorial, you will learn how to train your first CNN using the PyTorch framework.
We’ll start by configuring our development environment to install both torch and torchvision, followed by reviewing our project directory structure.
I’ll then show you the KMNIST dataset (a drop-in replacement for the MNIST digits dataset) that contains Hiragana characters. Later in this tutorial, you’ll learn how to train a CNN to recognize each of the Hiragana characters in the KMNIST dataset.
We’ll then implement three Python scripts with PyTorch, including our CNN architecture, training script, and a final script used to make predictions on input images.
By the end of this tutorial, you’ll be comfortable with the steps required to train a CNN with PyTorch.
Let’s get started!
Configuring your development environment
To follow this guide, you need to have PyTorch, OpenCV, and scikit-learn installed on your system.
Luckily, all three are extremely easy to install using pip:
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation— PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having problems configuring your development environment?
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
The KMNIST dataset
Figure 2: The KMNIST dataset is a drop-in replacement for the standard MNIST dataset. The KMNIST dataset contains examples of handwritten Hiragana characters (image source).
The dataset we are using today is the Kuzushiji-MNIST dataset, or KMNIST, for short. This dataset is meant to be a drop-in replacement for the standard MNIST digits recognition dataset.
The KMNIST dataset consists of 70,000 images and their corresponding labels (60,000 for training and 10,000 for testing).
There are a total of 10 classes (meaning 10 Hiragana characters) in the KMNIST dataset, each equally distributed and represented. Our goal is to train a CNN that can accurately classify each of these 10 characters.
And lucky for us, the KMNIST dataset is built into PyTorch, making it super easy for us to work with!
Project structure
Before we start implementing any PyTorch code, let’s first review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and pre-trained model.
You’ll then be presented with the following directory structure:
lenet.py: Our PyTorch implementation of the famous LeNet architecture
train.py: Trains LeNet on the KMNIST dataset using PyTorch, then serializes the trained model to disk (i.e., model.pth)
predict.py: Loads our trained model from disk, makes predictions on testing images, and displays the results on our screen
The output directory will be populated with plot.png (a plot of our training/validation loss and accuracy) and model.pth (our trained model file) once we run train.py.
With our project directory structure reviewed, we can move on to implementing our CNN with PyTorch.
Implementing a Convolutional Neural Network (CNN) with PyTorch
Figure 3: The LeNet architecture. We’ll be implementing LeNet with PyTorch (image source).
The Convolutional Neural Network (CNN) we are implementing here with PyTorch is the seminal LeNet architecture, first proposed by one of the grandfathers of deep learning, Yann LeCunn.
By today’s standards, LeNet is a very shallow neural network, consisting of the following layers:
(CONV => RELU => POOL) * 2 => FC => RELU => FC => SOFTMAX
As you’ll see, we’ll be able to implement LeNet with PyTorch in only 60 lines of code (including comments).
The best way to learn about CNNs with PyTorch is to implement one, so with that said, open the lenet.py file in the pyimagesearch module, and let’s get to work:
# import the necessary packages
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch import flatten
Lines 2-8 import our required packages. Let’s break each of them down:
Module: Rather than using the Sequential PyTorch class to implement LeNet, we’ll instead subclass the Module object so you can see how PyTorch implements neural networks using classes
LogSoftmax: Used when building our softmax classifier to return the predicted probabilities of each class
flatten: Flattens the output of a multi-dimensional volume (e.g., a CONV or POOL layer) such that we can apply fully connected layers to it
With our imports taken care of, we can implement our LeNet class using PyTorch:
class LeNet(Module):
def __init__(self, numChannels, classes):
# call the parent constructor
super(LeNet, self).__init__()
# initialize first set of CONV => RELU => POOL layers
self.conv1 = Conv2d(in_channels=numChannels, out_channels=20,
kernel_size=(5, 5))
self.relu1 = ReLU()
self.maxpool1 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
# initialize second set of CONV => RELU => POOL layers
self.conv2 = Conv2d(in_channels=20, out_channels=50,
kernel_size=(5, 5))
self.relu2 = ReLU()
self.maxpool2 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
# initialize first (and only) set of FC => RELU layers
self.fc1 = Linear(in_features=800, out_features=500)
self.relu3 = ReLU()
# initialize our softmax classifier
self.fc2 = Linear(in_features=500, out_features=classes)
self.logSoftmax = LogSoftmax(dim=1)
Line 10 defines the LeNet class. Notice how we are subclassing the Module object — by building our model as a class we can easily:
Reuse variables
Implement custom functions to generate subnetworks/components (used very often when implementing more complex networks, such as ResNet, Inception, etc.)
Define our own forward pass function
Best of all, when defined correctly, PyTorch can automatically apply its autograd module to perform automatic differentiation — backpropagation is taken care of for us by virtue of the PyTorch library!
The constructor to LeNet accepts two variables:
numChannels: The number of channels in the input images (1 for grayscale or 3 for RGB)
classes: Total number of unique class labels in our dataset
Line 13 calls the parent constructor (i.e., Module) which performs a number of PyTorch-specific operations.
From there, we start defining the actual LeNet architecture.
Lines 16-19 initialize our first set of CONV => RELU => POOL layers. Our first CONV layer learns a total of 20 filters, each of which are 5×5. A ReLU activation function is then applied, followed by a 2×2 max-pooling layer with a 2×2 stride to reduce the spatial dimensions of our input image.
We then have a second set of CONV => RELU => POOL layers on Lines 22-25. We increase the number of filters learned in the CONV layer to 50, but maintain the 5×5 kernel size. Again, a ReLU activation is applied, followed by max-pooling.
Next comes our first and only set of fully connected layers (Lines 28 and 29). We define the number of inputs to the layer (800) along with our desired number of output nodes (500). A ReLu activation follows the FC layer.
FInally, we apply our softmax classifier (Lines 32 and 33). The number of in_features is set to 500, which is the output dimensionality from the previous layer. We then apply LogSoftmax such that we can obtain predicted probabilities during evaluation.
It’s important to understand that at this point all we have done is initialized variables. These variables are essentially placeholders. PyTorch has absolutely no idea what the network architecture is, just that some variables exist inside the LeNet class definition.
To build the network architecture itself (i.e., what layer is input to some other layer), we need to override the forward method of the Module class.
The forward function serves a number of purposes:
It connects layers/subnetworks together from variables defined in the constructor (i.e., __init__) of the class
It defines the network architecture itself
It allows the forward pass of the model to be performed, resulting in our output predictions
And, thanks to PyTorch’s autograd module, it allows us to perform automatic differentiation and update our model weights
Let’s inspect the forward function now:
def forward(self, x):
# pass the input through our first set of CONV => RELU =>
# POOL layers
x = self.conv1(x)
x = self.relu1(x)
x = self.maxpool1(x)
# pass the output from the previous layer through the second
# set of CONV => RELU => POOL layers
x = self.conv2(x)
x = self.relu2(x)
x = self.maxpool2(x)
# flatten the output from the previous layer and pass it
# through our only set of FC => RELU layers
x = flatten(x, 1)
x = self.fc1(x)
x = self.relu3(x)
# pass the output to our softmax classifier to get our output
# predictions
x = self.fc2(x)
output = self.logSoftmax(x)
# return the output predictions
return output
The forward method accepts a single parameter, x, which is the batch of input data to the network.
We then connect our conv1, relu1, and maxpool1 layers together to form the first CONV => RELU => POOL layer of the network (Lines 38-40).
A similar operation is performed on Lines 44-46, this time building the second set of CONV => RELU => POOL layers.
At this point, the variable x is a multi-dimensional tensor; however, in order to create our fully connected layers, we need to “flatten” this tensor into what essentially amounts to a 1D list of values — the flatten function on Line 50 takes care of this operation for us.
From there, we connect the fc1 and relu3 layers to the network architecture (Lines 51 and 52), followed by attaching the final fc2 and logSoftmax (Lines 56 and 57).
The output of the network is then returned to the calling function.
Again, I want to reiterate the importance of initializing variables in the constructor versus building the network itself in the forward function:
The constructor to your Module only initializes your layer types. PyTorch keeps track of these variables, but it has no idea how the layers connect to each other.
For PyTorch to understand the network architecture you’re building, you define the forward function.
Inside the forward function you take the variables initialized in your constructor and connect them.
PyTorch can then make predictions using your network and perform automatic backpropagation, thanks to the autograd module
Congrats on implementing your first CNN with PyTorch!
Creating our CNN training script with PyTorch
With our CNN architecture implemented, we can move on to creating our training script with PyTorch.
Open the train.py file in your project directory structure, and let’s get to work:
# set the matplotlib backend so figures can be saved in the background
import matplotlib
matplotlib.use("Agg")
# import the necessary packages
from pyimagesearch.lenet import LeNet
from sklearn.metrics import classification_report
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import KMNIST
from torch.optim import Adam
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import time
Lines 2 and 3 import matplotlib and set the appropriate background engine.
From there, we import a number of notable packages:
LeNet: Our PyTorch implementation of the LeNet CNN from the previous section
classification_report: Used to display a detailed classification report on our testing set
random_split: Constructs a random training/testing split from an input set of data
DataLoader: PyTorch’s awesome data loading utility that allows us to effortlessly build data pipelines to train our CNN
ToTensor: A preprocessing function that converts input data into a PyTorch tensor for us automatically
KMNIST: The Kuzushiji-MNIST dataset loader built into the PyTorch library
Adam: The optimizer we’ll use to train our neural network
nn: PyTorch’s neural network implementations
Let’s now parse our command line arguments:
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", type=str, required=True,
help="path to output trained model")
ap.add_argument("-p", "--plot", type=str, required=True,
help="path to output loss/accuracy plot")
args = vars(ap.parse_args())
We have two command line arguments that need parsing:
--model: The path to our output serialized model after training (we save this model to disk so we can use it to make predictions in our predict.py script)
--plot: The path to our output training history plot
Moving on, we now have some important initializations to take care of:
# define training hyperparameters
INIT_LR = 1e-3
BATCH_SIZE = 64
EPOCHS = 10
# define the train and val splits
TRAIN_SPLIT = 0.75
VAL_SPLIT = 1 - TRAIN_SPLIT
# set the device we will be using to train the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Lines 29-31 set our initial learning rate, batch size, and number of epochs to train for, while Lines 34 and 35 define our training and validation split size (75% of training, 25% for validation).
Line 38 then determines our device (i.e., whether we’ll be using our CPU or GPU).
Lines 42-45 load the KMNIST dataset using PyTorch’s build in KMNIST class.
For our trainData, we set train=True while our testData is loaded with train=False. These Booleans come in handy when working with datasets built into the PyTorch library.
The download=True flag indicates that PyTorch will automatically download and cache the KMNIST dataset to disk for us if we had not previously downloaded it.
Also take note of the transform parameter — here we can apply a number of data transformations (outside the scope of this tutorial but will be covered soon). The only transform we need is to convert the NumPy array loaded by PyTorch into a tensor data type.
With our training and testing set loaded, we drive our training and validation set on Lines 49-53. Using PyTorch’s random_split function, we can easily split our data.
We now have three sets of data:
Training
Validation
Testing
The next step is to create a DataLoader for each one:
# initialize the train, validation, and test data loaders
trainDataLoader = DataLoader(trainData, shuffle=True,
batch_size=BATCH_SIZE)
valDataLoader = DataLoader(valData, batch_size=BATCH_SIZE)
testDataLoader = DataLoader(testData, batch_size=BATCH_SIZE)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE
valSteps = len(valDataLoader.dataset) // BATCH_SIZE
Building the DataLoader objects is accomplished on Lines 56-59. We set shuffle=True only for our trainDataLoader since our validation and testing sets do not require shuffling.
We also derive the number of training steps and validation steps per epoch (Lines 62 and 63).
At this point our data is ready for training; however, we don’t have a model to train yet!
Let’s initialize LeNet now:
# initialize the LeNet model
print("[INFO] initializing the LeNet model...")
model = LeNet(
numChannels=1,
classes=len(trainData.dataset.classes)).to(device)
# initialize our optimizer and loss function
opt = Adam(model.parameters(), lr=INIT_LR)
lossFn = nn.NLLLoss()
# initialize a dictionary to store training history
H = {
"train_loss": [],
"train_acc": [],
"val_loss": [],
"val_acc": []
}
# measure how long training is going to take
print("[INFO] training the network...")
startTime = time.time()
Lines 67-69 initialize our model. Since the KMNIST dataset is grayscale, we set numChannels=1. We can easily set the number of classes by calling dataset.classes of our trainData.
We also call to(device) to move the model to either our CPU or GPU.
Lines 72 and 73 initialize our optimizer and loss function. We’ll use the Adam optimizer for training and the negative log-likelihood for our loss function.
When we combine the nn.NLLoss class with LogSoftmax in our model definition, we arrive at categorical cross-entropy loss (which is the equivalent to training a model with an output Linear layer and an nn.CrossEntropyLoss loss). Basically, PyTorch allows you to implement categorical cross-entropy in two separate ways.
Get used to seeing both methods as some deep learning practitioners (almost arbitrarily) prefer one over the other.
We then initialize H, our training history dictionary (Lines 76-81). After every epoch we’ll update this dictionary with our training loss, training accuracy, testing loss, and testing accuracy for the given epoch.
Finally, we start a timer to measure how long training takes (Line 85).
At this point, all of our initializations are complete, so it’s time to train our model.
# loop over our epochs
for e in range(0, EPOCHS):
# set the model in training mode
model.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalValLoss = 0
# initialize the number of correct predictions in the training
# and validation step
trainCorrect = 0
valCorrect = 0
# loop over the training set
for (x, y) in trainDataLoader:
# send the input to the device
(x, y) = (x.to(device), y.to(device))
# perform a forward pass and calculate the training loss
pred = model(x)
loss = lossFn(pred, y)
# zero out the gradients, perform the backpropagation step,
# and update the weights
opt.zero_grad()
loss.backward()
opt.step()
# add the loss to the total training loss so far and
# calculate the number of correct predictions
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
On Line 88, we loop over our desired number of epochs.
We then proceed to:
Put the model in train() mode
Initialize our training loss and validation loss for the current epoch
Initialize our number of correct training and validation predictions for the current epoch
Line 102 shows the benefit of using PyTorch’s DataLoader class — all we have to do is start a for loop over the DataLoader object. PyTorch automatically yields a batch of training data. Under the hood, the DataLoader is also shuffling our training data (and if we were doing any additional preprocessing or data augmentation, it would happen here as well).
For each batch of data (Line 104) we perform a forward pass, obtain our predictions, and compute the loss (Lines 107 and 108).
Next comes the all important step of:
Zeroing our gradient
Performing backpropagation
Updating the weights of our model
Seriously, don’t forget this step! Failure to do those three steps in that exact order will lead to erroneous training results. Whenever you write a training loop with PyTorch, I highly recommend you insert those three lines of code before you do anything else so that you are reminded to ensure they are in the proper place.
We wrap up the code block by updating our totalTrainLoss and trainCorrect bookkeeping variables.
At this point, we’ve looped over all batches of data in our training set for the current epoch — now we can evaluate our model on the validation set:
# switch off autograd for evaluation
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# loop over the validation set
for (x, y) in valDataLoader:
# send the input to the device
(x, y) = (x.to(device), y.to(device))
# make the predictions and calculate the validation loss
pred = model(x)
totalValLoss += lossFn(pred, y)
# calculate the number of correct predictions
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
When evaluating a PyTorch model on a validation or testing set, you need to first:
Use the torch.no_grad() context to turn off gradient tracking and computation
Put the model in eval() mode
From there, you loop over all validation DataLoader (Line 128), move the data to the correct device (Line 130), and use the data to make predictions (Line 133) and compute your loss (Line 134).
You can then derive your total number of correct predictions (Lines 137 and 138).
We round out our training loop by computing a number of statistics:
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
trainCorrect = trainCorrect / len(trainDataLoader.dataset)
valCorrect = valCorrect / len(valDataLoader.dataset)
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["train_acc"].append(trainCorrect)
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
H["val_acc"].append(valCorrect)
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, EPOCHS))
print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
avgTrainLoss, trainCorrect))
print("Val loss: {:.6f}, Val accuracy: {:.4f}\n".format(
avgValLoss, valCorrect))
Lines 141 and 142 compute our average training and validation loss. Lines 146 and 146 do the same thing, but for our training and validation accuracy.
We then take these values and update our training history dictionary (Lines 149-152).
Finally, we display the training loss, training accuracy, validation loss, and validation accuracy on our terminal (Lines 149-152).
We’re almost there!
Now that training is complete, we need to evaluate our model on the testing set (previously we’ve only used the training and validation sets):
# finish measuring how long training took
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
# we can now evaluate the network on the test set
print("[INFO] evaluating network...")
# turn off autograd for testing evaluation
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# initialize a list to store our predictions
preds = []
# loop over the test set
for (x, y) in testDataLoader:
# send the input to the device
x = x.to(device)
# make the predictions and add them to the list
pred = model(x)
preds.extend(pred.argmax(axis=1).cpu().numpy())
# generate a classification report
print(classification_report(testData.targets.cpu().numpy(),
np.array(preds), target_names=testData.classes))
Lines 162-164 stop our training timer and show how long training took.
We then set up another torch.no_grad() context and put our model in eval() mode (Lines 170 and 172).
Evaluation is performed by:
Initializing a list to store our predictions (Line 175)
Looping over our testDataLoader (Line 178)
Sending the current batch of data to the appropriate device (Line 180)
Making predictions on the current batch of data (Line 183)
Updating our preds list with the top predictions from the model (Line 184)
Finally, we display a detailed classification_report.
The last step we’ll do here is plot our training and validation history, followed by serializing our model weights to disk:
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(args["plot"])
# serialize the model to disk
torch.save(model, args["model"])
Lines 191-201 generate a matplotlib figure for our training history.
We then call torch.save to save our PyTorch model weights to disk so that we can load them from disk and make predictions from a separate Python script.
As a whole, reviewing this script shows you how much more control PyTorch gives you over the training loop — this is both a good and a bad thing:
It’s good if you want full control over the training loop and need to implement custom procedures
It’s bad when your training loop is simple and a Keras/TensorFlow equivalent to model.fit would suffice
As I mentioned in part one of this series, What is PyTorch, neither PyTorch nor Keras/TensorFlow is better than the other, there are just different caveats and use cases for each library.
Training our CNN with PyTorch
We are now ready to train our CNN using PyTorch.
Be sure to access the “Downloads” section of this tutorial to retrieve the source code to this guide.
From there, you can train your PyTorch CNN by executing the following command:
$ python train.py --model output/model.pth --plot output/plot.png
[INFO] loading the KMNIST dataset...
[INFO] generating the train-val split...
[INFO] initializing the LeNet model...
[INFO] training the network...
[INFO] EPOCH: 1/10
Train loss: 0.362849, Train accuracy: 0.8874
Val loss: 0.135508, Val accuracy: 0.9605
[INFO] EPOCH: 2/10
Train loss: 0.095483, Train accuracy: 0.9707
Val loss: 0.091975, Val accuracy: 0.9733
[INFO] EPOCH: 3/10
Train loss: 0.055557, Train accuracy: 0.9827
Val loss: 0.087181, Val accuracy: 0.9755
[INFO] EPOCH: 4/10
Train loss: 0.037384, Train accuracy: 0.9882
Val loss: 0.070911, Val accuracy: 0.9806
[INFO] EPOCH: 5/10
Train loss: 0.023890, Train accuracy: 0.9930
Val loss: 0.068049, Val accuracy: 0.9812
[INFO] EPOCH: 6/10
Train loss: 0.022484, Train accuracy: 0.9930
Val loss: 0.075622, Val accuracy: 0.9816
[INFO] EPOCH: 7/10
Train loss: 0.013171, Train accuracy: 0.9960
Val loss: 0.077187, Val accuracy: 0.9822
[INFO] EPOCH: 8/10
Train loss: 0.010805, Train accuracy: 0.9966
Val loss: 0.107378, Val accuracy: 0.9764
[INFO] EPOCH: 9/10
Train loss: 0.011510, Train accuracy: 0.9960
Val loss: 0.076585, Val accuracy: 0.9829
[INFO] EPOCH: 10/10
Train loss: 0.009648, Train accuracy: 0.9967
Val loss: 0.082116, Val accuracy: 0.9823
[INFO] total time taken to train the model: 159.99s
[INFO] evaluating network...
precision recall f1-score support
o 0.93 0.98 0.95 1000
ki 0.96 0.95 0.96 1000
su 0.96 0.90 0.93 1000
tsu 0.95 0.97 0.96 1000
na 0.94 0.94 0.94 1000
ha 0.97 0.95 0.96 1000
ma 0.94 0.96 0.95 1000
ya 0.98 0.95 0.97 1000
re 0.95 0.97 0.96 1000
wo 0.97 0.96 0.97 1000
accuracy 0.95 10000
macro avg 0.95 0.95 0.95 10000
weighted avg 0.95 0.95 0.95 10000
Figure 4: Plotting our training history with PyTorch.
Training our CNN took ≈160 seconds on my CPU. Using my GPU training time drops to ≈82 seconds.
At the end of the final epoch we have obtained 99.67% training accuracy and 98.23% validation accuracy.
When we evaluate on our testing set we reach ≈95% accuracy, which is quite good given the complexity of the Hiragana characters and the simplicity of our shallow network architecture (using a deeper network such as a VGG-inspired model or ResNet-like would allow us to obtain even higher accuracy, but those models are more complex for an introduction to CNNs with PyTorch).
Furthermore, as Figure 4 shows, our training history plot is smooth, demonstrating there is little/no overfitting happening.
Before moving to the next section, take a look at your output directory:
$ ls output/
model.pth plot.png
Note the model.pth file — this is our trained PyTorch model saved to disk. We will load this model from disk and use it to make predictions in the following section.
Implementing our PyTorch prediction script
The final script we are reviewing here will show you how to make predictions with a PyTorch model that has been saved to disk.
Open the predict.py file in your project directory structure, and we’ll get started:
# set the numpy seed for better reproducibility
import numpy as np
np.random.seed(42)
# import the necessary packages
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision.transforms import ToTensor
from torchvision.datasets import KMNIST
import argparse
import imutils
import torch
import cv2
Lines 2-13 import our required Python packages. We set the NumPy random seed at the top of the script for better reproducibility across machines.
We then import:
DataLoader: Used to load our KMNIST testing data
Subset: Builds a subset of the testing data
ToTensor: Converts our input data to a PyTorch tensor data type
KMNIST: The Kuzushiji-MNIST dataset loader built into the PyTorch library
cv2: Our OpenCV bindings which we’ll use for basic drawing and displaying output images on our screen
Next comes our command line arguments:
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", type=str, required=True,
help="path to the trained PyTorch model")
args = vars(ap.parse_args())
We only need a single argument here, --model, the path to our trained PyTorch model saved to disk. Presumably, this switch will point to output/model.pth.
Moving on, let’s set our device:
# set the device we will be using to test the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load the KMNIST dataset and randomly grab 10 data points
print("[INFO] loading the KMNIST test dataset...")
testData = KMNIST(root="data", train=False, download=True,
transform=ToTensor())
idxs = np.random.choice(range(0, len(testData)), size=(10,))
testData = Subset(testData, idxs)
# initialize the test data loader
testDataLoader = DataLoader(testData, batch_size=1)
# load the model and set it to evaluation mode
model = torch.load(args["model"]).to(device)
model.eval()
Line 22 determines if we will be performing inference on our CPU or GPU.
We then load the testing data from the KMNIST dataset on Lines 26 and 27. We randomly sample a total of 10 images from this dataset on Lines 28 and 29 using the Subset class (which creates a smaller “view” of the full testing data).
A DataLoader is created to pass our subset of testing data through the model on Line 32.
We then load our serialized PyTorch model from disk on Line 35, passing it to the appropriate device.
Finally, the model is placed into evaluation mode (Line 36).
Let’s now make predictions on a sample of our testing set:
# switch off autograd
with torch.no_grad():
# loop over the test set
for (image, label) in testDataLoader:
# grab the original image and ground truth label
origImage = image.numpy().squeeze(axis=(0, 1))
gtLabel = testData.dataset.classes[label.numpy()[0]]
# send the input to the device and make predictions on it
image = image.to(device)
pred = model(image)
# find the class label index with the largest corresponding
# probability
idx = pred.argmax(axis=1).cpu().numpy()[0]
predLabel = testData.dataset.classes[idx]
Line 39 turns off gradient tracking, while Line 41 loops over all images in our subset of the test set.
For each image, we:
Grab the current image and turn it into a NumPy array (so we can draw on it later with OpenCV)
Extracts the ground-truth class label
Sends the image to the appropriate device
Uses our trained LeNet model to make predictions on the current image
Extracts the class label with the top predicted probability
All that’s left is a bit of visualization:
# convert the image from grayscale to RGB (so we can draw on
# it) and resize it (so we can more easily see it on our
# screen)
origImage = np.dstack([origImage] * 3)
origImage = imutils.resize(origImage, width=128)
# draw the predicted class label on it
color = (0, 255, 0) if gtLabel == predLabel else (0, 0, 255)
cv2.putText(origImage, gtLabel, (2, 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.95, color, 2)
# display the result in terminal and show the input image
print("[INFO] ground truth label: {}, predicted label: {}".format(
gtLabel, predLabel))
cv2.imshow("image", origImage)
cv2.waitKey(0)
Each image in the KMNIST dataset is a single channel grayscale image; however, we want to use OpenCV’s cv2.putText function to draw the predicted class label and ground-truth label on the image.
To draw RGB colors on a grayscale image, we first need to create an RGB representation of the grayscale image by stacking the grayscale image depth-wise a total of three times (Line 58).
Additionally, we resize the origImage so that we can more easily see it on our screen (by default, KMNIST images are only 28×28 pixels, which can be hard to see, especially on a high resolution monitor).
From there, we determine the text color and draw the label on the output image.
We wrap up the script by displaying the output origImage on our screen.
Making predictions with our trained PyTorch model
We are now ready to make predictions using our trained PyTorch model!
Be sure to access the “Downloads” section of this tutorial to retrieve the source code and pre-trained PyTorch model.
From there, you can execute the predict.py script:
$ python predict.py --model output/model.pth
[INFO] loading the KMNIST test dataset...
[INFO] Ground truth label: ki, Predicted label: ki
[INFO] Ground truth label: ki, Predicted label: ki
[INFO] Ground truth label: ki, Predicted label: ki
[INFO] Ground truth label: ha, Predicted label: ha
[INFO] Ground truth label: tsu, Predicted label: tsu
[INFO] Ground truth label: ya, Predicted label: ya
[INFO] Ground truth label: tsu, Predicted label: tsu
[INFO] Ground truth label: na, Predicted label: na
[INFO] Ground truth label: ki, Predicted label: ki
[INFO] Ground truth label: tsu, Predicted label: tsu
Figure 5: Making predictions on handwritten characters using PyTorch and our trained CNN.
As our output demonstrates, we have been able to successfully recognize each of the Hiragana characters using our PyTorch model.
Course information:
25 total classes • 37h 19m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 10,597 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
✓ 25 Certificates of Completion
✓ 37h 19m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned how to train your first Convolutional Neural Network (CNN) using the PyTorch deep learning library.
You also learned how to:
Save our trained PyTorch model to disk
Load it from disk in a separate Python script
Use the PyTorch model to make predictions on images
This sequence of saving a model after training, and then loading it and using the model to make predictions, is a process you should become comfortable with — you’ll be doing it often as a PyTorch deep learning practitioner.
Speaking of loading saved PyTorch models from disk, next week you will learn how to use pre-trained PyTorch to recognize 1,000 image classes that you often encounter in everyday life. These models can save you a bunch of time and hassle — they are highly accurate and don’t require you to manually train them.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In this blog post, I interview Askat Kuzdeuov, a computer vision and deep learning researcher at the Institute of Smart Systems and Artificial Intelligence (ISSAI).
Askat is not only a stellar researcher, but he’s an avid PyImageSearch reader as well.
I was first introduced to Askat a couple weeks ago. We had just announced a $500 prize to the first PyImageSearch University member to complete all courses in the program. Less than 24 hours later, Askat had completed all courses and won the prize.
We started an email conversation from there and he shared some of his latest research. I was incredibly impressed to say the least.
What I really like about SpeakingFaces is the sensor fusion component, consisting of:
High-resolution thermal images
Visual spectra image streams (i.e., what the human eye can see)
Synchronized audio recordings
Data was collected from 142 subjects, yielding over 13,000 instances of synchronized data (∼3.8 TB) … but collecting the data was just the easy part!
The hard part came afterward — preprocessing all the data, ensuring it was synchronized, and packaging it in a way that computer vision/deep learning researchers could use in their own work.
From there, Askat and his colleagues trained a GAN to take the input thermal images and then generate an RGB image from them, which as Askat can attest, was no easy task!
Inside the rest of this interview, you will learn how Askat and his colleagues built the SpeakingFaces dataset, including the preprocessing techniques they used to clean and prepare the dataset for distribution.
You’ll also learn how they trained a GAN to generate RGB images from thermal camera inputs.
If you have any interest in learning about how to perform publication-worthy work in the computer vision community, then definitely make sure you read this interview!
An interview with Askat Kuzdeuov, computer vision and deep learning researcher
Adrian: Hi Askat! Thank you for taking the time to do this interview. It’s a pleasure to have you on the PyImageSearch blog.
Askat: Hi Adrian! Thank you for having me. It’s a great honor to be a guest on the PyImageSearch.
Figure 1: Askat works at the Institute of Smart Systems and Artificial Intelligence (ISSAI), performing research in computer vision and deep learning.
Adrian: Before we get started, can you tell us a bit about yourself? Where do you work, and what is your role?
Figure 2: Askat has authored a number of papers on computer vision and deep learning.
Adrian: I really enjoyed going through your Google Scholar profile and reading some of your publications. What got you interested in doing research in computer vision and deep learning?
Askat: Thank you! I am glad to hear you liked them! I discovered computer vision during my graduate studies when I took an image processing course, on the recommendation of my advisor, Dr. Varol. It quickly became my favourite subject. I had a lot of fun operating with images in a game format, and by the end of the course I realized it is exactly what I want to pursue.
Figure 3: Snapshots from the visual and thermal streams with 0.5-second intervals during the utterance of “stop the kitchen fan from turning” (image source: Figure 6 of Abdrakhmanova and Kuzdeouv et al.).
Askat: Two of the major research interests of our Institute are multisensory data and speech data. The thermal camera was one of the sensors that we experimented with, and we were interested in examining how it might be used to accompany voice recordings. When we began the research, we realized that there are no existing datasets suitable to the task, so we decided to design one. Since most works primarily associate audio with visual data, we opted to combine all three modalities to widen the range of potential of applications.
Figure 4:Top: Data acquisition setup for SpeakingFaces. Bottom: Back (a) and front (b) 3D diagram of the nine camera positions with respect to a subject (image sources: Figures 2 and 3 from Abdrakhmanova and Kuzdeouv et al.).
Adrian: Can you describe the dataset acquisition process? I’m sure that was quite the undertaking. How did you accomplish it so efficiently?
Askat: It was indeed an undertaking. Once we designed the acquisition protocol, we had to go through an ethics committee to prove that our study bears minimal or no risk to our potential participants. We were able to proceed only after their approval. Because of the great number of participants and their busy schedules, it was quite tricky to schedule recording sessions. We asked each participant to donate 30 minutes of their time and come back to repeat the whole shooting process on another day. The population of participants was another challenge, especially so under pandemic conditions. We aimed for a diverse and gender balanced representation, so we had to keep that in mind while scouting for potential participants.
As with many people across the globe, we unexpectedly got hit by the pandemic and had to adjust our plans, and shift towards data processing tasks that could be performed remotely, optimizing online collaboration.
I am thankful to my co-authors and other ISSAI members who helped with the dataset. In the end, as they say, it took a village to acquire and process the data!
Adrian: After gathering the raw data you would have needed to preprocess it and then organize it into a logical structure that other researchers could utilize. Working with multi-modal data makes that a non-trivial process. How did you solve the problem?
Askat: Yes, the preprocessing step had its own challenges, as it revealed some problems that occurred during the data collection process. In fact, we found your tutorials to be instrumental in resolving these issues and cited them in our paper.
For instance, visual images were blurred in some cases because of the autofocus mode in the web camera. Considering that we recorded millions of frames, a manual search for blurred frames would have been an extremely difficult task. Fortunately, we had read your blog post “Blur Detection with OpenCV” which helped us to automate the process.
Also, we noticed that the thermal camera had frozen in some cases, such that the affected frames were not updated properly. We detected these cases by comparing consecutive frames and utilizing the Structural Similarity Index method, which was well explained in your “How-To: Python Compare Two Images” post.
Figure 5: Thermal-to-visual image translation results using CycleGAN. For each column, left to right: real thermal image; generated visual image by CUT; generated visual image by CycleGAN; and real visual image (image source: Figure 9 from Abdrakhmanova and Kuzdeouv et al.).
Adrian: I noticed that you used GANs to translate thermal images to “normal” RGB representations of the faces — can you tell us a bit more about that process and why you needed to use GANs here?
Askat: The state of the art of face recognition and facial landmarks prediction models were trained using only visual images. Therefore, we can’t apply them directly to thermal images. There are several possible options to solve this problem.
The first one is to collect a large amount of annotated thermal images and train the models from scratch. This process takes an enormous amount of time and human labour.
The second option is to combine visual and thermal images and use a transfer learning method. However, this method also requires annotating thermal images.
The option that we went for in the paper is to translate images from thermal to the visible domain using deep generative models. We opted for GANs because the state of the art shows that GANs are superior compared to autoencoders in generating realistic images. However, the downside is that GANs are extremely difficult to train.
Adrian: What are some of the more practical applications for which you envision using the SpeakingFaces dataset?
Askat: Overall we believe the dataset can be used for a wide range of human-computer interaction purposes. We hope our work will encourage others to integrate multimodal data into different recognition systems and make them more robust.
Adrian: What computer vision and deep learning tools, libraries, and packages did you utilize in your research?
Askat: The first step was data collection. We used a Logitech C920 Pro camera with dual mics and FLIR T540 thermal camera.
The official API for the thermal camera was written in MATLAB. Thus, I had to use MATLAB’s Computer VIsion and ROS ToolBoxes to acquire frames simultaneously from both cameras.
The next step was cleaning and preprocessing data. We mainly utilized Python and OpenCV, and heavily used your imutils package. It was very useful!
In the final step, we used PyTorch to build the baseline models for thermal-visual facial image translation and multimodal gender classification.
Adrian: If you had to pick the most challenging issue you faced when building the SpeakingFaces dataset, what would it be?
Askat: Definitely operating during the lockdown. When we shifted to the preprocessing stage, we realized that some of the collected image frames were corrupted. Unfortunately, due to quarantine measures, we couldn’t reshoot them, so we safely removed what we could and documented the rest.
Adrian: What are your next steps as a researcher? Are you going to continue working on SpeakingFaces or are you moving on to other projects?
Askat: Currently I am involved in several projects. Some of them are related to the SpeakingFaces. For instance, I am working on building a thermal-visual pose invariant face verification system based on a siamese neural network. In another project, we collected additional data using the same thermal camera but in the wild. Our goal is to build robust thermal face detection and landmark prediction models.
Adrian: What advice would you give to someone who wants to follow in your footsteps and become a successful published researcher, but doesn’t know how to get started?
Askat: I recommend finding an environment where you can work with people who share similar interests and enthusiasm. If you are passionate about the work, and you are working with interesting people, then it doesn’t really feel like a job!
Adrian: A few weeks ago we offered a $500 prize to the first member of PyImageSearch University who completed all courses, and you were the winner! Congratulations! Can you tell us about that experience? What did you learn inside PyImageSearch University and how were you able to complete all lessons so quickly?
Askat: Thank you again for the prize! It was a nice bonus to the great course.
I’ve been following the blog for some time: at the beginning of my journey into the field of computer vision, I went through most of the freely available lessons at PyImageSearch. I rewrote each piece of code, line by line. I strongly recommend this method, especially for beginners, because it not only allows you to understand the content, but also teaches you to program properly.
When you provided one week free access to the PyImageSearch University, I joined to watch the videos, because it is always good to refresh your knowledge. It was a great decision because the experiences I got from reading the blogs and watching the videos were completely different. I went through almost all topics in one week and had only 1 or 2 left. Nevertheless, I decided to purchase the annual membership because I liked the format, and also I wanted to support your hard work.
I definitely recommend PyImageSearch University for researchers and practitioners, especially in the areas of computer vision, deep learning, and OpenCV. The courses provide a strong baseline which is imperative if you wish to understand advanced concepts.
Adrian: If a PyImageSearch reader wants to connect with you, how can they do so?
Askat: If you wish to discuss some interesting ideas with me, you can contact me via Linkedin.
If you are interested in learning more about the SpeakingFaces dataset, here is a good overview video:
And here are some examples of our image translation model:
If you are interested to learn more about other active projects, we have pretty comprehensive material available on the website, and on the YouTube channel:
Today we interviewed Askat Kuzdeuov, a computer vision and deep learning researcher at the Institute of Smart Systems and Artificial Intelligence (ISSAI).
Askat’s latest work includes the SpeakingFaces dataset which can be used for human–computer interaction, biometric authentication, recognition systems, domain transfer, and speech recognition.
Perhaps one of the most notable contributions from the work is the accuracy of using GANs to generate RGB images from thermal camera inputs — an accurate GAN model would allow deep learning practitioners to reduce the number of sensors in a real-world application, potentially relying on just the thermal camera.
Make sure you give Askat’s work a read, it’s a wonderfully done and high-quality piece.
To be notified when future tutorials and interviews are published here on PyImageSearch, simply enter your email address in the form below!
Join the PyImageSearch Newsletter and Grab My FREE 17-page Resource Guide PDF
Enter your email address below to join the PyImageSearch Newsletter and download my FREE 17-page Resource Guide PDF on Computer Vision, OpenCV, and Deep Learning.
In this tutorial, you will learn how to perform image classification with pre-trained networks using PyTorch. Utilizing these networks, you can accurately classify 1,000 common object categories in only a few lines of code.
Today’s tutorial is part four in our five part series on PyTorch fundamentals:
PyTorch image classification with pre-trained networks (today’s tutorial)
August 2nd: PyTorch object detection with pre-trained networks (next week’s tutorial)
Throughout the rest of this tutorial, you’ll gain experience using PyTorch to classify input images using seminal, state-of-the-art image classification networks, including VGG, Inception, DenseNet, and ResNet.
To learn how to perform image classification with pre-trained PyTorch networks, just keep reading.
PyTorch image classification with pre-trained networks
In the first part of this tutorial, we’ll discuss what pre-trained image classification networks are, including those that are built into the PyTorch library.
From there, we’ll configure our development environment and review our project directory structure.
I’ll then show you how to implement a Python script that can accurately classify input images using pre-trained PyTorch networks.
We’ll wrap up this tutorial with a discussion of our results.
What are pre-trained image classification networks?
Figure 1: Most popular, state-of-the-art neural networks come with weights pre-trained on the ImageNet dataset. The PyTorch library includes many of these popular image classification networks.
When it comes to image classification, there is no dataset/challenge more famous than ImageNet. The goal of ImageNet is to accurately classify input images into a set of 1,000 common object categories that computer vision systems will “see” in everyday life.
Most popular deep learning frameworks, including PyTorch, Keras, TensorFlow, fast.ai, and others, include pre-trained networks. These are highly accurate, state-of-the-art models that computer vision researchers trained on the ImageNet dataset.
After training on ImageNet was complete, researchers saved their models to disk and then published them freely for other researchers, students, and developers to learn from and use in their own projects.
This tutorial will show how to use PyTorch to classify input images using the following state-of-the-art classification networks:
VGG16
VGG19
Inception
DenseNet
ResNet
Let’s get started!
Configuring your development environment
To follow this guide, you need to have both PyTorch and OpenCV installed on your system.
Luckily, both PyTorch and OpenCV are extremely easy to install using pip:
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation— PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having problems configuring your development environment?
Figure 2: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
Before we implement image classification with PyTorch, let’s first review our project directory structure.
Start by accessing the “Downloads” section of this guide to retrieve the source code and example images. You’ll then be presented with the following directory structure.
Inside the pyimagesearch module we have a single file, config.py. This file stores important configurations, such as:
Our input image dimensions
Mean and standard deviation for mean subtraction and scaling
Whether or not we are using a GPU for training
Path to the human-readable ImageNet class labels (i.e., ilsvrc2012_wordnet_lemmas.txt)
Our classify_image.py script will load our config and then classify an input image using either VGG16, VGG19, Inception, DenseNet, or ResNet (depending on which model architecture we supply as our command line argument).
The images directory contains a number of sample images where we’ll apply these image classification networks.
Creating our configuration file
Before we implement our image classification driver script, let’s first create a configuration file to store important configurations.
Open the config.py file in the pyimagesearch module and insert the following code:
# import the necessary packages
import torch
# specify image dimension
IMAGE_SIZE = 224
# specify ImageNet mean and standard deviation
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# determine the device we will be using for inference
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# specify path to the ImageNet labels
IN_LABELS = "ilsvrc2012_wordnet_lemmas.txt"
Line 5 defines our input image spatial dimensions, meaning that each image will be resized to 224×224 pixels before being passed through our pre-trained PyTorch network for classification.
Note: Most networks trained on the ImageNet dataset accept images that are 224×224 or 227×227. Some networks, particularly fully convolutional networks, may accept larger image dimensions.
From there, we define the mean and standard deviation of RGB pixel intensities across our training set (Lines 8 and 9). Prior to passing an input image through our network for classification, we first scale the image pixel intensities by subtracting the mean and then dividing by the standard deviation — this preprocessing is typical for CNNs trained on large, diverse image datasets such as ImageNet.
From there, Line 12 specifies whether we are using our CPU or GPU for training, while Line 15 defines the path to our input text file of ImageNet class labels.
If you were to open this file in your favorite text editor of choice, you would see the following contents:
Each row in this text file maps to the name of a class label our pre-trained PyTorch networks were trained to recognize and classify.
Implementing our image classification script
With our configuration file taken care of, let’s move on to implementing our main driver script used to classify input images using our pre-trained PyTorch networks.
Open the classify_image.py file in your project directory structure, and let’s get to work:
# import the necessary packages
from pyimagesearch import config
from torchvision import models
import numpy as np
import argparse
import torch
import cv2
We start on Lines 2-7 importing our Python packages, including:
config: The configuration file we implemented from the previous section
With our imports taken care of, let’s define a function to accept an input image and preprocess it:
def preprocess_image(image):
# swap the color channels from BGR to RGB, resize it, and scale
# the pixel values to [0, 1] range
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (config.IMAGE_SIZE, config.IMAGE_SIZE))
image = image.astype("float32") / 255.0
# subtract ImageNet mean, divide by ImageNet standard deviation,
# set "channels first" ordering, and add a batch dimension
image -= config.MEAN
image /= config.STD
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, 0)
# return the preprocessed image
return image
Our preprocess_image function takes a single argument, image, which is the image we’ll be preprocessing for classification.
We start the preprocessing operations by:
Swapping from BGR to RGB channel ordering (the pre-trained networks we’re using here utilized RGB channel ordering whereas OpenCV uses BGR ordering by default)
Resizing our image to fixed dimensions (i.e., 224×224), ignoring aspect ratio
Converting our image to a floating point data type and then scaling the pixel intensities to the range [0, 1]
From there, we perform a second set of preprocessing operations:
Subtracting the mean (Line 18) and dividing by the standard deviation (Line 19)
Moving the channels dimension to the front of the array (Line 20), which is called channels-first ordering and is the default channel ordering method that PyTorch expects
Adding a batch dimension to the array (Line 21)
The preprocessed image is then returned to the calling function.
Next, let’s parse our command line arguments:
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,
help="path to the input image")
ap.add_argument("-m", "--model", type=str, default="vgg16",
choices=["vgg16", "vgg19", "inception", "densenet", "resnet"],
help="name of pre-trained network to use")
args = vars(ap.parse_args())
We have two command line arguments to parse:
--image: The path to the input image that we wish to classify
--model: The pre-trained CNN model we’ll be using to classify the image
Let’s now define a MODELS dictionary which maps the name of the --model command line argument to its corresponding PyTorch function:
# define a dictionary that maps model names to their classes
# inside torchvision
MODELS = {
"vgg16": models.vgg16(pretrained=True),
"vgg19": models.vgg19(pretrained=True),
"inception": models.inception_v3(pretrained=True),
"densenet": models.densenet121(pretrained=True),
"resnet": models.resnet50(pretrained=True)
}
# load our the network weights from disk, flash it to the current
# device, and set it to evaluation mode
print("[INFO] loading {}...".format(args["model"]))
model = MODELS[args["model"]].to(config.DEVICE)
model.eval()
Lines 37-43 create our MODELS dictionary:
The key to the dictionary is the human-readable name of the model, passed in via the --model command line argument.
The value to the dictionary is the corresponding PyTorch function used to load the model with the weights pre-trained on ImageNet
You’ll be able to use the following pre-trained models to classify an input image with PyTorch:
VGG16
VGG19
Inception
DenseNet
ResNet
Specifying the pretrained=True flag instructs PyTorch to not only load the model architecture definition, but also download the pre-trained ImageNet weights for the model.
Line 48 then loads the model and pre-trained weights (if you’ve never downloaded the model weights before they will be automatically downloaded and cached for you) and then sets the model to run either on your CPU or GPU, depending on your DEVICE from the configuration file.
Line 49 puts our model into evaluation mode, instructing PyTorch to handle special layers, such as dropout and batch normalization, different from how it would otherwise handle them during training. Putting your model into evaluation mode before making predictions is critical, so don’t forget to do it!
Now that our model is loaded, we need an input image — let’s take care of that now:
# load the image from disk, clone it (so we can draw on it later),
# and preprocess it
print("[INFO] loading image...")
image = cv2.imread(args["image"])
orig = image.copy()
image = preprocess_image(image)
# convert the preprocessed image to a torch tensor and flash it to
# the current device
image = torch.from_numpy(image)
image = image.to(config.DEVICE)
# load the preprocessed the ImageNet labels
print("[INFO] loading ImageNet labels...")
imagenetLabels = dict(enumerate(open(config.IN_LABELS)))
Line 54 loads our input image from disk. We make a copy of it on Line 55 so that we can draw on it and visualize the top prediction of our network. We also make use of our preprocess_image function on Line 56 to perform resizing and scaling.
Line 60 converts our image from a NumPy array to a PyTorch tensor, while Line 61 moves the image to our device (either CPU or GPU).
FInally, Line 65 loads our input ImageNet class labels from disk.
We are now ready to make predictions on input image using our model:
# classify the image and extract the predictions
print("[INFO] classifying image with '{}'...".format(args["model"]))
logits = model(image)
probabilities = torch.nn.Softmax(dim=-1)(logits)
sortedProba = torch.argsort(probabilities, dim=-1, descending=True)
# loop over the predictions and display the rank-5 predictions and
# corresponding probabilities to our terminal
for (i, idx) in enumerate(sortedProba[0, :5]):
print("{}. {}: {:.2f}%".format
(i, imagenetLabels[idx.item()].strip(),
probabilities[0, idx.item()] * 100))
Line 69 performs a forward-pass of our network, resulting in the outputs of the network.
We pass these through the Softmax function on Line 70 to obtain the predicted probabilities for each of the possible 1,000 class labels the model was trained on.
Line 71 then sorts the probabilities in descending order with higher probabilities at the front of the list.
We then display the top-5 predicted class labels and corresponding probabilities to our terminal on Lines 75-78 by:
Looping over the top-5 predictions
Looking up the name of the class label using our imagenetLabels dictionary
Displaying the predicted probability
Our final code block draws the top-1 (i.e., top predicted label) on our output image:
# draw the top prediction on the image and display the image to
# our screen
(label, prob) = (imagenetLabels[probabilities.argmax().item()],
probabilities.max().item())
cv2.putText(orig, "Label: {}, {:.2f}%".format(label.strip(), prob * 100),
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
cv2.imshow("Classification", orig)
cv2.waitKey(0)
The result is then displayed to our screen.
Image classification with PyTorch results
We are now ready to apply image classification with PyTorch!
Be sure to access the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, try classifying an input image using the following command:
Figure 3: Using PyTorch and VGG16 to classify an input image.
It appears that Captain Jack Sparrow is stranded on the beach! And sure enough, the VGG16 network is able to correctly classify the input image as a “wreck” (i.e., shipwreck) with 99.99% probability.
It’s also interesting to see that “seashore” is the second top prediction from the model — this prediction is also accurate, due to the boat being on the beach.
Let’s try a different image, this time using the DenseNet model:
Figure 5: Utilizing ResNet and PyTorch to correctly classify an input image.
Here we are using the ResNet architecture to classify our input image. Jemma is a “beagle” (a type of dog), which ResNet accurately predicts with 95.98% probability.
Interestingly, a “bluetick,” “walker hound,” and “English foxhound” are all types of dogs belonging to the “hound” family — all of these would be reasonable predictions from the model.
Figure 6: Using Inception and PyTorch to make predictions on an input image.
Our Inception model correctly classifies the input image as “soccer ball” with 100% probability.
Image classification allows us to assign one or more labels to an input image; however, it tells us nothing about where in the image the object resides.
To determine where in an input image a given object is, we need to apply object detection:
Figure 7: Object detection can not only tell us what is in an image but also where the object is.
Just like we have pre-trained networks for image classification, we also have pre-trained networks for object detection as well. Next week you’ll learn how to use PyTorch to detect objects in images using specialized object detection networks.
Course information:
25 total classes • 37h 19m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 10,597 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
✓ 25 Certificates of Completion
✓ 37h 19m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned how to perform image classification using PyTorch. Specifically, we utilized popular pre-trained network architectures, including:
VGG16
VGG19
Inception
DenseNet
ResNet
These models were trained by the researchers responsible for inventing and proposing the novel architectures listed above. After training was complete, these researchers saved the model weights to disk and then published them for other researchers, students, and developers to learn from and use in their own projects.
While the models are free to use, make sure you check any terms/conditions associated with them, as some models are not free to use in commercial applications (typically entrepreneurs in the AI space get around this restriction by training the models themselves rather than using the pre-trained weights provided by the original authors).
Stay tuned for next week’s blog post, where you’ll learn how to perform object detection using PyTorch.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In this tutorial, you will learn how to perform object detection with pre-trained networks using PyTorch. Utilizing pre-trained object detection networks, you can detect and recognize 90 common objects that your computer vision application will “see” in everyday life.
Today’s tutorial is the final part in our five part series on PyTorch fundamentals:
PyTorch object detection with pre-trained networks (today’s tutorial)
Throughout the rest of this tutorial, you’ll gain experience using PyTorch to detect objects in input images using seminal, state-of-the-art image classification networks, including Faster R-CNN with ResNet, Faster R-CNN with MobileNet, and RetinaNet.
To learn how to perform object detection with pre-trained PyTorch networks, just keep reading.
PyTorch object detection with pre-trained networks
In the first part of this tutorial, we will discuss what pre-trained object detection networks are, including what object detection networks are built into the PyTorch library.
From there, we’ll configure our development environment and review our project directory structure.
We’ll review two Python scripts today. The first one will perform object detection in images, while the second one will show you how to perform real-time object detection in video streams (a GPU will be required to obtain real-time performance).
Finally, we’ll wrap up this tutorial with a discussion of our results.
What are pre-trained object detection networks?
Figure 1: Most popular, state-of-the-art neural networks come with weights pre-trained on the COCO dataset for object detection. The PyTorch library includes many of these popular object detection networks (image source).
Just like the ImageNet challenge tends to be the de facto standard for image classification, the COCO dataset (Common Objects in Context) tends to be the standard for object detection benchmarking.
This dataset includes over 90 classes of common objects you’ll see in the everyday world. Computer vision and deep learning researchers develop, train, and evaluate state-of-the-art object detection networks on the COCO dataset.
Most researchers also publish the pre-trained weights to their models so that computer vision practitioners can easily incorporate object detection into their own projects.
This tutorial will show how to use PyTorch to perform object detection using the following state-of-the-art classification networks:
Faster R-CNN with a ResNet50 backbone (more accurate, but slower)
Faster R-CNN with a MobileNet v3 backbone (faster, but less accurate)
RetinaNet with a ResNet50 backbone (good balance between speed and accuracy)
Ready? Let’s get started.
Configuring your development environment
To follow this guide, you need to have both PyTorch and OpenCV installed on your system.
Luckily, both PyTorch and OpenCV are extremely easy to install using pip:
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation— PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having problems configuring your development environment?
Figure 2: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
Before we start reviewing any source code, let’s first review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
You’ll then be presented with the following directory structure:
Inside the images directory, you’ll find a number of example images where we’ll be applying object detection.
The coco_classes.pickle file contains the names of the class labels our PyTorch pre-trained object detection networks were trained on.
We then have two Python scripts to review:
detect_image.py: Performs object detection with PyTorch in static images
detect_realtime.py: Applies PyTorch object detection to real-time video streams
Implementing our PyTorch object detection script
In this section, you will learn how to perform object detection with pre-trained PyTorch networks.
Open the detect_image.py script and insert the following code:
# import the necessary packages
from torchvision.models import detection
import numpy as np
import argparse
import pickle
import torch
import cv2
Lines 2-7 import our required Python packages. The most important import is detection from torchvision.models. The detection module contains PyTorch’s pre-trained object detectors.
Let’s move on to parsing our command line arguments:
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", type=str, required=True,
help="path to the input image")
ap.add_argument("-m", "--model", type=str, default="frcnn-resnet",
choices=["frcnn-resnet", "frcnn-mobilenet", "retinanet"],
help="name of the object detection model")
ap.add_argument("-l", "--labels", type=str, default="coco_classes.pickle",
help="path to file containing list of categories in COCO dataset")
ap.add_argument("-c", "--confidence", type=float, default=0.5,
help="minimum probability to filter weak detections")
args = vars(ap.parse_args())
We have a number of command line arguments here, including:
--image: The path to the input image we want to apply object detection to
--model: The type of PyTorch object detector we’ll be using (Faster R-CNN + ResNet, Faster R-CNN + MobileNet, or RetinaNet + ResNet)
--labels: The path to the COCO labels file, containing human readable class labels
--confidence: Minimum predicted probability to filter out weak detections
Here, we have a few important initializations:
# set the device we will be using to run the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load the list of categories in the COCO dataset and then generate a
# set of bounding box colors for each class
CLASSES = pickle.loads(open(args["labels"], "rb").read())
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
Line 23 sets the device we’ll be using for inference (either CPU or GPU).
We then load our class labels from disk (Line 27) and initialize a random color for each unique label (Line 28). We’ll use these colors when drawing predicted bounding boxes and labels on our output image.
Next, we define a MODELS dictionary to map the name of a given object detector to its corresponding PyTorch function:
# initialize a dictionary containing model name and its corresponding
# torchvision function call
MODELS = {
"frcnn-resnet": detection.fasterrcnn_resnet50_fpn,
"frcnn-mobilenet": detection.fasterrcnn_mobilenet_v3_large_320_fpn,
"retinanet": detection.retinanet_resnet50_fpn
}
# load the model and set it to evaluation mode
model = MODELS[args["model"]](pretrained=True, progress=True,
num_classes=len(CLASSES), pretrained_backbone=True).to(DEVICE)
model.eval()
PyTorch provides us with three object detection models:
Faster R-CNN with a ResNet50 backbone (more accurate, but slower)
Faster R-CNN with a MobileNet v3 backbone (faster, but less accurate)
RetinaNet with a ResNet50 backbone (good balance between speed and accuracy)
We then load the model from disk and send it to the appropriate DEVICE on Lines 39 and 40. We pass in a number of key parameters, including:
pretrained: Tells PyTorch to load the model architecture with pre-trained weights on the COCO dataset
progress=True: Displays download progress bar if model has not already been downloaded and cached
num_classes: Total number of unique classes
pretrained_backbone: Also provide the backbone network to the object detector
We then place the model in evaluation mode on Line 41.
With our model loaded, let’s move on to preparing our input image for object detection:
# load the image from disk
image = cv2.imread(args["image"])
orig = image.copy()
# convert the image from BGR to RGB channel ordering and change the
# image from channels last to channels first ordering
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.transpose((2, 0, 1))
# add the batch dimension, scale the raw pixel intensities to the
# range [0, 1], and convert the image to a floating point tensor
image = np.expand_dims(image, axis=0)
image = image / 255.0
image = torch.FloatTensor(image)
# send the input to the device and pass the it through the network to
# get the detections and predictions
image = image.to(DEVICE)
detections = model(image)[0]
Lines 44 and 45 load our input image from disk and clone it so that we can draw the bounding box predictions on it later in this script.
We then preprocess our image by:
Converting color channel ordering from BGR to RGB (since PyTorch models were trained on RGB-ordered images)
Swapping color channel ordering from “channels last” (OpenCV and Keras/TensorFlow default) to “channels first” (PyTorch default)
Adding a batch dimension
Scaling pixel intensities from the range [0, 255] to [0, 1]
Converting the image from a NumPy array to a tensor with a floating point data type
The image is then moved to the appropriate device (Line 60). At that point, we pass the image through the model to obtain our bounding box predictions.
Let’s loop over our bounding box predictions now:
# loop over the detections
for i in range(0, len(detections["boxes"])):
# extract the confidence (i.e., probability) associated with the
# prediction
confidence = detections["scores"][i]
# filter out weak detections by ensuring the confidence is
# greater than the minimum confidence
if confidence > args["confidence"]:
# extract the index of the class label from the detections,
# then compute the (x, y)-coordinates of the bounding box
# for the object
idx = int(detections["labels"][i])
box = detections["boxes"][i].detach().cpu().numpy()
(startX, startY, endX, endY) = box.astype("int")
# display the prediction to our terminal
label = "{}: {:.2f}%".format(CLASSES[idx], confidence * 100)
print("[INFO] {}".format(label))
# draw the bounding box and label on the image
cv2.rectangle(orig, (startX, startY), (endX, endY),
COLORS[idx], 2)
y = startY - 15 if startY - 15 > 15 else startY + 15
cv2.putText(orig, label, (startX, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2)
# show the output image
cv2.imshow("Output", orig)
cv2.waitKey(0)
Line 64 loops over all detections from the network. We then grab the confidence (i.e., probability) associated with the detection on Line 67.
We filter out weak detections that do not meet our minimum confidence test on Line 71. Doing so helps filter out false-positive detections.
From there, we:
Extract the idx of the class label with the largest corresponding probability (Line 75)
Obtain the bounding box coordinates and convert them to integers (Lines 76 and 77)
Display the prediction to our terminal (Lines 80 and 81)
Draw the predicted bounding box and class label on our output image (Lines 84-88)
We wrap up the script by displaying our output image with bounding boxes drawn on it.
Object detection with PyTorch results
We are now ready to see some PyTorch object detection results!
Be sure to access the “Downloads” section of this tutorial to retrieve the source code and example images.
Figure 3: Using Faster R-CNN and PyTorch to perform object detection.
The object detector we are using here is a Faster R-CNN with a ResNet50 backbone. Due to how the network is designed, Faster R-CNNs tend to be really good at detecting small objects in images — this is evidenced by the fact that not only are each of the cars detected in the input image, but also one of the drivers (whom is barely visible to the human eye).
Here is another example image using our Faster R-CNN object detector:
Figure 4: Applying pre-trained object detection networks with PyTorch.
Here, we can see that our output object detections are quite accurate. Our model accurately detects me and Jemma, the family beagle, in the foreground of the scene. It also detects the television and chair in the background.
Let’s try one final image, this one of a more complicated scene that really demonstrates how good Faster R-CNN models are at detecting small objects:
Figure 5: Faster R-CNN and PyTorch can be used together to detect small objects in complex scenes.
Notice here how we are manually specifying our --confidence command line argument of 0.7, meaning that object detections with a predicted probability > 70% will be considered a true-positive detection (if you remember, the detect_image.py script defaults the minimum confidence to 90%).
Note: Lowering our default confidence will allow us to detect more objects but perhaps at the expense of false-positives.
That said, as the output of Figure 5 shows, our model has made highly accurate predictions. We’ve not only detected the foreground objects such as the dog, horse, and person on the horse, but we’ve also detected background objects, including the truck and multiple people in the background.
As an exercise to gain more experience with object detection using PyTorch, I suggest you swap out the --model command line argument for frcnn-mobilenet and retinanet, and then compare the results of your output.
Implementing real-time object detection with PyTorch
In our previous section, you learned how to apply object detection to single images at PyTorch. This section will show you how to use PyTorch to apply object detection to video streams.
As you’ll see, much of the code from the previous implementation can be reused, with only minor changes.
Open the detect_realtime.py script in your project directory structure, and let’s get to work:
# import the necessary packages
from torchvision.models import detection
from imutils.video import VideoStream
from imutils.video import FPS
import numpy as np
import argparse
import imutils
import pickle
import torch
import time
import cv2
Lines 2-11 import our required Python packages. All these imports are essentially the same as our detect_image.py script, but with two notable additions:
VideoStream: Accesses our webcam
FPS: Measures our approximate frames per second throughput rate of our object detection pipeline
Next comes our command line arguments:
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", type=str, default="frcnn-resnet",
choices=["frcnn-resnet", "frcnn-mobilenet", "retinanet"],
help="name of the object detection model")
ap.add_argument("-l", "--labels", type=str, default="coco_classes.pickle",
help="path to file containing list of categories in COCO dataset")
ap.add_argument("-c", "--confidence", type=float, default=0.5,
help="minimum probability to filter weak detections")
args = vars(ap.parse_args())
Our first switch, --model controls which PyTorch object detector we want to utilize.
The --labels argument provides the path to the COCO class files file.
And finally, the --confidence switch allows us to provide a minimum predicted probability to help filter out weak, false-positive detections.
The next code block handles setting our inference device (CPU or GPU), along with loading our class labels:
# set the device we will be using to run the model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load the list of categories in the COCO dataset and then generate a
# set of bounding box colors for each class
CLASSES = pickle.loads(open(args["labels"], "rb").read())
COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3))
When performing object detection in video streams, I highly recommend that you use a GPU — a CPU will be too slow for anything close to real-time performance.
We then define our MODELS dictionary, just like in the previous script:
# initialize a dictionary containing model name and its corresponding
# torchvision function call
MODELS = {
"frcnn-resnet": detection.fasterrcnn_resnet50_fpn,
"frcnn-mobilenet": detection.fasterrcnn_mobilenet_v3_large_320_fpn,
"retinanet": detection.retinanet_resnet50_fpn
}
# load the model and set it to evaluation mode
model = MODELS[args["model"]](pretrained=True, progress=True,
num_classes=len(CLASSES), pretrained_backbone=True).to(DEVICE)
model.eval()
Lines 41-43 load the PyTorch object detection model from disk and place it in evaluation mode.
We are now ready to access our webcam:
# initialize the video stream, allow the camera sensor to warmup,
# and initialize the FPS counter
print("[INFO] starting video stream...")
vs = VideoStream(src=0).start()
time.sleep(2.0)
fps = FPS().start()
We insert a small sleep statement to allow our camera sensor to warm up.
A call to the start method of FPS allows us to start timing our approximate frames per second throughput rate.
The next step is to loop over frames from our video stream:
# loop over the frames from the video stream
while True:
# grab the frame from the threaded video stream and resize it
# to have a maximum width of 400 pixels
frame = vs.read()
frame = imutils.resize(frame, width=400)
orig = frame.copy()
# convert the frame from BGR to RGB channel ordering and change
# the frame from channels last to channels first ordering
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = frame.transpose((2, 0, 1))
# add a batch dimension, scale the raw pixel intensities to the
# range [0, 1], and convert the frame to a floating point tensor
frame = np.expand_dims(frame, axis=0)
frame = frame / 255.0
frame = torch.FloatTensor(frame)
# send the input to the device and pass the it through the
# network to get the detections and predictions
frame = frame.to(DEVICE)
detections = model(frame)[0]
Lines 56-58 read a frame from the video stream, resize it (the smaller the input frame, the faster inference will be), and then clone it so we can draw on it later.
Our preprocessing operations are identical to our previous script:
Convert from BGR to RGB channel ordering
Switch from “channels last” to “channels first” ordering
Add a batch dimension
Scale the pixel intensities in the frame from the range [0, 255] to [0, 1]
Convert the frame to a floating point PyTorch tensor
The preprocessed frame is then moved to the appropriate device, after which predictions are made (Lines 73 and 74).
Processing the results of the object detection model is identical to that of predict_image.py:
# loop over the detections
for i in range(0, len(detections["boxes"])):
# extract the confidence (i.e., probability) associated with
# the prediction
confidence = detections["scores"][i]
# filter out weak detections by ensuring the confidence is
# greater than the minimum confidence
if confidence > args["confidence"]:
# extract the index of the class label from the
# detections, then compute the (x, y)-coordinates of
# the bounding box for the object
idx = int(detections["labels"][i])
box = detections["boxes"][i].detach().cpu().numpy()
(startX, startY, endX, endY) = box.astype("int")
# draw the bounding box and label on the frame
label = "{}: {:.2f}%".format(CLASSES[idx], confidence * 100)
cv2.rectangle(orig, (startX, startY), (endX, endY),
COLORS[idx], 2)
y = startY - 15 if startY - 15 > 15 else startY + 15
cv2.putText(orig, label, (startX, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2)
Finally, we can display the output frame to our window:
# show the output frame
cv2.imshow("Frame", orig)
key = cv2.waitKey(1) & 0xFF
# if the 'q' key was pressed, break from the loop
if key == ord("q"):
break
# update the FPS counter
fps.update()
# stop the timer and display FPS information
fps.stop()
print("[INFO] elapsed time: {:.2f}".format(fps.elapsed()))
print("[INFO] approx. FPS: {:.2f}".format(fps.fps()))
# do a bit of cleanup
cv2.destroyAllWindows()
vs.stop()
We continue to monitor our FPS until we click on the window opened by OpenCV and press the q key to exit the script, after which we stop our FPS timer and display (1) the elapsed time of the script and (2) approximate frames per second throughput information.
PyTorch real-time object detection results
Let’s learn how to apply object detection to video streams using PyTorch.
Be sure to access the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, you can execute the detect_realtime.py script:
Using our Faster R-CNN model with a MobileNet background (best for speed) we’re achieving ≈7 FPS per second. We’re not quite at true real-time speed of > 20 FPS, but with a faster GPU and more optimization we could easily get there.
Course information:
25 total classes • 37h 19m video • Last updated: 7/2021 ★★★★★ 4.84 (128 Ratings) • 10,597 Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 25 courses on essential computer vision, deep learning, and OpenCV topics
✓ 25 Certificates of Completion
✓ 37h 19m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned how to perform object detection with PyTorch and pre-trained networks. You gained experience applying object detection with three popular networks:
Faster R-CNN with ResNet50 backbone
Faster R-CNN with MobileNet backbone
RetinaNet with ResNet50 backbone
When it comes to both accuracy and detecting small objects, Faster R-CNN will perform very well. However, that accuracy comes at a cost — Faster R-CNN models tend to be much slower than Single Shot Detectors (SSDs) and YOLO.
To help speed up the Faster R-CNN architecture, we can swap out the computationally expensive ResNet backhone for a lighter, more efficient (but less accurate) MobileNet backbone. Doing so will give you a boost in speed.
Otherwise, RetinaNet is a nice compromise between speed and accuracy.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
This post covers the intuition of Generative Adversarial Networks (GANs) at a high level, the various GAN variants, and applications for solving real-world problems.
How GANs work
GANs are a type of generative models, which observe many sample distributions and generate more samples of the same distribution. Other generative models include variational autoencoders (VAE) and Autoregressive models.
The GAN architecture
There are two networks in a basic GAN architecture: the generator model and the discriminator model. GANs get the word “adversarial” in its name because the two networks are trained simultaneously and competing against each other, like in a zero-sum game such as chess.
Figure 1: Chess pieces on a board.
The generator model generates new images. The goal of the generator is to generate images that look so real that it fools the discriminator. In the simplest GAN architecture for image synthesis, the input is typically random noise, and its output is a generated image.
Figure 2: Generator input and output (image by the author).
The discriminator is just a binary image classifier which you should already be familiar with. Its job is to classify whether an image is real or fake.
Note: In more complex GANs, we could condition the Discriminator with image or text for Image-to-Image translation or Text-to-Image generation).
Figure 3: Discriminator input and output (image by the author).
Putting it all together, here is what a basic GAN architecture looks like: the generator makes fake images; we feed both the real images (training dataset) and the fake images into the discriminator in separate batches. The discriminator then tells whether an image is real or fake.
Figure 4: GAN architecture (image by the author).
Training GANs
The Minimax game: G vs. D
Most deep learning models (for example, image classification) are based on optimization: finding the low value of the cost function. GANs are different because the two networks: the generator and discriminator, each has its own cost with opposite objectives:
The generator tries to fool the discriminator into thinking the fake images as real
The discriminator tries to classify real and fake images correctly
The minimax game math function below illustrates this adversarial dynamic during training. Don’t worry too much if you don’t understand the math, which I will explain in more detail when coding the G loss and D loss in a future DCGAN post.
Figure 5: GANs Minimax Game (image by the author).
Both the generator and discriminator improve over time during training. The generator gets better and better at producing images that resemble the training data, while the discriminator gets better at telling the real and fake images apart.
Training GANs is to find an equilibrium in the game when:
The generator makes data that looks almost identical to the training data.
The discriminator can no longer tell the difference between the fake images from the real images.
The artist vs. the critic
Mimicking masterpieces is a great way to learn art — “How Artists Are Copying Masterpieces at World-Renowned Museums.” As a human artist mimicking a masterpiece, I’d find the artwork I like as an inspiration and try to copy it as much as possible: the contours, the colors, the compositions and the brushstrokes, and so on. Then a critic takes a look at the copy and tells me whether it looks like the real masterpiece.
Figure 6: An artist copies another painting.
GANs training is similar to that process. We can think of the generator as the artist and the discriminator as the critic. Note the difference in this analogy between the human artist and the machine (GANs) artist, though: the generator doesn’t have access or visibility to the masterpiece that it’s trying to copy. Instead, it only relies on the discriminator’s feedback to improve the images it’s generating.
Evaluation metrics
A good GAN model should have good image quality — for example, not blurry and resembles the training image; and diversity: a good variety of images get generated that approximate the distribution of the training dataset.
To evaluate the GAN model, you can visually inspect the generated images during training or by inference with the generator model. If you’d like to evaluate your GANs quantitatively, here are two popular evaluation metrics:
Inception Score, which captures both the quality and diversity of the generated images
Fréchet Inception Distance which compares the real vs. fake images and doesn’t just evaluate the generated images in isolation
GAN variants
Since Ian Goodfellow et al.’s original GANs paper in 2014, there have been many GAN variants. They tend to build upon each other, either to solve a particular training issue or to create new GANs architectures for finer control of the GANs or better images.
Here are a few of these variants with breakthroughs that provided the foundation for future GAN advances. This is by all means not a complete list of all the GAN variants.
Figure 7: GAN variants timeline (image by the author).
DCGAN (Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks) was the first GAN proposal using Convolutional Neural Network (CNN) in its network architecture. Most of the GAN variations today are somewhat based on DCGAN. Thus, DCGAN is most likely your first GAN tutorial, the “Hello-World” of learning GANs.
WGAN (Wasserstein GAN) and WGAN-GP (were created to solve GAN training challenges such as mode collapse — when the generator produces the same images or a small subset (of the training images) repeatedly. WGAN-GP improves upon WGAN by using gradient penalty instead of weight clipping for training stability.
cGAN (Conditional Generative Adversarial Nets) first introduced the concept of generating images based on a condition, which could be an image class label, image, or text, as in more complex GANs. Pix2Pix and CycleGAN are both conditional GANs, using images as conditions for image-to-image translation.
Pix2PixHD(High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs)disentangles the effects of multiple input conditions and, as in the paper example: control color, texture, and shape of a generated garment image for fashion design. In addition, it can generate realistic 2k high-resolution images.
SAGAN (Self-Attention Generative Adversarial Networks) improves image synthesis quality: generating details using cues from all feature locations by applying the self-attention module (a concept from the NLP models) to CNNs. Google DeepMind scaled up SAGAN to make BigGAN.
BigGAN (Large Scale GAN Training for High Fidelity Natural Image Synthesis) can create high-resolution and high-fidelity images.
ProGAN, StyleGAN, and StyleGAN2 all create high-resolution images.
ProGAN (Progressive Growing of GANs for Improved Quality, Stability, and Variation) grows the network progressively.
StyleGAN (A Style-Based Generator Architecture for Generative Adversarial Networks), introduced by NVIDIA Research, uses the progress growing ProGAN plus image style transfer with adaptive instance normalization (AdaIN) and was able to have control over the style of generated images.
StyleGAN2 (Analyzing and Improving the Image Quality of StyleGAN) improves upon the original StyleGAN by making several improvements in areas such as normalization, progressively growing and regularization techniques, etc.
GAN applications
GANs are versatile and can be used in a variety of applications.
Image synthesis
Image synthesis can be fun and provide practical use, such as image augmentation in machine learning (ML) training or help with creating artwork and design assets.
GANs can be used to create images that never existed before, which is perhaps what GANs are best known for. They can create unseen new faces, cat images and artwork, and more. I’ve included a few high-fidelity images below, which I generated from the websites powered by StyleGAN2. Go to these links, experiment yourself, and see what images you get from your experiments.
GANs can also help train reinforcement agents. For example, NVIDIA’s GameGAN simulates the game environments.
Image-to-image translation
Image-to-image translation is a computer vision task that translates the input image to another domain (e.g., color or style) while preserving the original image content. This is perhaps one of the most important tasks to use GANs in art and design.
Pix2Pix (Image-to-Image Translation with Conditional Adversarial Networks) is a conditional GAN that was perhaps the most famous image-to-image translation GAN. However, one major drawback of Pix2Pix is that it requires paired training image datasets.
Figure 10:Inputs and outputs of Pix2Pix GANs (image source: Pix2Pix paper).
CycleGAN was built upon Pix2Pix and only needs unpaired images, much easier to come by in the real world. It can convert images of apples to oranges, day to night, horses to zebras … ok. These may not be real-world use cases to start with; there are so many other image-to-image GANs developed since then for art and design.
Figure 11:CycleGAN converts a horse to a zebra (image source: CycleGAN Project Page).
Now you can translate your selfie to comics, painting, cartoons, or any other styles you can imagine. For example, I can use White-box CartoonGAN to turn my selfie into a cartoonized version:
Figure 12: Input and output of the White-box CartoonGAN (images by the author).
Colorization can be applied to not only black and white photos but also artwork or design assets. In the artwork making or UI/UX design process, we start with outlines or contours and then coloring. Automatic colorization could help provide inspiration for artists and designers.
Text-to-Image
We’ve seen a lot of Image-to-Image translation examples by GANs. We could also use words as the condition to generate images, which is much more flexible and intuitive than using class labels as the condition.
Figure 13:A GAN transforms NLP and computer vision (image source: StyleCLIP paper).
Beyond images
GANs can be used for not only images but also music and video. For example, GANSynth from the Magenta project can make music. Here is a fun example of GANs on video motion transfer called “Everybody Dance Now” (YouTube | Paper). I’ve always loved watching this charming video where the dance moves by professional dancers get transferred to the amateurs.
Figure 14: A GAN transforms professional dance moves (image source).
Super-resolution (SRGAN & ESRGAN): enhance an image from lower-resolution to high resolution. This could be very helpful in photo editing or medical image enhancements.
Here is an example of how GANs can be used for climate change. Earth Intelligent Engine, an FDL (Frontier Development Lab) 2020 project, uses Pix2PixHD to simulate what an area would look like after flooding.
We have seen GAN demos from papers, research labs. and open source projects. These days we are starting to see real commercial applications using GANs. Designers are familiar with using design assets from icons8. Take a look at their website, and you will notice the GAN applications: from the Smart Upscaler, Generated Photos to Face Generator.
Summary
In this post, you learned a high-level overview of GANs, their variants, and fun applications. While most of the examples in this post are about using GANs for art and design, the same techniques can be easily adapted and applied to many other fields: medicine, agriculture, and climate change. As you see in the post, GANs are powerful and versatile. I hope you are excited to dive deeper into GANs: follow along with the upcoming posts as we explore GANs in depth with code examples!
Join the PyImageSearch Newsletter and Grab My FREE 17-page Resource Guide PDF
Enter your email address below to join the PyImageSearch Newsletter and download my FREE 17-page Resource Guide PDF on Computer Vision, OpenCV, and Deep Learning.
In this blog post, I sit down with Raul Garcia-Martin, a PhD candidate in Biometric Recognition at the University Carlos III of Madrid.
Raul’s work focuses on identifying individual people by their biometrics. You’re likely already familiar with the most popular methods for biometric recognition:
Face recognition
Fingerprint recognition
Retinal scans
… but did you know that the veins in your body can also be used for human identification?
This type of biometric recognition is called Vein or Vascular Biometric Recognition (VBR).
It’s not as well researched as other biometric recognition systems, but research shows that it can be just as accurate, if not more accurate, than the other methods.
Raul has been studying VBR throughout his graduate school career. As his Google Scholar profile shows, he’s already had numerous publications in this sub-niche of computer vision.
I have to admit that Raul reminds me a lot of myself when I was in graduate school.
Not only is Raul performing research and finishing up his PhD, but he’s also an entrepreneur. His company develops specialized cameras (and associated software) that can be used for infrared and thermal imaging.
These cameras can be used for:
Fire detection
Industry inspections
Security
Military applications
COVID-19 body temperature detection
… and not to mention, vein recognition!
To learn more about Raul’s work in Vein/Vascular Biometric Recognition, including how his work has helped him build his company, be sure to give the full interview a read!
An interview with Raul Garcia-Martin, PhD candidate and computer vision entrepreneur
Adrian: Hi Raul! Thank you for taking the time to do this interview. It’s a pleasure to have you on the PyImageSearch blog.
Raul: Hi Adrian! Thank you very much for giving me this great opportunity: you can’t imagine what it means to me to be here, it is an honor and a pleasure to contribute to the PyImageSearch blog. I sincerely hope that the PyImageSearchers could find something interesting in this interview that motivates them, as you do with me, to follow their dreams and the path to become computer vision and deep learning experts.
Adrian: Before we get started, can you tell us a bit about yourself? You’re a PhD candidate at the University Carlos III of Madrid, in addition to carrying out other projects, correct?
Raul: I am in my third year, out of four, as a PhD candidate in Biometric Recognition at the University Carlos III of Madrid. As you mentioned, at the same time, I am trying to go ahead with an entrepreneurial computer vision project.
Adrian: What got you interested in studying computer vision?
Raul: Since I was a child, I remember that I wanted to be an inventor. Around 2002, computer vision and deep learning weren’t as advanced as nowadays. I didn’t know anything about them, but my dream was clear: to engineer and develop technological solutions to improve people’s lives.
Therefore, with this clear yet undefined goal, and thanks to the infinite fields of knowledge that technology offers us, I studied for a bachelor’s degree in Industrial Electronics and Automation Engineering. In this sense, I feel very fortunate because I have had the opportunity to go to university, and I have always had the support of my family.
Without having found my way yet, I started a master’s degree in Electronic Systems and Applications. I combined both university studies with my first jobs as an electronics hardware and firmware developer in a small company in the industry sector and as a software tester in a multinational company in the railway sector.
But it wasn’t until I started my MSc thesis that I fell in love with computer vision. This occurred when I managed to get, for the first time, a video stream in real-time from a webcam using Python and OpenCV.
Figure 2: Traditional Vascular Biometric Recognition preprocessing using CLAHE algorithm to increase the contrast between the vascular patterns and the surrounding living tissue (image source: Figure 7 of Deep Learning for Vein Biometric Recognition on a Smartphone),
Adrian: Based on your Google Scholar profile, most of your work involves vein recognition and vascular biometric analysis using computer vision. Can you tell us a bit more about this research?
Raul: My PhD mainly addresses Vein or Vascular Biometric Recognition (VBR). It is a not very well-known biometric modality that uses the extraction and classification of unique human patterns to authenticate or identify people, just like facial or fingerprint recognition do.
There are four main VBR variants: finger, palm, hand dorsal, and wrist. I am researching wrist VBR because there are already patents and some commercial systems for the finger and palm vein modalities. Furthermore, I think wrist vein patterns are easier to visualize and capture.
Adrian: You recently published a paper in IEEE Access, Deep Learning for Vein Biometric Recognition on a Smartphone (Figure 1).Can you tell us a bit more about this paper? And in a COVID/pandemic world, why would we want to recognize veins using smartphones?
Raul: First of all, I would like to mention that I am very grateful to you because most of the deep learning knowledge presented in this article is based on the teachings extracted from your excellent Deep Learning for Computer Vision with Python book. I had no previous idea about deep learning (Convolutional Neural Networks, CNN, in this case), and in record time, I acquired solid knowledge with well-organized and structured information.
The main goal of this work is to bring Vein Biometric Recognition closer to our daily life, embedding this biometric variant into the small but powerful computer that has become an extension of our bodies: the smartphone. For this purpose, a deep learning model has been integrated into a smartphone for real-time video stream authentication and identification.
PyImageSearch readers can find a good video summary and demonstration here:
This authentication variant on smartphones, I think, could be a really interesting online payment or bank transaction method, being a comfortable and more secure alternative to facial or fingerprint verification.
In an attempt to contribute to this COVID/pandemic world, the other goal of this study is to develop vascular contactless multi-user devices. It is a more challenging computer vision technique, but it is optimal to prevent physical contact between the user and the device, providing a hygienic method of massive access control (e.g., airport border control).
In this sense, smartphones are portable and more comfortable devices that could be used, for instance, by an entrance operator in a sports stadium. In addition, this biometric variant is a secure alternative and more respectful in terms of user privacy than facial recognition, which nowadays presents an added challenge with masks.
Adrian: What type of hardware is needed to perform vein recognition? Is a standard iPhone/Android camera sufficient? Or do you need something more? I think this question will allow us to learn more about your passion: your computer vision entrepreneurship path.
Raul: Excellent question! To perform vein recognition, we only need a near-infrared camera (known as an IR camera). So answering your second question, a standard iPhone/Android or webcam is not sufficient. The RGB (standard) and IR cameras use the same sensor sensitive to visible radiation. But they physically mount, respectively, an IR blocking filter to render what for our eyes is a real image and a visible blocking filter (using the suitable IR torch) to see in the dark, through some plastics or visualize vein patterns.
I am completely in love with IR cameras.
Unfortunately (or not), it is not easy for us to access this type of camera on a smartphone because it is frequently used for facial authentication. However, that is why I love trying to access them. It is one of the reasons behind my entrepreneurship project,RGM Vision, where I’m “bringing infrared cameras and computer vision to people’s daily lives.”
On my website, PyImageSearch readers can find more than 5 infrared camera apps (Figure 3) for 6 different Android devices.
Figure 3: Left: Xiaomi InfraRed Camera Pro app designed for Xiaomi Pocophone F1 and Xiaomi Mi 8 devices. Middle: Pixel 4 InfraRed Camera Pro app designed for Google Pixel 4 and Pixel 4 XL devices. Right: OnePlus 8 Pro Photochrom app designed for OnePlus 8 Pro device. (image source: RGM Vision).
Furthermore, I have launched a new thermal camera available for all Android devices in my latest app: RGMVision ThermalCAM 1 (Figure 4).
Figure 4:Left:RGM Vision ThermalCAM 1 for all Android USB-C devices. Right:RGM Vision ThermalCAM 1 app screenshots.
I know that PyImageSearchers are more than acquainted with this type of IR camera (based on middle-infrared and far-infrared light). Still, if someone is interested, I recommend the latest interview on the PyImageSearch blog with Askat Kuzdeuov.
Adrian: What types of image preprocessing steps are required to perform vein recognition? Can you take the images directly from the infrared camera and apply deep learning models to them? Or do these images require additional preprocessing?
Raul: To perform vein recognition, near-infrared vein images are processed as grayscale images.
Before the eruption of deep learning in recent years, the first step was to preprocess the vein images to increase the contrast between the vascular patterns and the surrounding living tissue. Then, following the traditional biometric recognition paradigm, unique features were extracted and classified.
As we know, deep learning has changed this research methodology.
In this work, I have obtained good results both for preprocessed images, increasing the vein visualization using Contrast Limited Adaptive Histogram Equalization (CLAHE, Figure 2), and for raw grayscale images. The biometric recognition performance, using deep learning models, has not been substantially influenced by the preprocessing step. This seems to indicate that CNNs suffice to extract all relevant features.
Adrian: How did you settle on a model architecture for this project? Did you hand-design the model, or did you apply fine-tuning/transfer learning using existing architectures?
Raul: My first idea was to settle a state-of-the-art CNN architecture and train it from scratch. But in Deep Learning for Computer Vision with Python, I learned an even easier way: transfer learning. So, I implemented and tested both transfer learning variants, i.e., CNN as feature extractors and fine-tuning, over state-of-the-art CNN pre-trained architectures.
For the UC3M-CV2 dataset, only the CNN as a feature extractor technique obtained a high accuracy, as it could be expected, according to your advice, when we work with a small dataset.
Adrian: What computer vision and deep learning tools, libraries, and packages did you utilize in your research?
Raul: Since I started my journey in computer vision, I have always used OpenCV due to its ease of use and simplicity. I love this library! I program using Python for the same reason, even though I hadn’t tried it before starting with computer vision.
Since I began to apply deep learning, I have been using Keras and TensorFlow, but maybe I should also catch up with your latest PyTorch tutorials.
Other than that, as PyImageSearch readers will have already discovered, I love programming with Android: for me, it is the most ready-to-produce programming language.
Adrian: What are your next steps as a researcher? Are you going to continue working on automatic vein recognition, or are you moving on to other projects?
Raul: My idea is to finish my PhD in vein recognition and then try to make a living from computer vision and camera development: I’ll put all my efforts into building a thriving business.
At this point, I would like to change the roles of this interview and ask you for advice on how to become a successful entrepreneur and launch a startup/business (I know that it is related to one of the possible points of your brand new set of courses, “A technical education is not enough to succeed and hit your goals”).
I know this path is hard, but I believe I must try it. In addition, I keep the doors open to any other possibility.
Adrian: Congrats on being one step closer to completing your PhD! What advice would you give to someone who wants to follow in your footsteps, complete their PhD, and become a successfully published researcher?
Raul: Thank you very much, Adrian!
I recommend that they be persistent (in all aspects of their lives, anything worthwhile in our existence needs it) and enjoy the journey putting all their passion into their research. Results will come along.
As practical advice, I have to mention that I have always been a great fan of self-taught learning. In my case, I think that the university has prepared me for it. I also believe in another powerful concept: always remain curious and learn as much as you can from everybody. Especially from the main references in every field without fear of investing in yourself and your education.
Adrian: You’ve been a long-time reader and customer of PyImageSearch. Thank you for supporting us! How has PyImageSearch helped you with your research and journey to completing your PhD?
Raul: I am indebted to you for your teachings and inspiration!
And like everybody else in the computer vision world with Python, I have been learning throughout these years (and currently keep doing so) from your over 350 free tutorials. Thank you again!
If I had to start over again from scratch in computer vision, I would begin without hesitation by checking out your books. At this point, I think that PyImageSearch University could be my next step.
Adrian: If a PyImageSearch reader wants to connect with you, how can they do so?
Raul: It would be a pleasure if any PyImageSearch reader wants to connect with me!
Today we interviewed Raul Garcia-Martin, PhD candidate at University Carlos III of Madrid specializing in computer vision and Vein/Vascular Biometric Recognition.
Raul’s work shows that it’s possible to identify a person using the veins in their body, specifically the hands, wrist, and forearms. Furthermore, this method is just as accurate, if not more accurate, than other biometric recognition methods (i.e., face recognition, fingerprint recognition, etc.).
Additionally, Raul is developing his own set of cameras and software to facilitate further research in this area. Be sure to check out his company, RGM Vision, for more information.
To be notified when future tutorials and interviews are published here on PyImageSearch, simply enter your email address in the form below!
Join the PyImageSearch Newsletter and Grab My FREE 17-page Resource Guide PDF
Enter your email address below to join the PyImageSearch Newsletter and download my FREE 17-page Resource Guide PDF on Computer Vision, OpenCV, and Deep Learning.
PyTorch: Transfer Learning and Image Classification (this tutorial)
Introduction to Distributed Training in PyTorch (next week’s blog post)
If you are new to the PyTorch deep learning library, we suggest reading the following introductory series to help you learn the basics and become acquainted with the PyTorch library:
PyTorch: Transfer Learning and Image Classification
In the first part of this tutorial, we’ll learn what transfer learning is, including how PyTorch allows us to perform transfer learning.
We’ll then configure our development environment and review our project directory structure.
From there, we’ll implement several Python scripts, including:
A configuration script to store important variables
A dataset loader helper function
A script to build and organize our dataset on disk such that PyTorch’s ImageFolder and DataLoader classes can easily be utilized
A driver script that performs basic transfer learning via feature extraction
A second driver script that performs fine-tuning by replacing the fully connected (FC) layer head of a pre-trained network with a brand new, freshly initialized, FC head
A final script that allows us to perform inference with our trained models
We have a lot to review here today, so let’s get started!
What is transfer learning?
Training a Convolutional Neural Network from scratch poses many challenges, most notably the amount of data to train the network and the amount of time it takes for training to take place.
Transfer learning is a technique that allows us to use a model trained for a certain task as a starting point for a machine learning model for a different task.
For example, suppose a model is trained for image classification on the ImageNet dataset. In that case, we can take this model and “re-train” it to recognize classes it was never trained to recognize in the first place!
Imagine, you know how to ride a bicycle and want to ride a motorcycle. Your experience of riding a bicycle — keeping balance, maintaining direction, turning, and braking — will help you learn to ride a motorcycle faster.
This is what transfer learning does in the case of a CNN. Using transfer learning, you can make direct use of a well-trained model by freezing the parameters, changing the output layer, and fine-tuning the weights.
In essence, you can shortcut the entire training procedure and obtain a high accuracy model in a fraction of the time.
How can we perform transfer learning with PyTorch?
There are two primary types of transfer learning:
Transfer learning via feature extraction: We remove the FC layer head from the pre-trained network and replace it with a softmax classifier. This method is super simple as it allows us to treat the pre-trained CNN as a feature extractor and then pass those features through a Logistic Regression classifier.
Transfer learning via fine-tuning: When applying fine-tuning, we again remove the FC layer head from the pre-trained network, but this time we construct a brand new, freshly initialized FC layer head and place it on top of the original body of the network. The weights in the body of the CNN are frozen, and then we train the new layer head (typically with a very small learning rate). We may then choose to unfreeze the body of the network and train the entire network.
The first method tends to be easier to work with, as there is less code involved and fewer parameters to tune. However, the second method tends to be more accurate, leading to models that generalize better.
Both transfer learning via feature extraction and fine-tuning can be implemented with PyTorch — I’ll show you how in the rest of this tutorial.
Configuring your development environment
To follow this guide, you need to have OpenCV, imutils, matplotlib, and tqdm installed on your machine.
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation — PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having problems configuring your development environment?
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
The Flower photos dataset
Let’s look at the Flowers dataset and visualize a few of the images from that dataset. Figure 2 provides a sense of how the images look.
Figure 2: A sample of the images in the Flowers dataset.
Thai dataset 3,670 images belonging to five distinct flower species:
Daisy: 633 images
Dandelion: 898 images
Roses: 641 images
Sunflowers: 699 images
Tulips: 799 images
Our job is to train an image classification model to recognize each of these flower species. We’ll achieve this goal by applying transfer learning with PyTorch.
Project structure
We first need to review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, take a look at the directory structure:
The flower_photos directory contains our set of flower images.
We’ll be training our models on this flowers dataset. The output directory will then be populated with our training/validation plots.
Inside the pyimagesearch module, we have two Python files:
config.py: Contains important configuration variables used in our driver scripts.
create_dataloaders.py: Implements the get_dataloader helper function, responsible for creating a DataLoader instance to parse our files from the flower_photos directory
We then have four Python drive scripts:
build_dataset.py: Takes the flower_photos directory and builds a dataset directory. We’ll create special subdirectories to store our training and validation splits, allowing PyTorch’s ImageFolder script to parse the directory and train our model.
train_feature_extraction.py: Performs transfer learning via feature extraction and serializes the output model to disk.
fine_tune.py: Performs transfer learning via fine-tuning and saves the model to disk.
inference.py: Accepts a trained PyTorch model and uses it to make predictions on input flower images.
The .png files in the project directory structure contain the visualizations of our output predictions.
Creating our configuration file
Before implementing any of our transfer learning scripts, we first need to create our configuration file.
This configuration file will store important variables and parameters used across our driver scripts. Instead of re-defining them in every script, we’ll simply define them once here (and thereby make our code cleaner and easier to read).
Open the config.py file in the pyimagesearch module and insert the following code:
# import the necessary packages
import torch
import os
# define path to the original dataset and base path to the dataset
# splits
DATA_PATH = "flower_photos"
BASE_PATH = "dataset"
# define validation split and paths to separate train and validation
# splits
VAL_SPLIT = 0.1
TRAIN = os.path.join(BASE_PATH, "train")
VAL = os.path.join(BASE_PATH, "val")
Line 7 defines DATA_PATH, the path to our input flower_photos directory.
We then set the BASE_PATH variable to point to our dataset directory (Line 8). This directory will be created and populated via our build_dataset.py script. When we run our transfer learning/inference scripts, we’ll be reading images from the BASE_PATH directory.
Line 12 sets our validation split to 10%, meaning that we’ll take 90% of our data for training and 10% for validation.
We also define the TRAIN and VAL subdirectories on Lines 13 and 14. Once we run build_dataset.py, we’ll have two subdirectories inside dataset:
dataset/train
dataset/val
Each subdirectory will store its respective images for each of the five flower classes.
We’ll fine-tune the ResNet architecture, pre-trained on the ImageNet dataset. This implies that we’ll have to set some important parameters for image pixel scaling:
# specify ImageNet mean and standard deviation and image size
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = 224
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
Lines 17 and 18 define the mean and standard deviation of the pixel intensities in the RGB color space.
These values were obtained by researchers training their models on the ImageNet dataset. They looped over all images in the ImageNet dataset, loaded them from disk, and computed the mean and standard deviation of RGB pixel intensities.
The mean and standard deviation values were then used for image pixel normalization before training.
Even though we are not using the ImageNet dataset for transfer learning, we still need to perform the same preprocessing steps that ResNet was trained on; otherwise, the model would not make correct sense of the input image.
Line 19 sets our input IMAGE_SIZE to be 224 × 224 pixels.
The DEVICE variable controls whether we are using our CPU or GPU for training.
Next, we have some variables that will be used for feature extraction and fine-tuning:
When performing feature extraction, we’ll pass images through our network in batches of 256 (Line 25).
Instead of performing transfer learning via fine-tuning, we’ll use image batches of 64 (Line 26).
When performing inference (i.e., making predictions via the inference.py script), we’ll use batch sizes of 4.
Finally, we set the number of EPOCHS we’ll train our model for, the learning rate for feature extraction, and the learning rate for fine-tuning. These values were determined by running simple hyperparameter tuning experiments.
We’ll wrap our up configuration script by setting output file paths:
# define paths to store training plots and trained model
WARMUP_PLOT = os.path.join("output", "warmup.png")
FINETUNE_PLOT = os.path.join("output", "finetune.png")
WARMUP_MODEL = os.path.join("output", "warmup_model.pth")
FINETUNE_MODEL = os.path.join("output", "finetune_model.pth")
Lines 33 and 34 set the file paths to our output training history and serialized model for feature extraction.
Lines 35 and 36 do the same, only for fine-tuning.
Implementing our DataLoader helper
PyTorch allows us to easily construct DataLoader objects from images stored in directories on disk.
Note: If you’ve never used PyTorch’s DataLoader object before, I suggest you read our introduction to PyTorch tutorials, along with our guide on PyTorch image data loaders.
Open the create_dataloaders.py file inside the pyimagesearch module, and let’s get started:
# import the necessary packages
from . import config
from torch.utils.data import DataLoader
from torchvision import datasets
import os
config: The configuration file we created in the previous section
DataLoader: PyTorch’s data loading class used to handle data batching efficiently
datasets: A submodule from PyTorch that provides access to the ImageFolder class, used to read images from an input directory on disk
os: Used to determine the number of cores/workers on a CPU, allowing data loading to take place faster
From there, we define the get_dataloader function:
def get_dataloader(rootDir, transforms, batchSize, shuffle=True):
# create a dataset and use it to create a data loader
ds = datasets.ImageFolder(root=rootDir,
transform=transforms)
loader = DataLoader(ds, batch_size=batchSize,
shuffle=shuffle,
num_workers=os.cpu_count(),
pin_memory=True if config.DEVICE == "cuda" else False)
# return a tuple of the dataset and the data loader
return (ds, loader)
This function accepts four arguments:
rootDir: Path to the input directory containing our dataset on disk (i.e., the dataset directory)
transforms: A list of data transforms to perform, including preprocessing steps and data augmentation
batchSize: Size of the batches to be yielded from the DataLoader
shuffle: Whether or not to shuffle the data — we’ll shuffle data for training but not for validation
Lines 9 and 10 create our ImageFolder class, used to read images from the rootDir. This is also where we’ll apply our set of transforms.
The DataLoader is then created on Lines 11-14. Here we:
Pass in our ImageFolder object
Set the batch size
Indicate whether or not shuffling will be performed
Set num_workers, which is the number of CPUs/cores on our machine
Set whether or not we’re using GPU memory or not
The resulting ImageFolder and DataLoader instances are returned to the calling function on Line 17.
Creating our dataset organization script
Now that we’ve created our configuration file and implemented our DataLoader helper function, let’s create the build_dataset.py script used to build our dataset directory, along with the train and val subdirectories.
Open the build_dataset.py file in your project directory structure and insert the following code:
# USAGE
# python build_dataset.py
# import necessary packages
from pyimagesearch import config
from imutils import paths
import numpy as np
import shutil
import os
paths: A submodule of imutils used to gather paths to images inside a given directory
numpy: Numerical array processing
shutil: Used to copy files from one location to another
os: Operating system module used to create directories on disk
Next, we have our copy_images function:
def copy_images(imagePaths, folder):
# check if the destination folder exists and if not create it
if not os.path.exists(folder):
os.makedirs(folder)
# loop over the image paths
for path in imagePaths:
# grab image name and its label from the path and create
# a placeholder corresponding to the separate label folder
imageName = path.split(os.path.sep)[-1]
label = path.split(os.path.sep)[1]
labelFolder = os.path.join(folder, label)
# check to see if the label folder exists and if not create it
if not os.path.exists(labelFolder):
os.makedirs(labelFolder)
# construct the destination image path and copy the current
# image to it
destination = os.path.join(labelFolder, imageName)
shutil.copy(path, destination)
The copy_images function requires two arguments:
imagePaths: The paths to all images in a given input directory
folder: The output base directory where copied images will be stored (i.e., the dataset directory)
Lines 13 and 14 make a quick check to see if the folder directory exists. If the directory does not exist, we create it.
From there, we loop over all imagePaths (Line 17). For each path, we:
Grab the filename (Line 20)
Extract the class label from the image path (Line 21)
Construct the base output directory (Line 22)
If the labelFolder subdirectory does not yet exist, we create it on Lines 25 and 26.
From there, we build the path to the destination file (Line 30) and copy it (Line 31).
Let’s now put this copy_images function to work:
# load all the image paths and randomly shuffle them
print("[INFO] loading image paths...")
imagePaths = list(paths.list_images(config.DATA_PATH))
np.random.shuffle(imagePaths)
# generate training and validation paths
valPathsLen = int(len(imagePaths) * config.VAL_SPLIT)
trainPathsLen = len(imagePaths) - valPathsLen
trainPaths = imagePaths[:trainPathsLen]
valPaths = imagePaths[trainPathsLen:]
# copy the training and validation images to their respective
# directories
print("[INFO] copying training and validation images...")
copy_images(trainPaths, config.TRAIN)
copy_images(valPaths, config.VAL)
Lines 35 and 36 read all imagePaths from our input DATA_PATH (i.e., the flower_photos directory) and then randomly shuffle them.
Lines 39-42 create our training and validation splits based on our VAL_SPLIT percentage.
Finally, we use the copy_images function to copy the trainPaths and valPaths to their respective output directories (Lines 47 and 48).
The following section will make this process more clear, including why we are going through all the trouble to organize our dataset directory structure in this specific manner.
Building our dataset on disk
We are now ready to build our dataset directory. Be sure to use the “Downloads” section of this tutorial to access the source code and example images.
From there, open a shell and execute the following command:
$ python build_dataset.py
[INFO] loading image paths...
[INFO] copying training and validation images...
After the script executes, you’ll see that a new dataset directory has been created:
$ tree dataset --dirsfirst --filelimit 10
dataset
├── train
│ ├── daisy [585 entries exceeds filelimit, not opening dir]
│ ├── dandelion [817 entries exceeds filelimit, not opening dir]
│ ├── roses [568 entries exceeds filelimit, not opening dir]
│ ├── sunflowers [624 entries exceeds filelimit, not opening dir]
│ └── tulips [709 entries exceeds filelimit, not opening dir]
└── val
├── daisy [48 entries exceeds filelimit, not opening dir]
├── dandelion [81 entries exceeds filelimit, not opening dir]
├── roses [73 entries exceeds filelimit, not opening dir]
├── sunflowers [75 entries exceeds filelimit, not opening dir]
└── tulips [90 entries exceeds filelimit, not opening dir]
Notice that the dataset directory has two subdirectories:
train: Contains training images for each of the five classes.
val: Stores the validation images for each of the five classes.
By creating a train and val directory, we can now easily utilize PyTorch’s ImageFolder class to build a DataLoader such that we can fine-tune our models.
Implementing feature extraction and transfer learning PyTorch
The first method of transfer learning we are going to implement is feature extraction.
Transfer learning via feature extraction works by:
Taking a pre-trained CNN (typically on the ImageNet dataset)
Removing the FC layer head from the CNN
Treating the output of the body of the network as an arbitrary feature extractor with spatial dimensions M × N × C
From there, we have two choices:
Take a standard Logistic Regression classifier (like the one found in the scikit-learn library) and train it on the extracted features from each image
Or, more simply, place a softmax classifier on top of the body of the network
Either option is viable and more-or-less the “same” as the other.
The first option works great when your dataset of extracted features fits into the RAM of your machine. That way, you load the entire dataset, instantiate an instance of your favorite Logistic Regression classifier model, and then train it.
The problem happens when your dataset is too large to fit into your machine’s memory. When that happens, you could use something like online learning to train your Logistic Regression classifier, but that just introduces another set of libraries and dependencies.
Instead, it’s easier to just leverage the power of PyTorch and create a Logistic Regression-like classifier on top of the extracted features and then train it using PyTorch functions. This is the method we’ll be implementing here today.
Open the train_feature_extraction.py file in your project directory structure, and let’s get started:
# USAGE
# python train_feature_extraction.py
# import the necessary packages
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from imutils import paths
from torchvision.models import resnet50
from torchvision import transforms
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
We build data processing/augmentation steps using the Compose function, found inside the transforms submodule of PyTorch.
First, we create a trainTransform that, given an input image, will:
Randomly resize and crop the image to IMAGE_SIZE dimensions
Randomly perform horizontal flipping
Randomly perform rotation by in the range [-90, 90]
Converts the resulting image into a PyTorch tensor
Performs mean subtraction and scaling
We then have our valTransform, which:
Resizes the input image to IMAGE_SIZE dimensions
Converts the image to a PyTorch tensor
Performs mean subtraction and scaling
Notice that we do not perform data augmentation inside the validation transformer — there is no need to perform data augmentation for our validation data.
With both our training and validation Compose objects created, let’s apply our get_dataloader function:
Lines 32-34 create our training data loaders, while Lines 35-37 create our validation data loaders.
Each of these loaders will yield images from the dataset/train and dataset/val directories, respectively.
Also, note that we do not perform shuffling for our validation data (just like we do not perform data augmentation for validation data).
Let’s now prepare the ResNet50 model for transfer learning via feature extraction:
# load up the ResNet50 model
model = resnet50(pretrained=True)
# since we are using the ResNet50 model as a feature extractor we set
# its parameters to non-trainable (by default they are trainable)
for param in model.parameters():
param.requires_grad = False
# append a new classification top to our feature extractor and pop it
# on to the current device
modelOutputFeats = model.fc.in_features
model.fc = nn.Linear(modelOutputFeats, len(trainDS.classes))
model = model.to(config.DEVICE)
Line 40 loads ResNet, pre-trained on ImageNet from disk.
Since we’ll be using ResNet for feature extraction, and therefore no actual “learning” needs to take place in the body of the network, we freeze all layers in the body of the network (Lines 44 and 45).
From there, we create a new FC layer head that consists of a single FC layer. Effectively, this layer, when trained with categorical cross-entropy loss, will serve as our surrogate softmax classifier.
This new layer is then appended to the body of the network, and the model itself is moved to our DEVICE (either our CPU or GPU).
Next, we initialize our loss function and optimization method:
# initialize loss function and optimizer (notice that we are only
# providing the parameters of the classification top to our optimizer)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.fc.parameters(), lr=config.LR)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // config.FEATURE_EXTRACTION_BATCH_SIZE
valSteps = len(valDS) // config.FEATURE_EXTRACTION_BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "train_acc": [], "val_loss": [],
"val_acc": []}
We’ll train our model using the Adam optimizer and categorical cross-entropy loss (Lines 55 and 56).
We also compute the number of steps our model will take, as a function of batch size, for both our training and testing sets, respectively (Lines 59 and 60).
Now, it’s time to train the model:
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
# set the model in training mode
model.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalValLoss = 0
# initialize the number of correct predictions in the training
# and validation step
trainCorrect = 0
valCorrect = 0
# loop over the training set
for (i, (x, y)) in enumerate(trainLoader):
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# perform a forward pass and calculate the training loss
pred = model(x)
loss = lossFunc(pred, y)
# calculate the gradients
loss.backward()
# check if we are updating the model parameters and if so
# update them, and zero out the previously accumulated gradients
if (i + 2) % 2 == 0:
opt.step()
opt.zero_grad()
# add the loss to the total training loss so far and
# calculate the number of correct predictions
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
On Line 69, we loop over our desired number of epochs.
For each batch of data in the trainLoader, we:
Move the image and class label to our CPU/GPU (Line 85).
Make predictions on the data (Line 88)
Compute the loss, calculate the gradients, update the model weights, and zero the gradients (Lines 89-98)
Accumulate our total training loss for the epoch (Line 102)
Compute the total number of correct predictions (Lines 103 and 104)
Now that the epoch is complete, we can evaluate the model on the validation data:
# switch off autograd
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# loop over the validation set
for (x, y) in valLoader:
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# make the predictions and calculate the validation loss
pred = model(x)
totalValLoss += lossFunc(pred, y)
# calculate the number of correct predictions
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
Notice here that we turn off autograd and put the model in evaluation mode — this is a requirement when evaluating with PyTorch, so don’t forget to do it!
From there, we loop over all data points in our valLoader, make predictions on them, and compute our total loss and number of correct validation predictions.
The following code block aggregates our training/validation loss and accuracy, updates our training history, and then prints the loss/accuracy information to our terminal:
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
trainCorrect = trainCorrect / len(trainDS)
valCorrect = valCorrect / len(valDS)
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["train_acc"].append(trainCorrect)
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
H["val_acc"].append(valCorrect)
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
avgTrainLoss, trainCorrect))
print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
avgValLoss, valCorrect))
Our final code block plots our training history and serializes our model to disk:
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.WARMUP_PLOT)
# serialize the model to disk
torch.save(model, config.WARMUP_MODEL)
After this script executes, you’ll find a file named warmup_model.pth in your output directory — this file is your serialized PyTorch model, which can then be used to make predictions inside the inference.py script.
PyTorch transfer learning with feature extraction
We are now ready to perform transfer learning via feature extraction with PyTorch.
Make sure that you have:
Use the “Downloads” section of this tutorial to access the source code, example images, etc.
Executed the build_dataset.py script to create our dataset directory structure
Provided you’ve accomplished both of these steps, you can move on to running the train_feature_extraction.py script:
$ python train_feature_extraction.py
[INFO] training the network...
0% 0/20 [00:00<?, ?it/s][INFO] EPOCH: 1/20
Train loss: 1.610827, Train accuracy: 0.4063
Val loss: 2.295713, Val accuracy: 0.6512
5% 1/20 [00:17<05:24, 17.08s/it][INFO] EPOCH: 2/20
Train loss: 1.190757, Train accuracy: 0.6703
Val loss: 1.720566, Val accuracy: 0.7193
10% 2/20 [00:33<05:05, 16.96s/it][INFO] EPOCH: 3/20
Train loss: 0.958189, Train accuracy: 0.7163
Val loss: 1.423687, Val accuracy: 0.8120
15% 3/20 [00:50<04:47, 16.90s/it][INFO] EPOCH: 4/20
Train loss: 0.805547, Train accuracy: 0.7811
Val loss: 1.200151, Val accuracy: 0.7793
20% 4/20 [01:07<04:31, 16.94s/it][INFO] EPOCH: 5/20
Train loss: 0.731831, Train accuracy: 0.7856
Val loss: 1.066768, Val accuracy: 0.8283
25% 5/20 [01:24<04:14, 16.95s/it][INFO] EPOCH: 6/20
Train loss: 0.664001, Train accuracy: 0.8044
Val loss: 0.996960, Val accuracy: 0.8311
...
75% 15/20 [04:13<01:24, 16.83s/it][INFO] EPOCH: 16/20
Train loss: 0.495064, Train accuracy: 0.8480
Val loss: 0.736332, Val accuracy: 0.8665
80% 16/20 [04:30<01:07, 16.86s/it][INFO] EPOCH: 17/20
Train loss: 0.502294, Train accuracy: 0.8435
Val loss: 0.732066, Val accuracy: 0.8501
85% 17/20 [04:46<00:50, 16.85s/it][INFO] EPOCH: 18/20
Train loss: 0.486568, Train accuracy: 0.8471
Val loss: 0.703661, Val accuracy: 0.8801
90% 18/20 [05:03<00:33, 16.82s/it][INFO] EPOCH: 19/20
Train loss: 0.470880, Train accuracy: 0.8480
Val loss: 0.715560, Val accuracy: 0.8474
95% 19/20 [05:20<00:16, 16.85s/it][INFO] EPOCH: 20/20
Train loss: 0.489092, Train accuracy: 0.8426
Val loss: 0.684679, Val accuracy: 0.8774
100% 20/20 [05:37<00:00, 16.86s/it]
[INFO] total time taken to train the model: 337.24s
Total training time took just over 5 minutes. We obtained 84.26% training accuracy and 87.74% validation accuracy.
Figure 3 displays a plot of our training history.
Figure 3: Applying feature extraction with PyTorch.
Not too bad for how little time we invested in the training process!
Fine-tuning a CNN with PyTorch
So far in this tutorial, you have learned how to perform transfer learning via feature extraction.
This method works well in some cases, but its simplicity has its drawbacks, namely that both accuracy and the ability of the model to generalize can suffer.
Most forms of transfer learning apply fine-tuning, which is the topic of this section.
Similar to feature extraction, we start by removing the FC layer head from the network, but this time we create a brand new layer head with a set of linear, ReLU, and dropout layers, similar to what you would see on a modern state-of-the-art CNN.
We then perform some combination of:
Freezing all layers in the body of the network and training the layer head
Freezing all layers, training the layer head, and then unfreezing the body and training that too
Simply leaving all layers unfrozen and training them all together
Exactly which method you use is an experiment you’ll run for yourself — be sure to measure which one gives you the lowest loss and highest accuracy!
Let’s learn how to apply fine-tuning via transfer learning with PyTorch. Open the fine_tune.py file in your project directory structure, and let’s get started:
# USAGE
# python fine_tune.py
# import the necessary packages
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from imutils import paths
from torchvision.models import resnet50
from torchvision import transforms
from tqdm import tqdm
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import shutil
import torch
import time
import os
We start on Lines 5-17 by importing our required Python packages. Note that these imports are essentially identical to our previous script.
We then define our training and validation transforms, just like we did for feature extraction:
The real change comes when we load ResNet from disk and modify the architecture itself, so let’s inspect this section closely:
# load up the ResNet50 model
model = resnet50(pretrained=True)
numFeatures = model.fc.in_features
# loop over the modules of the model and set the parameters of
# batch normalization modules as not trainable
for module, param in zip(model.modules(), model.parameters()):
if isinstance(module, nn.BatchNorm2d):
param.requires_grad = False
# define the network head and attach it to the model
headModel = nn.Sequential(
nn.Linear(numFeatures, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, len(trainDS.classes))
)
model.fc = headModel
# append a new classification top to our feature extractor and pop it
# on to the current device
model = model.to(config.DEVICE)
Line 41 loads our ResNet model from disk with weights pre-trained on the ImageNet dataset.
In this particular fine-tuning example, we are going to construct a new FC layer head and then train both the FC layer head and the body of the network at the same time.
However, we first need to pay close attention to the batch normalization layers in the network architecture. These layers have specific mean and standard deviation values that were obtained when the network was originally trained on the ImageNet dataset.
We do not want to update these statistics during training, so we make any instances of BatchNorm2d frozen on Lines 46-48.
If you are performing fine-tuning in a network that utilizes batch normalization, make sure you freeze those layers before you start training!
From there, we construct our new headModel which consists of a series of FC => RELU => DROPOUT layers (Lines 51-59).
The output of the final Linear layer is the number of classes in the dataset (Line 58).
Finally, we add the new headModel to the network, thereby replacing the old FC layer head.
With our “network surgery” done, we can move on to instantiating our loss function and optimizer:
# initialize loss function and optimizer (notice that we are only
# providing the parameters of the classification top to our optimizer)
lossFunc = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=config.LR)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // config.FINETUNE_BATCH_SIZE
valSteps = len(valDS) // config.FINETUNE_BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "train_acc": [], "val_loss": [],
"val_acc": []}
And from there, we start our training pipeline:
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
# set the model in training mode
model.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalValLoss = 0
# initialize the number of correct predictions in the training
# and validation step
trainCorrect = 0
valCorrect = 0
# loop over the training set
for (i, (x, y)) in enumerate(trainLoader):
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# perform a forward pass and calculate the training loss
pred = model(x)
loss = lossFunc(pred, y)
# calculate the gradients
loss.backward()
# check if we are updating the model parameters and if so
# update them, and zero out the previously accumulated gradients
if (i + 2) % 2 == 0:
opt.step()
opt.zero_grad()
# add the loss to the total training loss so far and
# calculate the number of correct predictions
totalTrainLoss += loss
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
At this point, the code to fine-tune our model is identical to the feature extraction method, so you can defer to the previous section for a detailed review of the code.
With training complete, we can then move on to the validation part of the epoch:
# switch off autograd
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# loop over the validation set
for (x, y) in valLoader:
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# make the predictions and calculate the validation loss
pred = model(x)
totalValLoss += lossFunc(pred, y)
# calculate the number of correct predictions
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
trainCorrect = trainCorrect / len(trainDS)
valCorrect = valCorrect / len(valDS)
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["train_acc"].append(trainCorrect)
H["val_loss"].append(avgValLoss.cpu().detach().numpy())
H["val_acc"].append(valCorrect)
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
avgTrainLoss, trainCorrect))
print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
avgValLoss, valCorrect))
After validation is complete, we plot our training history and serialize our model to disk:
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.FINETUNE_PLOT)
# serialize the model to disk
torch.save(model, config.FINETUNE_MODEL)
After executing the train_feature_extraction.py script, you will find a trained model named finetune_model.pth in your output directory.
You can use this model with inference.py to make predictions on new images.
PyTorch fine-tuning results
Let’s now apply fine-tuning using PyTorch.
Again, make sure you have:
Used the “Downloads” section of this tutorial to download the source code, dataset, etc.
Executed the build_dataset.py script to create our dataset directory
From there, you can execute the following command:
$ python fine_tune.py
[INFO] training the network...
0% 0/20 [00:00<?, ?it/s][INFO] EPOCH: 1/20
Train loss: 0.857740, Train accuracy: 0.6809
Val loss: 2.498850, Val accuracy: 0.6512
5% 1/20 [00:18<05:55, 18.74s/it][INFO] EPOCH: 2/20
Train loss: 0.581107, Train accuracy: 0.7972
Val loss: 0.432770, Val accuracy: 0.8665
10% 2/20 [00:38<05:40, 18.91s/it][INFO] EPOCH: 3/20
Train loss: 0.506620, Train accuracy: 0.8289
Val loss: 0.721634, Val accuracy: 0.8011
15% 3/20 [00:57<05:26, 19.18s/it][INFO] EPOCH: 4/20
Train loss: 0.477470, Train accuracy: 0.8341
Val loss: 0.431005, Val accuracy: 0.8692
20% 4/20 [01:17<05:10, 19.38s/it][INFO] EPOCH: 5/20
Train loss: 0.467796, Train accuracy: 0.8368
Val loss: 0.746030, Val accuracy: 0.8120
25% 5/20 [01:37<04:53, 19.57s/it][INFO] EPOCH: 6/20
Train loss: 0.429070, Train accuracy: 0.8523
Val loss: 0.607376, Val accuracy: 0.8311
...
75% 15/20 [04:51<01:36, 19.33s/it][INFO] EPOCH: 16/20
Train loss: 0.317167, Train accuracy: 0.8880
Val loss: 0.344129, Val accuracy: 0.9183
80% 16/20 [05:11<01:17, 19.32s/it][INFO] EPOCH: 17/20
Train loss: 0.295942, Train accuracy: 0.9013
Val loss: 0.375650, Val accuracy: 0.8992
85% 17/20 [05:30<00:58, 19.38s/it][INFO] EPOCH: 18/20
Train loss: 0.282065, Train accuracy: 0.9046
Val loss: 0.374338, Val accuracy: 0.8992
90% 18/20 [05:49<00:38, 19.30s/it][INFO] EPOCH: 19/20
Train loss: 0.254787, Train accuracy: 0.9116
Val loss: 0.302762, Val accuracy: 0.9264
95% 19/20 [06:08<00:19, 19.25s/it][INFO] EPOCH: 20/20
Train loss: 0.270875, Train accuracy: 0.9083
Val loss: 0.385452, Val accuracy: 0.9019
100% 20/20 [06:28<00:00, 19.41s/it]
[INFO] total time taken to train the model: 388.23s
Since our model is more complex (due to adding the new FC layer head to the body of the network), training is now taking ~6.5 minutes.
However, in Figure 4, we obtain higher accuracy than our simple feature extraction method (90.83%/90.19% versus 84.26%/87.74%, respectively):
Figure 4: Applying fine-tuning with PyTorch.
While performing fine-tuning does take more work, you’ll often find that accuracy is higher, and your model will generalize better.
Implementing our PyTorch prediction script
So far, you’ve learned two ways to apply transfer learning with PyTorch:
Feature extraction
Fine-tuning
Both methods have resulted in models obtaining 80-90% accuracy …
… but how do we use these models to make predictions?
The answer is to use our inference.py script:
# USAGE
# python inference.py --model output/warmup_model.pth
# python inference.py --model output/finetune_model.pth
# import the necessary packages
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn
import argparse
import torch
We start our inference.py script with a number of imports, including:
config: Our configuration file
create_dataloaders: Our helper utility to create a DataLoader object from an input directory of images (in this case, our dataset/val directory)
transforms: Applies data preprocessing in a sequential manner
matplotlib: Displays our output images and predictions to our screen
torch and nn: Our PyTorch bindings
argparse: Parses any command line arguments
Speaking of command line arguments, let’s parse them now:
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
help="path to trained model model")
args = vars(ap.parse_args())
We only need a single argument here, --model, which is the path to our trained PyTorch model residing on disk.
Let’s now create a transform object for our input images:
# build our data pre-processing pipeline
testTransform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
# calculate the inverse mean and standard deviation
invMean = [-m/s for (m, s) in zip(config.MEAN, config.STD)]
invStd = [1/s for s in config.STD]
# define our de-normalization transform
deNormalize = transforms.Normalize(mean=invMean, std=invStd)
Just like our validation transformer in the previous section, all we’ll be doing here is:
Resizing our input images to IMAGE_SIZE dimensions
Converting the image to a PyTorch tensor
Applying mean scaling to the input image
However, to display the output images to our screen, we’ll actually need to “denormalize” them. Lines 28 and 29 compute the inverse mean and standard deviation while Line 32 creates a deNormalize transform.
Using the deNormalize transform, we’ll be able to “undo” the testTransform, and then display the output image from our screen.
Let’s now build a DataLoader for our config.VAL directory:
# initialize our test dataset and data loader
print("[INFO] loading the dataset...")
(testDS, testLoader) = create_dataloaders.get_dataloader(config.VAL,
transforms=testTransform, batchSize=config.PRED_BATCH_SIZE,
shuffle=True)
From there, we can set our target computation device and load our trained PyTorch model:
# check if we have a GPU available, if so, define the map location
# accordingly
if torch.cuda.is_available():
map_location = lambda storage, loc: storage.cuda()
# otherwise, we will be using CPU to run our model
else:
map_location = "cpu"
# load the model
print("[INFO] loading the model...")
model = torch.load(args["model"], map_location=map_location)
# move the model to the device and set it in evaluation mode
model.to(config.DEVICE)
model.eval()
Lines 40-47 check to see if we are using our CPU or GPU.
Lines 51-55 proceed to:
Load our trained PyTorch mode from disk
Move it to our target DEVICE
Place the model in evaluation mode
Let’s now grab a random set of testing data from our testLoader:
# grab a batch of test data
batch = next(iter(testLoader))
(images, labels) = (batch[0], batch[1])
# initialize a figure
fig = plt.figure("Results", figsize=(10, 10))
And finally, we can make predictions on our test data:
# switch off autograd
with torch.no_grad():
# send the images to the device
images = images.to(config.DEVICE)
# make the predictions
print("[INFO] performing inference...")
preds = model(images)
# loop over all the batch
for i in range(0, config.PRED_BATCH_SIZE):
# initalize a subplot
ax = plt.subplot(config.PRED_BATCH_SIZE, 1, i + 1)
# grab the image, de-normalize it, scale the raw pixel
# intensities to the range [0, 255], and change the channel
# ordering from channels first tp channels last
image = images[i]
image = deNormalize(image).cpu().numpy()
image = (image * 255).astype("uint8")
image = image.transpose((1, 2, 0))
# grab the ground truth label
idx = labels[i].cpu().numpy()
gtLabel = testDS.classes[idx]
# grab the predicted label
pred = preds[i].argmax().cpu().numpy()
predLabel = testDS.classes[pred]
# add the results and image to the plot
info = "Ground Truth: {}, Predicted: {}".format(gtLabel,
predLabel)
plt.imshow(image)
plt.title(info)
plt.axis("off")
# show the plot
plt.tight_layout()
plt.show()
Line 65 turns off autograd computation (a requirement when placing a PyTorch model in evaluation mode) while Line 67 sends the images to the appropriate DEVICE.
Line 71 makes predictions on the images using our trained model.
To visualize the predictions, we first need to loop over them on Line 74. Inside the loop, we proceed to:
Initialize a subplot to display the image and prediction (Line 76)
Denormalize the image by “undoing” the mean scaling and swapping color channel ordering (Lines 81-84)
Grabbing the ground-truth label (Lines 87 and 88)
Grabbing the predicted label (Lines 91 and 92)
Adding the image, ground-truth, and predicted label to the plot (Lines 95-99)
The output visualization is then displayed on our screen.
Making predictions with our trained PyTorch model
Let’s now make predictions using our inference.py script and our trained PyTorch models.
Go to the “Downloads” section of this tutorial to access the source code, datasets, etc., and from there, you can execute the following command:
$ python inference.py --model output/finetune_model.pth
[INFO] loading the dataset...
[INFO] loading the model...
[INFO] performing inference...
You can see the results in Figure 5.
Figure 5: After applying fine-tuning with PyTorch, we are able to use the trained model to make correct, accurate predictions on input images.
Here you can see that we have correctly classified our flower images — and best of all, we were able to obtain such high accuracy with little effort on our part due to transfer learning.
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned how to perform transfer learning using PyTorch.
Specifically, we discussed two types of transfer learning:
Transfer learning via feature extraction
Transfer learning via fine-tuning
The first method is typically easier to implement and requires less effort. However, it tends to be less accurate than the second method.
I typically recommend using the feature extraction method to obtain a baseline accuracy. If the accuracy is sufficient for your application, fantastic! You’re done, and you can continue building the rest of your project.
However, if accuracy is not sufficient, then you should apply fine-tuning and see if you can boost your accuracy higher.
In either case, transfer learning, whether via feature extraction or fine-tuning, tends to save you a ton of time and effort, as opposed to training your model from scratch.
@article{Rosebrock_2021_Transfer,
author = {Adrian Rosebrock},
title = {{PyTorch}: Transfer Learning and Image Classification},
journal = {PyImageSearch},
year = {2021},
note = {https://www.pyimagesearch.com/2021/10/11/pytorch-transfer-learning-and-image-classification/}, }
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
You’ve built a brand new home out in the country, far from major cities. You need a break from all the hustle and bustle, and you want to bring yourself back to nature.
The house you’ve built is beautiful. It’s quiet at night, you can see constellations dancing in the sky, and you sleep well, knowing you’re going to wake up rested and rejuvenated.
… then, you wake up in the middle of the night. Smoke? Is there a fire?!
You run down the stairs and onto the porch. On the horizon, you see an orange glow, as if the entire sky is on fire. Smoke is billowing, like an angry storm cloud ready to suffocate you.
Sure enough, it’s a wildfire. And based on how the wind is blowing, it’s headed right for you.
Your beautiful serene home is now turned into a combustible nightmare.
And you have to wonder … could computer vision have been used to detect this wildfire early on and thereby alerted firefighters in the area?
No time to think about that now though, just grab what precious belongings you can, throw them in the back of the truck, and get the hell out of there.
As the raging fire bears down on your house, you take one last look in the rearview mirror and vow that you’ll one day figure out how to detect wildfires early on … and then you’re driving down a dirt road in search of civilization.
Scary story, right?
And while I’ve embellished it a bit for dramatic effect, it’s not unlike what David Bonn has experienced in his home out in Washington State … multiple times!
David has dedicated the past few years of his professional career to develop an early warning computer vision system to detect wildfires.
The system runs on a Raspberry Pi and connects to the internet via WiFi or a cellular modem. If a fire is detected, the RPi pings David, after which he can alert the fire department.
Additionally, the United States Patent Office just granted David multiple patents on his work!
It’s truly a pleasure to have David on the PyImageSearch blog today. He’s been a long-time reader, customer, and moderator in the community forums.
And most importantly, his work is helping prevent injury and loss of life in arguably one of the most notoriously hard natural disasters to detect early on.
To learn how David Bonn has created a computer vision system to detect wildfires, just keep reading the full interview!
An interview with David Bonn, computer vision and wildfire detection expert
Adrian: Hi David! Thank you for taking the time to do this interview. It’s a pleasure to have you on the PyImageSearch blog.
David: Thanks, Adrian. Always a pleasure to chat with you.
Adrian: Before we get started, can you tell us a bit about yourself? Where do you work, and what do you do?
David: I have been working off and on as a developer and engineer since college. In between those jobs, I had various “fun” jobs, as a Park Ranger, river guide, ski instructor, and fire lookout.
Along with a few other people, I founded WatchGuard Technologies in 1996, which became wildly successful and is, in fact, still around and independent today. After that adventure, I was semi-retired and traveled a great deal. During that same period, I spent a lot of time working on environmental education programs and other natural history projects.
These days I am trying to get a new company, Deepseek Labs, off the ground.
Adrian: What got you interested in studying computer vision and deep learning?
David: It became obvious to me about ten years ago that something big had changed with respect to neural networks and how they could work. They finally were on the path to having a toolkit that could solve practical problems. And what was even more interesting to me was that there was a large class of real problems that you likely couldn’t solve any better way.
A few years later, I started dabbling a bit with OpenCV, mostly by downloading books and going through their examples, and then I stumbled across your blog in 2015.
In 2014 and 2015, there were a series of large wildfires near my home. While my home was unscathed, almost 500 homes were lost in the combined fires, and several firefighters were killed. It occurred to me that there was a huge unsolved problem here, and I wondered what I could do to solve it.
Figure 1: Tripod Fire (2006) from my home, about 1 am August 8th, 2006 (time exposure).
In late 2017 I found myself in one of those situations where one needs to take stock of their life and what they were doing with it. At the same time, I had many conversations with friends and neighbors about the wildfire problem and what tools we would need to adapt to this new situation.
One thing that popped into my head during those conversations was (approximately), “gosh, couldn’t I make something that would detect fires and give people a few minutes warning to flee for their lives?” So, I resolved to learn enough about computer vision and deep learning to figure out if such a thing was even possible.
Usually, I find that if you wish to master a new skill, it helps a great deal if you have a project (or a set of goals) in mind that needs that new skill to help push you along. So my project for mastering computer vision and deep learning was to build a simple but practical fire detection system.
Adrian: One of my first memories of interacting with you was in the PyImageSearch community forums, where you were discussing fire/smoke detection and how that was such a big concern where you live. How do those wildfires start, and why is early detection so important?
David: Every fire is different. Right now, there are four substantial fires within thirty miles of my home. Three were started by lightning, and one was caused by a person working on an irrigation pump. Typically most “wildfires” in the United States happen in fairly developed areas caused by human activity. That might be anything from a power line shorting on a tree branch to a vehicle idling in dry grass to somebody carelessly cooking hot dogs on a campfire.
Early detection is very important, both from a cost and public safety perspective. A small fire might cost a few thousand dollars to suppress. A large fire can easily run into the tens of millions of dollars. The Cub Creek 2 Fire, near my home, had over 800 people fighting the fire at its peak. A small crew in a brush truck might cost $1000 per day, and a large helicopter typically runs $8000 per hour. Those costs add up quickly.
Figure 2: Helicopter drop near my house, July 17th, 2021. You never want to see this at your house!
Also, while most anyone can safely put out a campfire with a shovel and bucket of water, fighting a large fire is more like fighting a weather phenomenon. You might be able to slow it down or steer it a little bit, but you aren’t likely to stop it or completely suppress it.
Chances are the fires burning near me will still be burning, somewhere, until the snow starts falling in November. But, with a lot of the brush cleared out by the fires, the skiing is likely to be great this winter!
Adrian: A few weeks ago, your own home was affected by a wildfire. Can you tell us about what happened?
David: I can tell you what happened from my perspective.
An important bit of context: I live in north central Washington State, which despite Washington’s reputation, is a pretty dry place, and summers can often be quite warm. It has been an extremely dry spring (most of the eastern part of the state is in a serious drought), and we had an incredible heat wave in late June. So vegetation was incredibly dry. How dry? Well, fire researchers take core samples from trees to evaluate fuel dryness. Core samples taken in my area two weeks before the fire found that living trees were dryer than kiln-dried lumber with a moisture content of about 2%. Your typical sheet of paper has a 3% moisture content.
On Friday, July 16th, I was at home and inside during the heat of the day. About 1:45 pm, I looked outside and noticed an ominous column of smoke just to the south. At that point, I went out, and there was a pretty strong wind blowing from the direction of the smoke.
At that point, a whole lot of practice and planning kicked in. I quickly closed up all of the windows (most were already closed). There was a bag of clothes and a box of documents and hard drives in the entry that I quickly loaded into the truck. Then I got all four dogs and loaded them in the truck as well. After a long last look at the house, I headed down the hill, and at the same time, started texting my neighbors about the fire and encouraging them to get out of there.
At this point, I and all of my neighbors got very lucky. There was another large fire in the area, and they immediately transferred fire crews and aircraft to this new fire. Hence, within an hour, there were several aircraft and around 100 firefighters on the scene (by the time the fire crews were on the scene, the fire was already estimated to be 1000 acres in size).
There is also a heavy equipment operator very close by, and even before all the firefighters were there, he was cutting firelines all over the place with a bulldozer.
All that quick action by my neighbors and firefighters produced a near-miraculous result: only a few buildings were lost, and a few others had minor damage. And nobody was injured or killed.
Despite a county-wide emergency warning system, none of us had any warning at all — the warnings from the county system got to my phone about 2:30 pm. If the exact series of events had happened at 1:45 am rather than 1:45 pm, things would have been tragically worse, certainly in terms of loss of homes and property and likely in terms of lost lives as well.
Interestingly, I had a prototype fire-detection system up and running outside while this all happened. By the time it detected a fire, I was quite a distance away, of course. Unfortunately, by the time it detected a fire, WiFi was out (power went out at my house when I evacuated and wasn’t restored late the next day). As a result, I was unable to save any of the detection images. In addition, the detector itself was running on a Goal Zero battery pack.
Adrian: You and your company have developed a fire detection system that can be used in rural areas. Can you tell us about the solution? How does it work?
David: The 30-second answer is that I use a thermal camera (FLIR Lepton 3) and basic OpenCV image processing functions to find good candidate regions. I pass to another program that inspects those regions with an optical camera and then passes slices of the optical image to a binary classifier.
The longer answer is that the thermal camera looks for two things close together: a hot spot (which is a very bright part of the thermal image) and a region of turbulent motion. So, if I can find turbulent motion close to a very hot spot and the turbulent motion is mostly above the hot spot (remember that hot air rises), the area where the hot spot and turbulent motion overlap (or where they overlap if suitably dilated) is a likely location to find flames.
Figure 3: Two candles in a false color thermal image. The red channel is the (normalized) thermal image, the green channel is turbulent motion, and the blue channel shows the bounding boxes where the algorithm thinks flames are likely.
The optical algorithm then takes that candidate region and translates their coordinates to the optical camera’s coordinate system. Thus, by carefully looking at the candidate region, I can choose a good slice of the optical image suitable for passing to a well-trained binary classifier.
A big discovery (well, for me, it was big) I made was that while with a great dataset and a good network, you could get 96-97% accuracy with a full frame image if your classifier was looking at a well-chosen slice of an image the accuracy went up to over 99%. I suspect that with a carefully constructed ensemble, you could reach far higher accuracies.
If both the thermal algorithm and the classifier agree there is a fire, the system goes into an alarm state.
Figure 4: Detection of a small fire at about 30 meters. The blue square is the candidate region as detected by the thermal camera algorithm. The green square is the slice passed to the classifier. The gray guide lines show the approximate field of view of the thermal camera.
By itself, this system gets an accuracy of over 99.99% — this translates to a “mistake” every 3-5 days when operating in sample (outdoors). Out of sample (e.g., in my kitchen with a gas range) gives 4 or 5 mistakes per day. Higher accuracies would likely be possible with frame averaging or ensembles. And since the thermal images have a very low resolution (160×120) and frame rate (9 fps) most of the time, the system doesn’t have to work very hard to obtain those impressive results.
The approach I use is far from perfect and still struggles in some situations. Hot exhaust from internal combustion engines, especially heavy equipment or farm equipment, frequently confuses the thermal algorithm. The classifier struggles with brightly lit subjects and even more with brightly backlit subjects. Brightly colored birds close to the sensors have produced confusing results at times. These problems are being mitigated over time, often by collecting more representative training data.
I have applied for patents on many parts of this system, and on August 11, I was informed that those patents were allowed. After some more fees and paperwork, those patents will formally be issued and published, and then I can share more details about how the system works.
Adrian: What hardware does your fire detection system run on? Do you need a laptop/desktop, or are you using something like a Raspberry Pi, Jetson Nano, etc.?
David: The core fire detector runs on a Raspberry Pi. The reference implementation right now is a Pi 3B+ and works fine with detection times on the order of 1-2 seconds. As implemented right now, the system either connects to the internet through WiFi or using a cellular modem. My preference down the road is to use the cellular modem, as we can self-configure the system and have it up and running without any end-user setup.
I am booting the Pi read-only. This makes the system much more robust in the face of power outages and other failures, but it isn’t possible to save detection images directly on the SD card.
Other parts of this system will run on the cloud, and there will also be a client (either a web page or an app) that can show you the detectors deployed, where they are deployed, and what their status is.
Adrian: What is the hardest part of combining an infrared camera with “standard” image processing and OpenCV code? What roadblocks did you encounter?
David: The big obstacle for me was getting the hardware to work at all. There were many hardware barriers (including unsupported and deprecated parts) and a lot of obsolete code on GitHub that I had to work through. I finally found some halfway decent code that let me at least get started. For example, the normal Lepton breakout board uses I2C, and so you have a mess of wires connected to the GPIO bus on your Raspberry Pi. I got all that to work, but it wasn’t the best environment for exploring a better flame detection algorithm.
My rate of progress dramatically increased when I switched over to using a Purethermal 2 USB module. This was a huge improvement because, with very little effort, I could experiment with image processing algorithms for the thermal camera on a laptop. So rather than upload code to the Pi and reboot the Pi and look at another display to see the output, I could just have a code-test-debug cycle on my laptop and work on the software at my desk, on the kitchen table, or at a bakery. So I rapidly got a lot more time on the system and learned more in a month than I had in the previous six months.
Once you are talking to the hardware, the real work begins. The Lepton is a remarkably sensitive instrument, which has only 160×120 spatial resolution, but each pixel is 16 bits deep — a 1-bit change in pixel values represents a temperature change of approximately 0.05C. That is sensitive enough that if you walk barefoot on a cold floor, your footprints will “glow” in the thermal image for several minutes. On the other hand, a lot of opencv functions don’t really like a 16-bit grayscale image, so you need to be careful which functions you call, and you might need to use numpy for some operations.
The final thing that you need to watch out for is that thermal images have extremely low contrast. So, unless you enhance them somehow (usually normalizing a histogram is enough), you won’t be able to see anything interesting when you display the images.
Adrian: What do you think you learned from building the fire detection system?
David: Short answer: a lot!
I think the biggest takeaway (so far) is that you should expect to spend a lot of time building great datasets. You shouldn’t expect it to be easy, and you should expect that there will be a lot of trial and error and learning experiences along the way. However, once you start building a great dataset and have a framework in place that lets you continuously improve it, you are in a fantastic place.
In your writing, you talk a lot about how valuable a high-quality dataset is. That is one hundred percent true. However, I’d go further and say that you are creating something even more valuable if you build a framework and processes that let you easily grow and extend that dataset.
Adrian: What computer vision and deep learning tools, libraries, and packages did you utilize in building the fire detection system?
David: I used OpenCV, Tensorflow, Keras, the picamera module, and your imutils library.
Adrian: What are your next steps in the project?
David: We are doing two things right now: the first is we have a few prototypes, and we are getting time on them and learning both the limitations to the approach we are using and how to make it much better. At the same time, we are talking with potential customers, showing them what we have and also talking about what we have in mind, and trying to figure out how to solve their problems.
One big thing we have learned is that this whole idea of fire detectors works much better if you can deploy a lot of them (at least dozens, possibly hundreds) around a community. Then you can give people a web site or an app that lets them answer the question they care about: where is the fire?
Adrian: You’ve been a long-time reader and customer of PyImageSearch. Thank you for supporting us! (And an even bigger thank you for being a wonderfully helpful moderator in the community forums.) How has PyImageSearch helped you with your work and your company?
David: It greatly helps that I can get access to wiser and more experienced people (you and Sayak Paull have both been immensely helpful) who can help me out when I get stuck (when I switched over to Tensorflow 2.x and tf-data, I got all fouled up several times, and you and Sayak both made a huge difference in helping me puzzle out what went wrong).
The other thing I enjoy about participating in PyImageSearch is that helping others is a great way to sharpen my own skills and learn new things. So it has been a great experience all around for me.
Adrian: If a PyImageSearch reader wants to connect with you, how can they do so?
David: The best way to connect with me is on my LinkedIn at David Bonn.
Summary
Today we interviewed David Bonn, an expert in using computer vision to detect wildfires.
David has built a patented computer vision system that can successfully detect wildfires using:
Raspberry Pi
FLIR Lepton thermal camera
Cellular modem
Proprietary OpenCV and deep learning code
The system has already been used to detect wildfires and alert firefighters, preventing injury and loss of life.
Early wildfire detection is yet another example of how computer vision and deep learning are revolutionizing nearly every facet of our lives. David’s work demonstrates, as computer vision practitioners, how much our work can impact the world.
I wish David the best of luck as he continues to develop this system — it truly has the potential to save lives, help the environment, and prevent tremendous property damage.
To be notified when future tutorials and interviews are published here on PyImageSearch, simply enter your email address in the form below!
Join the PyImageSearch Newsletter and Grab My FREE 17-page Resource Guide PDF
Enter your email address below to join the PyImageSearch Newsletter and download my FREE 17-page Resource Guide PDF on Computer Vision, OpenCV, and Deep Learning.
Introduction to Distributed Training in PyTorch (today’s lesson)
When I first learned about PyTorch, I was quite indifferent to it. As someone who used TensorFlow throughout his Deep Learning days, I wasn’t yet ready to leave the comfort zone TensorFlow had created and try out something new.
As fate would have it, due to some unavoidable circumstances, I had to finally dive into PyTorch. Although to be very honest, I had a rough start. Having been accustomed to hiding behind TensorFlow’s abstractions, the verbose nature of PyTorch reminded me exactly why I had left Java and opted for Python.
However, after a while, the beauty of PyTorch started to unravel itself. The reason why it is more verbose is that it lets you have more control over your actions. Granting you a more definite grasp over every step you take, PyTorch gives you more freedom. Perhaps Java also had the same intention, but I’ll never know since that ship has sailed!
Distributed training presents you with several ways to utilize every bit of computation power you have and make your model training much more efficient. One of PyTorch’s stellar features is its support for Distributed training.
Today, we will learn about the Data Parallel package, which enables a single machine, multi-GPU parallelism. After completing this tutorial, the readers will have:
A clear understanding of PyTorch’s Data Parallelism
An idea on implementing Data Parallelism
A clear vision of your goal while traversing through PyTorch’s verbose code
To learn how to use Data Parallel Training in PyTorch, just keep reading.
Imagine having a computer with 4 RTX 2060 GPUs. You have been given a task where you have to deal with several gigabytes of data. Piece of cake, right? What if you had no way of using all that computation power together? That would be extremely frustrating, almost like if we had a billion dollars but were only allowed to spend $5 a month!
It wouldn’t be ideal if we had no way of using all our resources together. Thankfully, PyTorch has our back! Figure 1 shows how PyTorch utilizes multiple GPUs in a single system in a simple yet efficient manner.
Figure 1: Internal workings of PyTorch’s Data Parallel Module.
This is known as Data Parallel training, where you are using a single host system with multiple GPUs to boost your efficiency while dealing with huge piles of data.
The Process is as simple as it can be. Once nn.DataParallel is called, individual model instances are created on each of your GPUs. The data is then batched into equal parts, one for each model instance. Finally, each instance creates its own gradients, which are then averaged and back-propagated amongst all the available instances.
Without further ado, let’s jump into the code and see distributed training in action!
Configuring your development environment
To follow this guide, first and foremost, you need to have PyTorch installed in your system. To access PyTorch’s own set of models for vision computing, you will also need to have Torchvision in your system. We are also using the imutils package for data handling. Finally, we will be using matplotlib to plot our results!
Luckily, all of the above-mentioned packages are pip-installable!
Having problems configuring your development environment?
Figure 2: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
Before hopping into the project, let’s review the project structure.
First and foremost, comes the pyimagesearch directory. It houses:
config.py: houses several important parameters and paths which are used throughout the project
create_dataloaders.py: houses a function that will help us load, process, and handle datasets
food_classifier.py: the main model architecture residing inside this script
The other scripts we’ll use are in the parent directory. They are:
train_distributed.py: defines data processes and trains our model
distributed_inference.py: will be used to assess our trained model on individual test data
Finally, we have our output folder, which will house all the results (plots, models) that all the other scripts produce.
Configuring the Prerequisites
To begin our implementation, let’s start with config.py, the script that will house the configuration of the end-to-end training and inference pipeline. These values will be used throughout the project.
# import the necessary packages
import torch
import os
# define path to the original dataset
DATA_PATH = "Food-11"
# define base path to store our modified dataset
BASE_PATH = "dataset"
# define paths to separate train, validation, and test splits
TRAIN = os.path.join(BASE_PATH, "training")
VAL = os.path.join(BASE_PATH, "validation")
TEST = os.path.join(BASE_PATH, "evaluation")
We define a path to our original dataset (Line 6) and a base path (Line 9) to store our modified dataset. On Lines 12-14, we define separate train, validation, and test paths for our modified dataset using the os.path.join function.
# initialize the list of class label names
CLASSES = ["Bread", "Dairy_product", "Dessert", "Egg", "Fried_food",
"Meat", "Noodles/Pasta", "Rice", "Seafood", "Soup",
"Vegetable/Fruit"]
# specify ImageNet mean and standard deviation and image size
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
IMAGE_SIZE = 224
On Lines 17-19, we define our target classes. We are choosing 11 classes into which our dataset will be grouped. On Lines 22-24, we specify the mean, standard deviation, and image size values for our ImageNet input. Notice how the mean and standard deviation have 3 values each. Each value represents the channel-wise, height-wise, and width-wise mean and standard deviation, respectively. The image size is set to 224 × 224 to match the accepted generalized input size of the ImageNet model.
# set the device to be used for training and evaluation
DEVICE = torch.device("cuda")
# specify training hyperparameters
LOCAL_BATCH_SIZE = 128
PRED_BATCH_SIZE = 4
EPOCHS = 20
LR = 0.0001
# define paths to store training plot and trained model
PLOT_PATH = os.path.join("output", "model_training.png")
MODEL_PATH = os.path.join("output", "food_classifier.pth")
Since today’s task involves demonstrating multiple Graphics Processing Units for training, we will set torch.device to cuda (Line 27). cuda is an ingenious application programming interface (API) developed by NVIDIA, enabling GPUs that are CUDA (Compute Unified Device Architecture) allowed to be used for general purpose processing. Furthermore, since GPUs have more bandwidth and cores than CPUs, they are faster at training machine learning models.
On Lines 30-33, we set up a few hyperparameters like LOCAL_BATCH_SIZE (batch size during training), PRED_BATCH_SIZE (for batch size during inference), epochs, and learning rate. Then, on Lines 36 and 37, we define paths to store our training plot and trained model. The former will assess how well it fared against model metrics, while the latter will be called to the inference module.
For our next task, we’ll move into the create_dataloaders.py script.
# import the necessary packages
from . import config
from torch.utils.data import DataLoader
from torchvision import datasets
import os
def get_dataloader(rootDir, transforms, bs, shuffle=True):
# create a dataset and use it to create a data loader
ds = datasets.ImageFolder(root=rootDir,
transform=transforms)
loader = DataLoader(ds, batch_size=bs, shuffle=shuffle,
num_workers=os.cpu_count(),
pin_memory=True if config.DEVICE == "cuda" else False)
# return a tuple of the dataset and the data loader
return (ds, loader)
On Line 7, we define a function called get_dataloader which takes the root directory, PyTorch’s transform instance, and batch size as external arguments.
On Lines 9 and 10, we are using torchvision.datasets.ImageFolder to map all items in the given directory to have the __getitem__ and __len__ methods. These methods have a very important role to play here.
Firstly, they help represent the dataset in a map-like structure from indices to data samples.
Secondly, the newly mapped dataset can now be passed through a torch.utils.data.DataLoader instance (Lines 11-13), which can load multiple data samples in parallel.
Finally, we are returning the dataset and the DataLoader instance (Line 16).
Preparing the Dataset for Distributed Training
For today’s tutorial, we are using the Food-11 dataset. If you’d like a quick way to download the Food-11 Dataset, please refer to this excellent blog post by Adrian on fine-tuning models created using Keras!
Although the dataset already has a training, testing, and validation split, we will organize it in a more easy-to-understand way.
In its original form, the dataset is in a format shown in Figure 3:
Figure 3: Folder Structure of dataset before processing.
Each filename is in the format class_index_imageNumber.jpg. For example, the file 0_10.jpg refers to an image belonging to the Bread label. Images from all classes are grouped together. In our custom dataset, we will arrange images by their labels and put them in their respective folder with label names. So, after the data preparation, our dataset structure will look something like Figure 4:
Figure 4: Dataset Structure after Processing.
Each label-wise folder will contain respective images belonging to these labels. This is done because many modern frameworks and functions prefer a folder structure like this when processing input.
So, let’s jump into our prepare_dataset.py script and code it out!
# USAGE
# python prepare_dataset.py
# import the necessary packages
from pyimagesearch import config
from imutils import paths
import shutil
import os
def copy_images(rootDir, destiDir):
# get a list of the all the images present in the directory
imagePaths = list(paths.list_images(rootDir))
print(f"[INFO] total images found: {len(imagePaths)}...")
We start by defining a function copy_images (Line 10) which takes two arguments: The root directory where our images are and the destination directory where our custom dataset will be copied. Then, on Line 12, we use the paths.list_images function to generate a list of all images in the root directory. This will be used later while copying the files.
# loop over the image paths
for imagePath in imagePaths:
# extract class label from the filename
filename = imagePath.split(os.path.sep)[-1]
label = config.CLASSES[int(filename.split("_")[0])].strip()
# construct the path to the output directory
dirPath = os.path.sep.join([destiDir, label])
# if the output directory does not exist, create it
if not os.path.exists(dirPath):
os.makedirs(dirPath)
# construct the path to the output image file and copy it
p = os.path.sep.join([dirPath, filename])
shutil.copy2(imagePath, p)
We start iterating over the list of images on Line 16. First, we single out the exact name of the file by separating the preceding pathname (Line 18), and then we identify the label of the file by filename.split("_")[0]) and feed it to config.CLASSES as an index. In the first pass of the loop, the function creates the directory path (Lines 25 and 26). Finally, we construct the path to the current image and use the shutil package to copy the image to the destination path.
# calculate the total number of images in the destination
# directory and print it
currentTotal = list(paths.list_images(destiDir))
print(f"[INFO] total images copied to {destiDir}: "
f"{len(currentTotal)}...")
# copy over the images to their respective directories
print("[INFO] copying images...")
copy_images(os.path.join(config.DATA_PATH, "training"), config.TRAIN)
copy_images(os.path.join(config.DATA_PATH, "validation"), config.VAL)
copy_images(os.path.join(config.DATA_PATH, "evaluation"), config.TEST)
We run a sanity check on Lines 34 and 35 to see if all the files have been copied. This concludes the copy_images function. We call the function on Lines 40-42 and create our modified Train, Test, and Validation dataset!
Creating the PyTorch Classifier
Since our dataset creation is complete, it’s time for us to hop into the food_classifier.py script and define our classifier.
# import the necessary packages
from torch.cuda.amp import autocast
from torch import nn
class FoodClassifier(nn.Module):
def __init__(self, baseModel, numClasses):
super(FoodClassifier, self).__init__()
# initialize the base model and the classification layer
self.baseModel = baseModel
self.classifier = nn.Linear(baseModel.classifier.in_features,
numClasses)
# set the classifier of our base model to produce outputs
# from the last convolution block
self.baseModel.classifier = nn.Identity()
We first define our custom nn.Module class (Line 5). This is normally done when the architecture is more complex, allowing more flexibility while defining our model. Inside the class, our first job is to define the __init__ function to initialize the object’s state.
The super method on Line 7 will enable access to the methods of the base class. Then, on Line 10, we initialize the base model as the baseModel argument that was passed in the constructor (__init__). We then create a separate classification output layer (Line 11) with 11 outputs, each representing one of the classes that we had defined earlier. Finally, since we are using our own classification layer, we replace the inbuilt classifier layer of the baseModel with nn.Identity, which is nothing but a placeholder layer. Hence, the inbuilt classifier of the baseModel will just mirror the outputs of the convolution block just before its classification layer.
# we decorate the *forward()* method with *autocast()* to enable
# mixed-precision training in a distributed manner
@autocast()
def forward(self, x):
# pass the inputs through the base model and then obtain the
# classifier outputs
features = self.baseModel(x)
logits = self.classifier(features)
# return the classifier outputs
return logits
On Line 21, we define the forward() pass of our custom model, but before that, we decorate the model with @autocast(). This decorator function enables mixed-precision during training, which essentially makes your training faster due to the smart assignment of data types. I have linked to a blog by TensorFlow, which explains mixed precision in detail. Finally, on Lines 24 and 25, we get the baseModel output and pass it through the custom classifier layer to get the final output.
Using Distributed Training to Train the PyTorch Classifier
Our next destination is the train_distributed.py, where we will put our model training into motion and learn about putting multiple GPUs to use!
# USAGE
# python train_distributed.py
# import the necessary packages
from pyimagesearch.food_classifier import FoodClassifier
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from sklearn.metrics import classification_report
from torchvision.models import densenet121
from torchvision import transforms
from tqdm import tqdm
from torch import nn
from torch import optim
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
# determine the number of GPUs we have
NUM_GPU = torch.cuda.device_count()
print(f"[INFO] number of GPUs found: {NUM_GPU}...")
# determine the batch size based on the number of GPUs
BATCH_SIZE = config.LOCAL_BATCH_SIZE * NUM_GPU
print(f"[INFO] using a batch size of {BATCH_SIZE}...")
The torch.cuda.device_count() function (Line 20) will list the number of CUDA compatible GPUs present in our system. This will be used to determine our global batch size (Line 24), which is config.LOCAL_BATCH_SIZE * NUM_GPU. This is because if our global batch size is B, and we have NCUDA compatible GPUs, each GPU will deal with data of batch size B/N. For example, for a global batch size of 12 and 2 CUDA compatible GPUs, each GPU will assess data of batch size 6.
Next, we use a very handy function by PyTorch, known as torchvision.transforms. Not only does it help build complex transformation pipelines, but it also grants us a lot of control over the transforms we choose to use.
Notice on Lines 28-34, we use several data augmentations for our training set images, like RandomHorizontalFlip, RandomRotation, etc. We also add the mean and standard deviation normalization values to our dataset using this function.
We again use torchvision.transforms for the test transformations (Lines 35-39), but we don’t add additional augmentations. Instead, we pass these instances through the get_dataloader function that we had created in the create_dataloaders script and get the training, validation, and testing datasets and data loaders, respectively (Lines 42-47).
# load up the DenseNet121 model
baseModel = densenet121(pretrained=True)
# loop over the modules of the model and if the module is batch norm,
# set it to non-trainable
for module, param in zip(baseModel.modules(), baseModel.parameters()):
if isinstance(module, nn.BatchNorm2d):
param.requires_grad = False
# initialize our custom model and flash it to the current device
model = FoodClassifier(baseModel, len(trainDS.classes))
model = model.to(config.DEVICE)
We choose densenet121 as our base model to cover the bulk of our model architecture (Line 50). We then loop over the densenet121 layers and set the batch_norm layers to non-trainable (Lines 54-56). This is done to avoid the issue of an unstable Batch normalization due to varying batch sizes. Once this is complete, we send the densenet121 to the FoodClassifier class and initialize our custom model (Line 59). Finally, we load the model onto our device(s) (Line 60).
# if we have more than one GPU then parallelize the model
if NUM_GPU > 1:
model = nn.DataParallel(model)
# initialize loss function, optimizer, and gradient scaler
lossFunc = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=config.LR * NUM_GPU)
scaler = torch.cuda.amp.GradScaler(enabled=True)
# initialize a learning-rate (LR) scheduler to decay the it by a factor
# of 0.1 after every 10 epochs
lrScheduler = optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.1)
# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // BATCH_SIZE
valSteps = len(valDS) // BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "train_acc": [], "val_loss": [],
"val_acc": []}
First, we use a condition statement to check if our system is eligible for PyTorch Data Parallel (Lines 63 and 64). If the condition is true, we pass our model through the nn.DataParallel module and parallelize our model. Then, on Lines 67-69, we define our Loss function, Optimizer, and create a PyTorch Gradient scaler instance. The Gradient scaler is a very helpful tool that will help bring mixed precision into the gradient calculations. We then initialize a learning-rate scheduler to decay its value by a factor every 10 epochs (Line 73).
On Lines 76 and 77, we calculate the steps per epoch for training and validation batches. The H variable on Lines 80 and 81 will be our training history dictionary, containing values like training loss, training accuracy, validation loss, and validation accuracy.
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.EPOCHS)):
# set the model in training mode
model.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalValLoss = 0
# initialize the number of correct predictions in the training
# and validation step
trainCorrect = 0
valCorrect = 0
# loop over the training set
for (x, y) in trainLoader:
with torch.cuda.amp.autocast(enabled=True):
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# perform a forward pass and calculate the training loss
pred = model(x)
loss = lossFunc(pred, y)
# calculate the gradients
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
opt.zero_grad()
# add the loss to the total training loss so far and
# calculate the number of correct predictions
totalTrainLoss += loss.item()
trainCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# update our LR scheduler
lrScheduler.step()
To assess how much faster our model will train, we time our training process (Line 85). To start our model training, we start looping over our epochs on Line 87. We first set our PyTorch custom model to train mode (Line 89) and initialize training and validation losses and correct predictions (Lines 92-98).
We then loop our training set using the train dataloader (Line 101). Once inside the training set loop, we first enable mixed precision (Line 102) and load the inputs (Data and labels) to the CUDA device (Line 104). Finally, on Lines 107 and 108, we make our model perform a forward pass and calculate the loss using our loss function.
The scaler.scale(loss).backward function automatically calculates the gradient for us (Line 111), which we then proceed to plug into the model weights and update the model (Lines 111-113). Finally, we reset the gradients using opt.zero_grad after completing one pass since the backward function keeps accumulating the gradients (we only need stepwise gradients for each pass).
Lines 118-120 update our loss and correct prediction values while updating our LR scheduler after a complete training pass (Line 123).
# switch off autograd
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# loop over the validation set
for (x, y) in valLoader:
with torch.cuda.amp.autocast(enabled=True):
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# make the predictions and calculate the validation
# loss
pred = model(x)
totalValLoss += lossFunc(pred, y).item()
# calculate the number of correct predictions
valCorrect += (pred.argmax(1) == y).type(
torch.float).sum().item()
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
trainCorrect = trainCorrect / len(trainDS)
valCorrect = valCorrect / len(valDS)
During our evaluation, we will turn off PyTorch’s automatic gradients using torch.no_grad and switch our model to evaluation mode (Lines 126-128). Then, during the training step, we loop over the validation data loader and enable mixed precision before loading the data into our CUDA devices (Lines 131-134). Next, we get predictions for our validation dataset and update the validation loss values (Lines 138 and 139).
Once out of the loop, we calculate the batchwise averages of the training and validation losses and predictions (Lines 146-151).
# update our training history
H["train_loss"].append(avgTrainLoss)
H["train_acc"].append(trainCorrect)
H["val_loss"].append(avgValLoss)
H["val_acc"].append(valCorrect)
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, config.EPOCHS))
print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
avgTrainLoss, trainCorrect))
print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
avgValLoss, valCorrect))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
Before the end of our epochs loop, we log in all the loss and prediction values into our History dictionary H (Lines 154-157).
Once outside the loop, we clock the time using the time.time() function on Line 167 to see how fast our model performed.
# evaluate the network
print("[INFO] evaluating network...")
with torch.no_grad():
# set the model in evaluation mode
model.eval()
# initialize a list to store our predictions
preds = []
# loop over the test set
for (x, _) in testLoader:
# send the input to the device
x = x.to(config.DEVICE)
# make the predictions and add them to the list
pred = model(x)
preds.extend(pred.argmax(axis=1).cpu().numpy())
# generate a classification report
print(classification_report(testDS.targets, preds,
target_names=testDS.classes))
Now it’s time to test our freshly trained model on the test data. Once again, turning the Automatic gradient calculation off, we set our model to evaluation mode (Lines 173-175).
Next, we initialize an empty list called preds on Line 178, which will store the model predictions for the test data. We finally follow the same procedure of loading the data into our devices, getting predictions for our batched test data, and storing the values inside the preds list (Lines 181-187).
Among the several handy tools scikit-learn provides us for assessment of our models, the classification_report provides a complete class-wise overview of the predictions given by our model (Lines 190 and 191).
The complete classification report of our model should look like this, giving us a comprehensive idea about the classes for which our model predicts better/worse than others.
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["val_loss"], label="val_loss")
plt.plot(H["train_acc"], label="train_acc")
plt.plot(H["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig(config.PLOT_PATH)
# serialize the model state to disk
torch.save(model.module.state_dict(), config.MODEL_PATH)
The final step in our training script is to plot the values from our model history dictionary (Lines 194-204) and save the model state in our predefined path (Line 207).
Performing Distributed Training with PyTorch
Before executing the training script, we will need to run the prepare_dataset.py script.
$ python prepare_dataset.py
[INFO] copying images...
[INFO] total images found: 9866...
[INFO] total images copied to dataset/training: 9866...
[INFO] total images found: 3430...
[INFO] total images copied to dataset/validation: 3430...
[INFO] total images found: 3347...
[INFO] total images copied to dataset/evaluation: 3347...
Once this script has run its course, we can move onto executing the train_distributed.py script.
$ python train_distributed.py
[INFO] number of GPUs found: 4...
[INFO] using a batch size of 512...
[INFO] training the network...
0%| | 0/20 [00:00<?, ?it/s][INFO] EPOCH: 1/20
Train loss: 1.267870, Train accuracy: 0.6176
Val loss: 0.838317, Val accuracy: 0.7586
5%|███▏ | 1/20 [00:37<11:47, 37.22s/it][INFO] EPOCH: 2/20
Train loss: 0.669389, Train accuracy: 0.7974
Val loss: 0.580541, Val accuracy: 0.8394
10%|██████▍ | 2/20 [01:03<09:16, 30.91s/it][INFO] EPOCH: 3/20
Train loss: 0.545763, Train accuracy: 0.8305
Val loss: 0.516144, Val accuracy: 0.8580
15%|█████████▌ | 3/20 [01:30<08:14, 29.10s/it][INFO] EPOCH: 4/20
Train loss: 0.472342, Train accuracy: 0.8547
Val loss: 0.482138, Val accuracy: 0.8682
...
85%|█████████████████████████████████████████████████████▌ | 17/20 [07:40<01:19, 26.50s/it][INFO] EPOCH: 18/20
Train loss: 0.226185, Train accuracy: 0.9338
Val loss: 0.323659, Val accuracy: 0.9099
90%|████████████████████████████████████████████████████████▋ | 18/20 [08:06<00:52, 26.32s/it][INFO] EPOCH: 19/20
Train loss: 0.227704, Train accuracy: 0.9331
Val loss: 0.313711, Val accuracy: 0.9140
95%|███████████████████████████████████████████████████████████▊ | 19/20 [08:33<00:26, 26.46s/it][INFO] EPOCH: 20/20
Train loss: 0.228238, Train accuracy: 0.9332
Val loss: 0.318986, Val accuracy: 0.9105
100%|███████████████████████████████████████████████████████████████| 20/20 [09:00<00:00, 27.02s/it]
[INFO] total time taken to train the model: 540.37s
After 20 epochs, the average Train accuracy hit 0.9332 while the validation accuracy hit a commendable 0.9105. Let’s first look at the metric plots in Figure 5!
Figure 5: Training and Validation Plots.
By looking at how close the training and validation metrics evolved throughout, we can safely say that our model didn’t overfit.
Data Distributed Training inference
Although we have already evaluated the model on our test set, we will create a separate script, distributed_inference.py, where we will individually assess test images one by one instead of a full batch at a time.
# USAGE
# python distributed_inference.py
# import the necessary packages
from pyimagesearch.food_classifier import FoodClassifier
from pyimagesearch import config
from pyimagesearch import create_dataloaders
from torchvision import models
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn
import torch
# determine the number of GPUs we have
NUM_GPU = torch.cuda.device_count()
print(f"[INFO] number of GPUs found: {NUM_GPU}...")
# determine the batch size based on the number of GPUs
BATCH_SIZE = config.PRED_BATCH_SIZE * NUM_GPU
print(f"[INFO] using a batch size of {BATCH_SIZE}...")
# define augmentation pipeline
testTransform = transforms.Compose([
transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
Before initializing the iterators, we set up the initial requirements for these scripts. These include setting up the batch size dictated by the number of CUDA GPUs (Lines 15-19) and initializing a torchvision.transforms instance for our test dataset (Lines 23-27).
# calculate the inverse mean and standard deviation
invMean = [-m/s for (m, s) in zip(config.MEAN, config.STD)]
invStd = [1/s for s in config.STD]
# define our denormalization transform
deNormalize = transforms.Normalize(mean=invMean, std=invStd)
# create test data loader
(testDS, testLoader) = create_dataloaders.get_dataloader(config.TEST,
transforms=testTransform, bs=BATCH_SIZE, shuffle=True)
# load up the DenseNet121 model
baseModel = models.densenet121(pretrained=True)
# initialize our food classifier
model = FoodClassifier(baseModel, len(testDS.classes))
# load the model state
model.load_state_dict(torch.load(config.MODEL_PATH))
It is important to understand why we are calculating the inverse mean and inverse standard deviation values on Lines 30 and 31. This is because our torchvision.transforms instance normalizes the dataset before it is plugged into the model. So, to turn the image back to its original form, we are calculating these values beforehand. We’ll see how these are used pretty soon!
With these values, we create a torchvision.transforms.Normalize instance for later use (Line 34). Next, we create our test dataset and data loaders using the create_dataloaders method on Lines 37 and 38.
Note that we had saved the trained model state in train_distributed.py. Next, we’ll initialize the model as we had done in the training script (Lines 41-44) and use the model.load_state_dict function to plug in the trained model weights into the initialized model (Line 47).
# if we have more than one GPU then parallelize the model
if NUM_GPU > 1:
model = nn.DataParallel(model)
# move the model to the device and set it in evaluation mode
model.to(config.DEVICE)
model.eval()
# grab a batch of test data
batch = next(iter(testLoader))
(images, labels) = (batch[0], batch[1])
# initialize a figure
fig = plt.figure("Results", figsize=(10, 10 * NUM_GPU))
We repeat parallelizing the model using nn.DataParallel and set the model to evaluation mode (Lines 50-55). Since we’ll be working with individual data points, we won’t be needing to loop over the full test dataset. Instead, we’ll just grab a batch of test data using next(iter(loader)) (Lines 58 and 59). You can run this function (till the generator runs out of batches) to randomize the batch choice.
# switch off autograd
with torch.no_grad():
# send the images to the device
images = images.to(config.DEVICE)
# make the predictions
preds = model(images)
# loop over all the batch
for i in range(0, BATCH_SIZE):
# initialize a subplot
ax = plt.subplot(BATCH_SIZE, 1, i + 1)
# grab the image, de-normalize it, scale the raw pixel
# intensities to the range [0, 255], and change the channel
# ordering from channels first to channels last
image = images[i]
image = deNormalize(image).cpu().numpy()
image = (image * 255).astype("uint8")
image = image.transpose((1, 2, 0))
# grab the ground truth label
idx = labels[i].cpu().numpy()
gtLabel = testDS.classes[idx]
# grab the predicted label
pred = preds[i].argmax().cpu().numpy()
predLabel = testDS.classes[pred]
# add the results and image to the plot
info = "Ground Truth: {}, Predicted: {}".format(gtLabel,
predLabel)
plt.imshow(image)
plt.title(info)
plt.axis("off")
# show the plot
plt.tight_layout()
plt.show()
Again, since we have no intention of changing the weights of our model, we turn off automatic gradients (Line 65) and flash the test images into our device(s). Finally, on Line 70, we directly make our model predictions on the images in the batch.
Looping over the images in the batch, we select individual images, denormalize them, scale up their values and change their dimension order (Lines 80-83). Changing dimensions is necessary if we display the image because PyTorch chose to design its modules to take channel first inputs. Meaning, our image fresh out of torchvision.transforms is currently Channels * Height * Width. To display it, we have to rearrange the dimensions in the form Height * Width * Channels.
We use the individual label of the image to get the name of the class using testDS.classes (Lines 86 and 87). Next, we get the individual image’s predicted class (Lines 90 and 91). Finally, we compare the real and predicted labels for the individual image (Lines 94-98).
This concludes our inference script for Data Parallel training!
PyTorch Visualizations of Data Parallel Trained Model
Let’s look at a few results plotted by our inference script distributed_inference.py.
As we had taken a batch size of 4 in our inference script, our plot will show pictures of the present batch.
The batch of data sent to our inference script contains: an image of oyster shells (Figure 6), an image of french fries (Figure 7), an image containing meat (Figure 8), and an image of a chocolate cake (Figure 9).
Figure 6: Image of Oyster shells, predicted correctly as Seafood.
Figure 7: Image of French Fries, correctly identified as Fried food.
Figure 8: Image of meat, incorrectly Identified as Dessert.
Figure 9: Image of a Chocolate Cake, correctly identified as Dessert.
Here we see 3 predictions correct out of a possible 4. This, along with our complete test set scores, tells us that using PyTorch’s data parallel worked pretty nicely!
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In today’s tutorial, we got a little taste of one of PyTorch’s vast array of Distributed Training procedures. The nn.DataParallel may not be the most efficient or fastest among other Distributed Training procedures in terms of internal workings, but it sure is a great place to start! It’s easy to understand and takes only a single line of code to implement. As I mentioned before, the other procedures require more coding, but they were created to handle things more efficiently.
Some very evident problems with nn.DataParallel would be:
the redundancy of creating entire model instances itself
failing to work when the model becomes too big to fit
having no way of adaptively adjusting training when the GPUs available are different
Especially when dealing with a big architecture, model parallelism is preferred, where you can split layers of models among the GPUs.
With that being said, if you are someone who owns multiple GPUs in your system, make use of every bit of computational power your system can provide using nn.DataParallel.
I hope you found this tutorial was helpful enough to pave the way for your curiosity in mastering Distributed Training as a whole!
@article{Chakraborty_2021_Distributed,
author = {Devjyoti Chakraborty},
title = {Introduction to Distributed Training in {PyTorch}},
journal = {PyImageSearch},
year = {2021},
note = {https://www.pyimagesearch.com/2021/10/18/introduction-to-distributed-training-in-pytorch/},
}
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Training an object detector from scratch in PyTorch (today’s tutorial)
U-Net: Training Image Segmentation Models in PyTorch (next week’s blog post)
Since my childhood, the idea of artificial intelligence (AI) has fascinated me (like every other kid). But, of course, the concept of AI that I had was vastly different from what it actually was, unquestionably due to pop culture. Until the end of my teenage years, I firmly believed that the unchecked growth of AI would lead to something like the T-800 (the terminator from The Terminator). Fortunately, the actual scenario can be better explained using Figure 1:
Figure 1: Machine Learning.
Don’t get me wrong, though. Machine Learning may be a bunch of matrices and calculus coalesced together, but the sheer amount of things we can do with these can be best described by a single word: limitless.
One such application, which always intrigued me, was Object Detection. Pouring in image data to get labels was one thing, but making our model learn where the label is? That’s a whole different ball game, something right out of some espionage movie. And that is exactly what we’ll be going through today!
In today’s tutorial, we’ll learn how to train our very own object detector from scratch in PyTorch. This blog will help you:
Understand the intuition behind Object Detection
Understand the step-by-step approach to building your own Object Detector
Learn how to fine-tune parameters to get ideal results
To learn how to train an object detector from scratch in Pytorch, just keep reading.
Training an Object Detector from scratch in PyTorch
Much before the power deep learning algorithms of today existed, Object Detection was a domain that was extensively worked on throughout history. From the late 1990s to the early 2020s, many new ideas were proposed, which are still used as benchmarks for deep learning algorithms to this day. Unfortunately, back then, researchers didn’t have much computation power at their disposal, so most of these techniques relied on lots of additional mathematics to reduce compute time. Thankfully, we wouldn’t be facing that problem.
Our Approach to Object Detection
Let’s first understand the intuition behind Object Detection. The approach we are going to take is quite similar to training a simple classifier. The weights of the classifier keep changing until it outputs the correct labels for a given dataset and reduces loss. We will be doing the exact same thing for today’s task, except our model will output 5 values, 4 of them being thecoordinates of the bounding box surrounding our object. The 5th value is the label of the object being detected. Notice the architecture in Figure 2.
Figure 2: Schematic of Our Object Detector.
The main model will branch into two subsets: the regressor and the classifier. The former will output the bounding box’s starting and ending coordinates, while the latter will output the object label. The combined losses generated by these 5 values will serve in our backpropagation. Quite a simple way to start, isn’t it?
Of course, through the years, several powerful algorithms took over the Object Detection domain, like R-CNN and YOLO. But our approach will serve as a reasonable starting point to wrap your head around the basic idea behind Object Detection!
Configuring your development environment
To follow this guide, first and foremost, you need to have PyTorch installed in your system. To access PyTorch’s own set of models for vision computing, you will also need to have Torchvision in your system. For some array and storage operations, we have employed the use of numpy. We are also using the imutils package for data handling. For our plots, we will be using matplotlib. For better tracking of our model training, we’ll be using tqdm, and finally, we’ll be needing OpenCV in our system!
Luckily, all of the above-mentioned packages are pip-installable!
If you need help configuring your development environment for OpenCV, I highly recommend that you read my pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having problems configuring your development environment?
Figure 3: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project structure
We first need to review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, take a look at the directory structure:
The first item in the directory is dataset.zip. This zip file contains the complete dataset (Images, labels, and bounding boxes). More about it in a later section.
Next, we have the output directory. This directory is where all our saved models, results, and other important requirements are dumped.
There are two scripts in the parent directory:
train.py: used to train our object detector
predict.py: used to draw inference from our model and see the object detector in action
Lastly, we have the most important directory, the pyimagesearch directory. It houses 3 very important scripts.
bbox_regressor.py: houses the complete object detector architecture
config.py: contains the configuration of the end-to-end training and inference pipeline
custom_tensor_dataset.py: contains a custom class for data preparation
That concludes the review of our project directory.
Configuring the prerequisites for Object Detection
Our first task is to configure several hyperparameters we’ll be using throughout the project. For that, let’s hop into the pyimagesearch folder and open the config.py script.
# import the necessary packages
import torch
import os
# define the base path to the input dataset and then use it to derive
# the path to the input images and annotation CSV files
BASE_PATH = "dataset"
IMAGES_PATH = os.path.sep.join([BASE_PATH, "images"])
ANNOTS_PATH = os.path.sep.join([BASE_PATH, "annotations"])
# define the path to the base output directory
BASE_OUTPUT = "output"
# define the path to the output model, label encoder, plots output
# directory, and testing image paths
MODEL_PATH = os.path.sep.join([BASE_OUTPUT, "detector.pth"])
LE_PATH = os.path.sep.join([BASE_OUTPUT, "le.pickle"])
PLOTS_PATH = os.path.sep.join([BASE_OUTPUT, "plots"])
TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])
We start by defining several paths which we will later use. Then on Lines 7-12, we define paths for our datasets (Images and annotations) and output. Next, we create separate paths for our detector and label encoder, followed by paths for our plots and testing images (Lines 16-19).
# determine the current device and based on that set the pin memory
# flag
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PIN_MEMORY = True if DEVICE == "cuda" else False
# specify ImageNet mean and standard deviation
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
# initialize our initial learning rate, number of epochs to train
# for, and the batch size
INIT_LR = 1e-4
NUM_EPOCHS = 20
BATCH_SIZE = 32
# specify the loss weights
LABELS = 1.0
BBOX = 1.0
Since we are training an object detector, it’s advisable to train on a GPU instead of a CPU since the computations are more complex. Hence, we set our PyTorch device to CUDA if a CUDA-compatible GPU is available in our system (Lines 23 and 24).
We will, of course, be using PyTorch’s transforms during our dataset preparation. Hence we specify the mean and standard deviation values (Lines 27 and 28). The three values represent the channel-wise, width-wise, and height-wise mean and standard deviation, respectively. Finally, we initialize hyperparameters like learning rate, epochs, batch size, and Loss weights for our model (Lines 32-38).
Creating the Custom Object Detection Data processor
The dataset subdivides into two folders: annotations (which contains CSV files of bounding box start and end points) and images (which are further divided into three folders, each representing the classes we’ll be using today).
Since we’ll use PyTorch’s own DataLoader, it’s important to preprocess the data in a way that the DataLoader will accept. The custom_tensor_dataset.py script will do exactly that.
# import the necessary packages
from torch.utils.data import Dataset
class CustomTensorDataset(Dataset):
# initialize the constructor
def __init__(self, tensors, transforms=None):
self.tensors = tensors
self.transforms = transforms
We have created a custom class, CustomTensorDataset, which inherits from the torch.utils.data.Dataset class (Line 4). This way, we can configure the internal functions to our needs while retaining the core properties of the torch.utils.data.Dataset class.
On Lines 6-8, the constructor function __init__ is created. The constructor takes in two arguments:
tensors: A tuple of three tensors, namely the image, label, and the bounding box coordinates.
transforms: A torchvision.transforms instance which will be used to process the image.
def __getitem__(self, index):
# grab the image, label, and its bounding box coordinates
image = self.tensors[0][index]
label = self.tensors[1][index]
bbox = self.tensors[2][index]
# transpose the image such that its channel dimension becomes
# the leading one
image = image.permute(2, 0, 1)
# check to see if we have any image transformations to apply
# and if so, apply them
if self.transforms:
image = self.transforms(image)
# return a tuple of the images, labels, and bounding
# box coordinates
return (image, label, bbox)
Since we are using a custom class, we will override the parent (Dataset) class’s methods. So, the __getitem__ method is altered according to our needs. But, first, the tensor tuple is unpacked into its constituents (Lines 12-14).
The image tensor is originally in the form Height × Width × Channels. However, all PyTorch models need their input to be “channel first.” Accordingly, the image.permute method rearranges the image tensor (Line 18).
We add a check for the torchvision.transforms instance on Lines 22 and 23. If the check yields true, the image is passed through the transform instance. After this, the __getitem__ method returns the image, label, and bounding boxes.
def __len__(self):
# return the size of the dataset
return self.tensors[0].size(0)
The second method that we’ll override is the __len__ method. It returns the size of the image dataset tensor (Lines 29-31). This concludes the custom_tensor_dataset.py script.
Building the Objection Detection Architecture
Coming to the model we’ll be needing for this project, we need to keep two things in mind. First, to avoid additional hassle and for efficient feature extraction, we’ll use a pre-trained model to act as the base model. Second, the base model will then be split into two parts; the box regressor and the label classifier. Both of these will be individual model entities.
The second thing to remember is that only the box regressor and the label classifier will have trainable weights. The weights of the pre-trained model will be left untouched, as shown in Figure 4.
Figure 4: Model Architecture.
With this in mind, let’s hop into bbox_regressor.py!
# import the necessary packages
from torch.nn import Dropout
from torch.nn import Identity
from torch.nn import Linear
from torch.nn import Module
from torch.nn import ReLU
from torch.nn import Sequential
from torch.nn import Sigmoid
class ObjectDetector(Module):
def __init__(self, baseModel, numClasses):
super(ObjectDetector, self).__init__()
# initialize the base model and the number of classes
self.baseModel = baseModel
self.numClasses = numClasses
For the custom model ObjectDetector, we’ll use torch.nn.Module as the parent class (Line 10). For the constructor function __init__, there are two external arguments; the base model and the number of labels (Lines 11-16).
# build the regressor head for outputting the bounding box
# coordinates
self.regressor = Sequential(
Linear(baseModel.fc.in_features, 128),
ReLU(),
Linear(128, 64),
ReLU(),
Linear(64, 32),
ReLU(),
Linear(32, 4),
Sigmoid()
)
Moving on to the regressor, keep in mind that our end goal is to produce 4 separate values: the starting x-axis value, the starting y-axis value, the ending x-axis value, and the ending y-axis value. The first Linear layer inputs the fully connected layer of the base model with an output size set to 128 (Line 21).
This is followed by a few Linear and ReLU layers (Lines 22-27), finally ending with a Linear layer which outputs 4 values followed by a Sigmoid layer (Line 28).
# build the classifier head to predict the class labels
self.classifier = Sequential(
Linear(baseModel.fc.in_features, 512),
ReLU(),
Dropout(),
Linear(512, 512),
ReLU(),
Dropout(),
Linear(512, self.numClasses)
)
# set the classifier of our base model to produce outputs
# from the last convolution block
self.baseModel.fc = Identity()
The next step is the classifier for the object label. In the Regressor, we take the base model’s fully connected layer’s feature size and plug it into the first Linear layer (Line 33). This is followed by repeating the ReLU, Dropout, and Linear layers (Lines 34-40). The Dropout layers are generally used to help spread generalization and prevent overfitting.
The final step of the initialization is to make the base model’s fully connected layer into an Identity layer, which means it’ll mirror the outputs produced by the convolution block right before it (Line 44).
def forward(self, x):
# pass the inputs through the base model and then obtain
# predictions from two different branches of the network
features = self.baseModel(x)
bboxes = self.regressor(features)
classLogits = self.classifier(features)
# return the outputs as a tuple
return (bboxes, classLogits)
Next comes the forward step (Line 46). We simply take the output of the base model and pass it through the regressor and the classifier (Lines 49-51).
With that, we finish designing the architecture of our object detector.
Training the Object Detection Model
Just one more step remaining before we can see the object detector in action. So let’s hop over to the train.py script and train the model!
# USAGE
# python train.py
# import the necessary packages
from pyimagesearch.bbox_regressor import ObjectDetector
from pyimagesearch.custom_tensor_dataset import CustomTensorDataset
from pyimagesearch import config
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn import CrossEntropyLoss
from torch.nn import MSELoss
from torch.optim import Adam
from torchvision.models import resnet50
from sklearn.model_selection import train_test_split
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pickle
import torch
import time
import cv2
import os
# initialize the list of data (images), class labels, target bounding
# box coordinates, and image paths
print("[INFO] loading dataset...")
data = []
labels = []
bboxes = []
imagePaths = []
After importing the necessary packages, we create empty lists for our data, labels, bounding boxes, and image paths (Lines 29-32).
Now it’s time for some data pre-processing.
# loop over all CSV files in the annotations directory
for csvPath in paths.list_files(config.ANNOTS_PATH, validExts=(".csv")):
# load the contents of the current CSV annotations file
rows = open(csvPath).read().strip().split("\n")
# loop over the rows
for row in rows:
# break the row into the filename, bounding box coordinates,
# and class label
row = row.split(",")
(filename, startX, startY, endX, endY, label) = row
# derive the path to the input image, load the image (in
# OpenCV format), and grab its dimensions
imagePath = os.path.sep.join([config.IMAGES_PATH, label,
filename])
image = cv2.imread(imagePath)
(h, w) = image.shape[:2]
# scale the bounding box coordinates relative to the spatial
# dimensions of the input image
startX = float(startX) / w
startY = float(startY) / h
endX = float(endX) / w
endY = float(endY) / h
# load the image and preprocess it
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
# update our list of data, class labels, bounding boxes, and
# image paths
data.append(image)
labels.append(label)
bboxes.append((startX, startY, endX, endY))
imagePaths.append(imagePath)
On Line 35, we start looping over all available CSVs in the directory. Opening the CSVs, we then begin looping over the rows to split the data (Lines 37-44).
After splitting the row values into a tuple of individual values, we first single out the image path (Line 48). Then, we use OpenCV to read the image and get its height and width (Lines 50 and 51).
The height and width values are then used to scale the bounding box coordinates to the range of 0 and 1 (Lines 55-58).
Next, we load the image and do some slight preprocessing (Lines 61-63).
The empty lists are then updated with the unpacked values, and the process repeats as each iteration passes (Lines 67-70).
# convert the data, class labels, bounding boxes, and image paths to
# NumPy arrays
data = np.array(data, dtype="float32")
labels = np.array(labels)
bboxes = np.array(bboxes, dtype="float32")
imagePaths = np.array(imagePaths)
# perform label encoding on the labels
le = LabelEncoder()
labels = le.fit_transform(labels)
For faster processing of data, the lists are converted into numpy arrays (Lines 74-77). Since the labels are in string format, we use scikit-learn’s LabelEncoder to transform them into their respective indices (Lines 80 and 81).
# partition the data into training and testing splits using 80% of
# the data for training and the remaining 20% for testing
split = train_test_split(data, labels, bboxes, imagePaths,
test_size=0.20, random_state=42)
# unpack the data split
(trainImages, testImages) = split[:2]
(trainLabels, testLabels) = split[2:4]
(trainBBoxes, testBBoxes) = split[4:6]
(trainPaths, testPaths) = split[6:]
Using another handy scikit-learn tool called train_test_split, we part the data into training and test sets, keeping an 80-20 ratio (Lines 85 and 86). Since the split will apply to all the arrays passed into the train_test_split function, we can unpack them into tuples using simple row slicing (Lines 89-92).
The unpacked train and test data, labels, and bounding boxes are then converted into PyTorch tensors from the numpy format (Lines 95-100). Next, we proceed to create a torchvision.transforms instance to easily process the dataset (Lines 103-107). Through this, the dataset will also get normalized using the mean and standard deviation values defined in config.py.
# convert NumPy arrays to PyTorch datasets
trainDS = CustomTensorDataset((trainImages, trainLabels, trainBBoxes),
transforms=transforms)
testDS = CustomTensorDataset((testImages, testLabels, testBBoxes),
transforms=transforms)
print("[INFO] total training samples: {}...".format(len(trainDS)))
print("[INFO] total test samples: {}...".format(len(testDS)))
# calculate steps per epoch for training and validation set
trainSteps = len(trainDS) // config.BATCH_SIZE
valSteps = len(testDS) // config.BATCH_SIZE
# create data loaders
trainLoader = DataLoader(trainDS, batch_size=config.BATCH_SIZE,
shuffle=True, num_workers=os.cpu_count(), pin_memory=config.PIN_MEMORY)
testLoader = DataLoader(testDS, batch_size=config.BATCH_SIZE,
num_workers=os.cpu_count(), pin_memory=config.PIN_MEMORY)
Remember, in the custom_tensor_dataset.py script, we created a custom Dataset class to cater to our exact needs. As of now, our required entities are just tensors. So, to turn them into a PyTorch DataLoader acceptedformat, we create training and testing instances of the CustomTensorDataset class, passing the images, labels, and the bounding boxes as arguments (Lines 110-113).
On Lines 118 and 119, the steps per epoch values are calculated using the length of the datasets and the batch size value set in config.py.
Finally, we pass the CustomTensorDataset instances through the DataLoader and create the train and test Data loaders (Lines 122-125).
# write the testing image paths to disk so that we can use then
# when evaluating/testing our object detector
print("[INFO] saving testing image paths...")
f = open(config.TEST_PATHS, "w")
f.write("\n".join(testPaths))
f.close()
# load the ResNet50 network
resnet = resnet50(pretrained=True)
# freeze all ResNet50 layers so they will *not* be updated during the
# training process
for param in resnet.parameters():
param.requires_grad = False
Since we’ll be using the test image paths for evaluation later, it’s written to the disk (Lines 129-132).
For the base model in our architecture, we’ll be using a pre-trained resnet50 (Line 135). However, as mentioned before, the weights of the base model will be left untouched. Hence, we freeze the weights (Lines 139 and 140).
# create our custom object detector model and flash it to the current
# device
objectDetector = ObjectDetector(resnet, len(le.classes_))
objectDetector = objectDetector.to(config.DEVICE)
# define our loss functions
classLossFunc = CrossEntropyLoss()
bboxLossFunc = MSELoss()
# initialize the optimizer, compile the model, and show the model
# summary
opt = Adam(objectDetector.parameters(), lr=config.INIT_LR)
print(objectDetector)
# initialize a dictionary to store training history
H = {"total_train_loss": [], "total_val_loss": [], "train_class_acc": [],
"val_class_acc": []}
With the model prerequisites complete, we create our custom model instance and load it to the current device (Lines 144 and 145). For the classifier loss, Cross-Entropy loss is being used, while for the Box Regressor, we are sticking to Mean squared error loss (Lines 148 and 149). On Line153, Adam is set as the Object Detector optimizer. To track the training loss and other metrics, a dictionary H is initialized on Lines 157 and 158.
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.NUM_EPOCHS)):
# set the model in training mode
objectDetector.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalValLoss = 0
# initialize the number of correct predictions in the training
# and validation step
trainCorrect = 0
valCorrect = 0
For training speed assessment, the start time is noted (Line 162). Looping over the number of epochs, we first set the object detector to training mode (Line 165) and initialize the losses and number of correct predictions (Lines 168-174).
# loop over the training set
for (images, labels, bboxes) in trainLoader:
# send the input to the device
(images, labels, bboxes) = (images.to(config.DEVICE),
labels.to(config.DEVICE), bboxes.to(config.DEVICE))
# perform a forward pass and calculate the training loss
predictions = objectDetector(images)
bboxLoss = bboxLossFunc(predictions[0], bboxes)
classLoss = classLossFunc(predictions[1], labels)
totalLoss = (config.BBOX * bboxLoss) + (config.LABELS * classLoss)
# zero out the gradients, perform the backpropagation step,
# and update the weights
opt.zero_grad()
totalLoss.backward()
opt.step()
# add the loss to the total training loss so far and
# calculate the number of correct predictions
totalTrainLoss += totalLoss
trainCorrect += (predictions[1].argmax(1) == labels).type(
torch.float).sum().item()
Looping over the train data loader, we first load the images, labels, and bounding boxes to the device in use (Lines 179 and 180). Next, we plug the images into our Object Detector and store the predictions (Line 183). Finally, since the model will give two predictions (one for the label and one for the bounding box), we index those out and calculate those losses, respectively (Lines 183-185).
The combined value of both the losses will act as the total loss for the architecture. We multiply the respective loss weights for the bounding box loss and the label loss defined in config.py to the losses and sum them up (Line 186).
With the help of PyTorch’s automatic gradient functionality, we simply reset the gradients, calculate the weights due to the loss generated, and update the parameter based on the gradient of the current step (Lines 190-192). It is important to reset the gradients because the backward function keeps accumulating the gradients altogether. Since we only want the gradient pertaining to the current step, the opt.zero_grad flushes out the previous values.
On Lines 196-198, we update the loss values and correct predictions.
# switch off autograd
with torch.no_grad():
# set the model in evaluation mode
objectDetector.eval()
# loop over the validation set
for (images, labels, bboxes) in testLoader:
# send the input to the device
(images, labels, bboxes) = (images.to(config.DEVICE),
labels.to(config.DEVICE), bboxes.to(config.DEVICE))
# make the predictions and calculate the validation loss
predictions = objectDetector(images)
bboxLoss = bboxLossFunc(predictions[0], bboxes)
classLoss = classLossFunc(predictions[1], labels)
totalLoss = (config.BBOX * bboxLoss) + \
(config.LABELS * classLoss)
totalValLoss += totalLoss
# calculate the number of correct predictions
valCorrect += (predictions[1].argmax(1) == labels).type(
torch.float).sum().item()
Moving on to the model evaluation, we’ll first turn off automatic gradients and switch to the evaluation mode of the object detector (Lines 201-203). Then, looping over the test data, we’ll repeat the same process as done in training apart from updating the weights (Lines 212-214).
The combined loss is calculated in the same manner as the training step (Lines 215 and 216). Consequently, the total loss value and correct predictions are updated (Lines 217-221).
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgValLoss = totalValLoss / valSteps
# calculate the training and validation accuracy
trainCorrect = trainCorrect / len(trainDS)
valCorrect = valCorrect / len(testDS)
# update our training history
H["total_train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["train_class_acc"].append(trainCorrect)
H["total_val_loss"].append(avgValLoss.cpu().detach().numpy())
H["val_class_acc"].append(valCorrect)
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS))
print("Train loss: {:.6f}, Train accuracy: {:.4f}".format(
avgTrainLoss, trainCorrect))
print("Val loss: {:.6f}, Val accuracy: {:.4f}".format(
avgValLoss, valCorrect))
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
After one epoch, the average batchwise training and testing losses are calculated on Lines 224 and 225. We also calculate the training and testing accuracies of the epoch using the number of correct predictions (Lines 228 and 229).
Following the calculations, all values are logged in the model history dictionary H (Lines 232-235), while the end time is calculated to see how long the training took and after exiting the loop (Line 243).
# serialize the model to disk
print("[INFO] saving object detector model...")
torch.save(objectDetector, config.MODEL_PATH)
# serialize the label encoder to disk
print("[INFO] saving label encoder...")
f = open(config.LE_PATH, "wb")
f.write(pickle.dumps(le))
f.close()
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["total_train_loss"], label="total_train_loss")
plt.plot(H["total_val_loss"], label="total_val_loss")
plt.plot(H["train_class_acc"], label="train_class_acc")
plt.plot(H["val_class_acc"], label="val_class_acc")
plt.title("Total Training Loss and Classification Accuracy on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
# save the training plot
plotPath = os.path.sep.join([config.PLOTS_PATH, "training.png"])
plt.savefig(plotPath)
Since we’ll use the object detector for inference, we save it to the disk (Line 249). We also save the label encoder that was created so that the pattern remains unchanged (Lines 253-255)
To assess the model training, we plot all the metrics stored in the model history dictionary H (Lines 258-271).
This ends the model training. Next, let’s look at how well the object detector trained!
Assessing the Object Detection Training
Since the bulk of the model will have its weights unchanged, the training shouldn’t take long. First, let’s take a look at some of the training epochs.
[INFO] training the network...
0%| | 0/20 [00:00<?,
5%|▌ | 1/20 [00:16<05:08, 16.21s/it][INFO] EPOCH: 1/20
Train loss: 0.874699, Train accuracy: 0.7608
Val loss: 0.360270, Val accuracy: 0.9902
10%|█ | 2/20 [00:31<04:46, 15.89s/it][INFO] EPOCH: 2/20
Train loss: 0.186642, Train accuracy: 0.9834
Val loss: 0.052412, Val accuracy: 1.0000
15%|█▌ | 3/20 [00:47<04:28, 15.77s/it][INFO] EPOCH: 3/20
Train loss: 0.066982, Train accuracy: 0.9883
...
85%|████████▌ | 17/20 [04:27<00:47, 15.73s/it][INFO] EPOCH: 17/20
Train loss: 0.011934, Train accuracy: 0.9975
Val loss: 0.004053, Val accuracy: 1.0000
90%|█████████ | 18/20 [04:43<00:31, 15.67s/it][INFO] EPOCH: 18/20
Train loss: 0.009135, Train accuracy: 0.9975
Val loss: 0.003720, Val accuracy: 1.0000
95%|█████████▌| 19/20 [04:58<00:15, 15.66s/it][INFO] EPOCH: 19/20
Train loss: 0.009403, Train accuracy: 0.9982
Val loss: 0.003248, Val accuracy: 1.0000
100%|██████████| 20/20 [05:14<00:00, 15.73s/it][INFO] EPOCH: 20/20
Train loss: 0.006543, Train accuracy: 0.9994
Val loss: 0.003041, Val accuracy: 1.0000
[INFO] total time taken to train the model: 314.68s
We see that the model reached astounding accuracies at 0.9994 and 1.0000 for training and validation, respectively. Let’s see the epoch-wise variation on the training plot Figure 5!
Figure 5: Training Plot.
The model reached saturation levels fairly quickly in both the training and validation values. Now it’s time to see the object detector in action!
Drawing Inference from the Object Detector
The final step in this journey is at the predict.py script. Here, we will individually loop over the test images and draw bounding boxes with our predicted values.
# USAGE
# python predict.py --input dataset/images/face/image_0131.jpg
# import the necessary packages
from pyimagesearch import config
from torchvision import transforms
import mimetypes
import argparse
import imutils
import pickle
import torch
import cv2
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--input", required=True,
help="path to input image/text file of image paths")
args = vars(ap.parse_args())
The argparse module is used to write user-friendly command line interface commands. On Lines 15-18, we construct an argument parser to help the user select the test image.
# determine the input file type, but assume that we're working with
# single input image
filetype = mimetypes.guess_type(args["input"])[0]
imagePaths = [args["input"]]
# if the file type is a text file, then we need to process *multiple*
# images
if "text/plain" == filetype:
# load the image paths in our testing file
imagePaths = open(args["input"]).read().strip().split("\n")
We follow the argument parsing with steps to deal with any kind of input the user proceeds to give. On Lines 22 and 23, the imagePaths variable is set to deal with a single input image, while on Lines 27-29, the event of multiple images is dealt with.
# load our object detector, set it evaluation mode, and label
# encoder from disk
print("[INFO] loading object detector...")
model = torch.load(config.MODEL_PATH).to(config.DEVICE)
model.eval()
le = pickle.loads(open(config.LE_PATH, "rb").read())
# define normalization transforms
transforms = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize(mean=config.MEAN, std=config.STD)
])
The model which was trained using the train.py script is called for evaluation (Lines 34 and 35). Similarly, the label encoder stored using the aforementioned script is loaded (Line 36). Since we’ll be needing to process the data again, another torchvision.transforms instance is created, having the same arguments as the ones used during training.
# loop over the images that we'll be testing using our bounding box
# regression model
for imagePath in imagePaths:
# load the image, copy it, swap its colors channels, resize it, and
# bring its channel dimension forward
image = cv2.imread(imagePath)
orig = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (224, 224))
image = image.transpose((2, 0, 1))
# convert image to PyTorch tensor, normalize it, flash it to the
# current device, and add a batch dimension
image = torch.from_numpy(image)
image = transforms(image).to(config.DEVICE)
image = image.unsqueeze(0)
Looping over the test images, we read the image and apply some preprocessing to it (Lines 50-54). This is done since our image needs to be plugged into the object detector again.
We proceed to turn the image into a tensor, apply the torchvision.transforms instance to it, and add a batching dimension to it (Lines 58-60). Our test image is now ready to be plugged into the object detector.
# predict the bounding box of the object along with the class
# label
(boxPreds, labelPreds) = model(image)
(startX, startY, endX, endY) = boxPreds[0]
# determine the class label with the largest predicted
# probability
labelPreds = torch.nn.Softmax(dim=-1)(labelPreds)
i = labelPreds.argmax(dim=-1).cpu()
label = le.inverse_transform(i)[0]
First, the predictions from the model are obtained (Line 64). We proceed to unpack the bounding box values from the boxPreds variable (Line 65).
A simple softmax function on the Label prediction will give us a better picture of the values corresponding to the classes. For that purpose, we use PyTorch’s own torch.nn.Softmax on Line 69. Isolating the index with argmax, we plug it in the Label encoder le and use inverse_transform (Index to value) to get the name of the label (Lines 69-71).
# resize the original image such that it fits on our screen, and
# grab its dimensions
orig = imutils.resize(orig, width=600)
(h, w) = orig.shape[:2]
# scale the predicted bounding box coordinates based on the image
# dimensions
startX = int(startX * w)
startY = int(startY * h)
endX = int(endX * w)
endY = int(endY * h)
# draw the predicted bounding box and class label on the image
y = startY - 10 if startY - 10 > 10 else startY + 10
cv2.putText(orig, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX,
0.65, (0, 255, 0), 2)
cv2.rectangle(orig, (startX, startY), (endX, endY),
(0, 255, 0), 2)
# show the output image
cv2.imshow("Output", orig)
cv2.waitKey(0)
On Line 75, we have resized the original image to fit our screen. The height and width of the resized image are then stored, to scale the predicted bounding box values based on the image (Lines 76-83). This is done because we had scaled down the annotations to the range 0 and 1 before fitting them to the model. Hence, all outputs would have to be scaled up for display purposes.
While displaying the bounding box, the label name will also be shown on top of it. For that purpose, we set up the y-axis value for our text on Line 86. Using OpenCV’s putText function, we set up the label displayed on the image (Lines 87 and 88).
Finally, we use OpenCV’s rectangle method to create the bounding box on the image (Lines 89 and 90). Since we have the starting x-axis, starting y-axis, ending x-axis, and ending y-axis values, it’s very easy to create a rectangle from them. This rectangle will surround our object.
This concludes our inference script. Let’s take a look at the results!
Object Detection in Action
Let’s see how our object detector fared, using one image from each class. We first use an image of an airplane (Figure 6), followed by an image under faces (Figure 7), and an image belonging to the motorcycle class (Figure 8).
Figure 6: An airplane, correctly predicted as an airplane.
Figure 7: A man, correctly categorized under face.
Figure 8: A Harley, correctly predicted as a motorcycle.
As it turns out, the accuracy values of our model weren’t lying. Not only did our model correctly guess the label, but the bounding boxes produced are also almost perfect!
With such precise detection and results, we can all agree that our little project was a success, can’t we?
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
While writing this Object Detection tutorial, I realized a few things in retrospect.
To be very honest, I had never liked using pre-trained models for my projects. It would feel that my work isn’t my work anymore. Obviously, that turned out to be a stupid notion, solidified with the fact that my first one-shot face classifier said that my best friend and I were the same people (believe me, we don’t look remotely similar).
I would say this tutorial served as a beautiful example of what happens when you have a well-trained feature extractor. Not only did we save time, but the end results were also brilliant. Take Figures 6 and 8 as examples. The predicted bounding boxes have minimal error.
Of course, this doesn’t mean there isn’t room for improvement. In Figure 7, the image has many elements, yet the object detector has managed to capture the general area of the object. However, it could have been more compact. We urge you to tinker around with the parameters to see if your results are better!
That being said, Object Detection really plays a vital role in our world today. Automated Traffic, face detections, self-driving cars are just a few of the real-world applications where Object Detection thrives. Each year, algorithms are designed to make the process faster and more compact. We have reached a stage where algorithms can concurrently detect all objects inside scenes of a video! I hope this tutorial has piqued your curiosity toward uncovering the intricacies of this domain.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In today’s tutorial, we will be looking at image segmentation and building our own segmentation model from scratch, based on the popular U-Net architecture.
This lesson is the last of a 3-part series on Advanced PyTorch Techniques:
U-Net: Training Image Segmentation Models in PyTorch (today’s tutorial)
The computer vision community has devised various tasks, such as image classification, object detection, localization, etc., for understanding images and their content. These tasks give us a high-level understanding of the object class and its location in the image.
In Image Segmentation, we go a step further and ask our model to classify each pixel in our image to the object category it represents. This can be viewed as pixel-level image classification and is a much harder task than simple image classification, detection, or localization. Our model must automatically determine all objects and their precise location and boundaries at a pixel level in the image.
Thus image segmentation provides an intricate understanding of the image and is widely used in medical imaging, autonomous driving, robotic manipulation, etc.
To learn how to train a U-Net-based segmentation model in PyTorch, just keep reading.
U-Net: Training Image Segmentation Models in PyTorch
Throughout this tutorial, we will be looking at image segmentation and building and training a segmentation model in PyTorch. We will focus on a very successful architecture, U-Net, which was originally proposed for medical image segmentation. Furthermore, we will understand the salient features of the U-Net model, which make it an apt choice for the task of image segmentation.
Specifically, we will discuss the following, in detail, in this tutorial:
The architectural details of U-Net that make it a powerful segmentation model
Creating a custom PyTorch Dataset for our image segmentation task
Training the U-Net segmentation model from scratch
Making predictions on novel images with our trained U-Net model
U-Net Architecture Overview
The U-Net architecture (see Figure 1) follows an encoder-decoder cascade structure, where the encoder gradually compresses information into a lower-dimensional representation. Then the decoder decodes this information back to the original image dimension. Owing to this, the architecture gets an overall U-shape, which leads to the name U-Net.
Figure 1: Architecture of the U-Net Image Segmentation Model.
In addition to this, one of the salient features of the U-Net architecture is the skip connections (shown with grey arrows in Figure 1), which enable the flow of information from the encoder side to the decoder side, enabling the model to make better predictions.
Specifically, as we go deeper, the encoder processes information at higher levels of abstraction. This simply means that at the initial layers, the feature maps of the encoder capture low-level details about object texture and edges, and as we gradually go deeper, the features capture high-level information about object shapes and categories.
It is worth noting that to segment objects in an image, both low-level and high-level information is important. For example, a change in texture between objects and edge information can help determine the boundaries of various objects. On the other hand, high-level information about the class to which an object shape belongs can help segment corresponding pixels to correct object classes they represent.
Thus, to use both these pieces of information during predictions, the U-Net architecture implements skip connections between the encoder and decoder. This enables us to take intermediate feature map information from various depths on the encoder side and concatenate it at the decoder side to process and facilitate better predictions.
We will look at the U-Net model in further detail and build it from scratch in PyTorch later in this tutorial.
Our TGS Salt Segmentation Dataset
For this tutorial, we will use the TGS Salt Segmentation dataset. The dataset was introduced as part of the TGS Salt Identification Challenge on Kaggle.
Practically, it is difficult to accurately identify the location of salt deposits from images even with the help of human experts. Therefore, the challenge required participants to help experts precisely identify the locations of salt deposits from seismic images of the earth sub-surface. This is practically important since incorrect estimates of salt presence can lead companies to set up drillers at the wrong locations for mining, leading to a waste of time and resources.
We use a sub-part of this dataset which comprises 4000 images of size 101×101 pixels, taken from various locations on earth. Here, each pixel corresponds to either salt deposit or sediment. In addition to images, we are also provided with the ground-truth pixel-level segmentation masks of the same dimension as the image (see Figure 2).
Figure 2: Sample images and corresponding ground-truth segmentation masks from our TGS Salt Segmentation dataset.
The white pixels in the masks represent salt deposits, and the black pixels represent sediment. We aim to correctly predict the pixels that correspond to salt deposits in the images. Thus, we have a binary classification problem where we have to classify each pixel into one of the two classes, Class 1: Salt or Class 2: Not Salt (or, in other words, sediment).
Configuring Your Development Environment
To follow this guide, you need to have the PyTorch deep learning library, matplotlib, OpenCV, imutils, scikit-learn, and tqdm packages installed on your system.
Luckily, these packages are extremely easy to install using pip:
If you need help configuring your development environment for PyTorch, I highly recommend that you read the PyTorch documentation — PyTorch’s documentation is comprehensive and will have you up and running quickly.
Having Problems Configuring Your Development Environment?
Figure 3: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project Structure
We first need to review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
From there, take a look at the directory structure:
The dataset folder stores the TGS Salt Segmentation dataset we will use for training our segmentation model.
Furthermore, we will be storing our trained model and training loss plots in the output folder.
The config.pyfile in the pyimagesearch folder stores our code’s parameters, initial settings, and configurations.
On the other hand, the dataset.py file consists of our custom segmentation dataset class, and the model.py file contains the definition of our U-Net model.
Finally, our model training and prediction codes are defined in train.py and predict.py files, respectively.
Creating Our Configuration File
We start by discussing the config.py file, which stores configurations and parameter settings used in the tutorial.
# import the necessary packages
import torch
import os
# base path of the dataset
DATASET_PATH = os.path.join("dataset", "train")
# define the path to the images and masks dataset
IMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "images")
MASK_DATASET_PATH = os.path.join(DATASET_PATH, "masks")
# define the test split
TEST_SPLIT = 0.15
# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False
We start by importing the necessary packages on Lines 2 and 3. Then, we define the path for our dataset (i.e., DATASET_PATH) on Line 6 and the paths for images and masks within the dataset folder (i.e., IMAGE_DATASET_PATH and MASK_DATASET_PATH) on Lines 9 and 10.
On Line 13, we define the fraction of the dataset we will keep aside for the test set. Then, on Line 16, we define the DEVICE parameter, which determines based on availability, whether we will be using a GPU or CPU for training our segmentation model. In this case, we are using a CUDA-enabled GPU device, and we set the PIN_MEMORY parameter to True on Line 19.
# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3
# initialize learning rate, number of epochs to train for, and the
# batch size
INIT_LR = 0.001
NUM_EPOCHS = 40
BATCH_SIZE = 64
# define the input image dimensions
INPUT_IMAGE_WIDTH = 128
INPUT_IMAGE_HEIGHT = 128
# define threshold to filter weak predictions
THRESHOLD = 0.5
# define the path to the base output directory
BASE_OUTPUT = "output"
# define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = os.path.join(BASE_OUTPUT, "unet_tgs_salt.pth")
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])
TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])
Next, we define the NUM_CHANNELS, NUM_CLASSES, and NUM_LEVELS parameters on Lines 23-25, which we will discuss in more detail later in the tutorial. Finally, on Lines 29-31, we define the training parameters such as initial learning rate (i.e., INIT_LR), the total number of epochs (i.e., NUM_EPOCHS), and batch size (i.e., BATCH_SIZE).
On Lines 34 and 35, we also define input image dimensions to which our images should be resized for our model to process them. We further define a threshold parameter on Line 38, which will later help us classify the pixels into one of the two classes in our binary classification-based segmentation task.
Finally, we define the path to our output folder (i.e., BASE_OUTPUT) on Line 41 and the corresponding paths to the trained model weights, training plots, and test images within the output folder on Lines 45-47.
Creating Our Custom Segmentation Dataset Class
Now that we have defined our initial configurations and parameters, we are ready to understand the custom dataset class we will be using for our segmentation dataset.
Let’s open the dataset.py file from the pyimagesearch folder in our project directory.
# import the necessary packages
from torch.utils.data import Dataset
import cv2
class SegmentationDataset(Dataset):
def __init__(self, imagePaths, maskPaths, transforms):
# store the image and mask filepaths, and augmentation
# transforms
self.imagePaths = imagePaths
self.maskPaths = maskPaths
self.transforms = transforms
def __len__(self):
# return the number of total samples contained in the dataset
return len(self.imagePaths)
def __getitem__(self, idx):
# grab the image path from the current index
imagePath = self.imagePaths[idx]
# load the image from disk, swap its channels from BGR to RGB,
# and read the associated mask from disk in grayscale mode
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.maskPaths[idx], 0)
# check to see if we are applying any transformations
if self.transforms is not None:
# apply the transformations to both image and its mask
image = self.transforms(image)
mask = self.transforms(mask)
# return a tuple of the image and its mask
return (image, mask)
We begin by importing the Dataset class from the torch.utils.data module on Line 2. This is important since all PyTorch datasets must inherit from this base dataset class. Furthermore, on Line 3, we import the OpenCV package, which will enable us to use its image handling functionalities.
We are now ready to define our own custom segmentation dataset. Each PyTorch dataset is required to inherit from Dataset class (Line 5) and should have a __len__ (Lines 13-15) and a __getitem__ (Lines 17-34) method. We discuss each of these methods below.
We start by defining our initializer constructor, that is, the __init__ method on Lines 6-11. The method takes as input the list of image paths (i.e., imagePaths) of our dataset, the corresponding ground-truth masks (i.e., maskPaths), and the set of transformations (i.e., transforms) we want to apply to our input images (Line 6).
On Lines 9-11, we initialize the attributes of our SegmentationDataset class with the parameters input to the __init__ constructor.
Next, we define the __len__ method, which returns the total number of image paths in our dataset, as shown on Line 15.
The task of the __getitem__ method is to take an index as input (Line 17) and returns the corresponding sample from the dataset. On Line 19, we simply grab the image path at the idx index in our list of input image paths. Then, we load the image using OpenCV (Line 23). By default, OpenCV loads an image in the BGR format, which we convert to the RGB format as shown on Line 24. We also load the corresponding ground-truth segmentation mask in grayscale mode on Line 25.
Finally, we check for input transformations that we want to apply to our dataset images (Line 28) and transform both the image and mask with the required transforms on Lines 30 and 31, respectively. This is important since we want our image and ground-truth mask to correspond and have the same dimension. On Line 34, we return the tuple containing the image and its corresponding mask (i.e., (image, mask)) as shown.
This completes the definition of our custom Segmentation dataset. Next, we will discuss the implementation of the U-Net architecture.
Building Our U-Net Model in PyTorch
It is time to look at our U-Net model architecture in detail and build it from scratch in PyTorch.
We open our model.py file from the pyimagesearch folder in our project directory and get started.
# import the necessary packages
from . import config
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch
On Lines 2-11, we import the necessary layers, modules, and activation functions from PyTorch, which we will use to build our model.
Overall, our U-Net model will consist of an Encoder class and a Decoder class. The encoder will gradually reduce the spatial dimension to compress information. Furthermore, it will increase the number of channels, that is, the number of feature maps at each stage, enabling our model to capture different details or features in our image. On the other hand, the decoder will take the final encoder representation and gradually increase the spatial dimension and reduce the number of channels to finally output a segmentation mask of the same spatial dimension as the input image.
Next, we define a Block module as the building unit of our encoder and decoder architecture. It is worth noting that all models or model sub-parts that we define are required to inherit from the PyTorch Module class, which is the parent class in PyTorch for all neural network modules.
class Block(Module):
def __init__(self, inChannels, outChannels):
super().__init__()
# store the convolution and RELU layers
self.conv1 = Conv2d(inChannels, outChannels, 3)
self.relu = ReLU()
self.conv2 = Conv2d(outChannels, outChannels, 3)
def forward(self, x):
# apply CONV => RELU => CONV block to the inputs and return it
return self.conv2(self.relu(self.conv1(x)))
We start by defining our Block class on Lines 13-23. The function of this module is to take an input feature map with the inChannels number of channels, apply two convolution operations with a ReLU activation between them and return the output feature map with the outChannels channels.
The __init__ constructor takes as input two parameters, inChannels and outChannels (Line 14), which determine the number of channels in the input feature map and the output feature map, respectively.
We initialize the two convolution layers (i.e., self.conv1 and self.conv2) and a ReLU activation on Lines 17-19. On Lines 21-23, we define the forward function which takes as input our feature map x, applies self.conv1 =>self.relu=> self.conv2 sequence of operations and returns the output feature map.
class Encoder(Module):
def __init__(self, channels=(3, 16, 32, 64)):
super().__init__()
# store the encoder blocks and maxpooling layer
self.encBlocks = ModuleList(
[Block(channels[i], channels[i + 1])
for i in range(len(channels) - 1)])
self.pool = MaxPool2d(2)
def forward(self, x):
# initialize an empty list to store the intermediate outputs
blockOutputs = []
# loop through the encoder blocks
for block in self.encBlocks:
# pass the inputs through the current encoder block, store
# the outputs, and then apply maxpooling on the output
x = block(x)
blockOutputs.append(x)
x = self.pool(x)
# return the list containing the intermediate outputs
return blockOutputs
Next, we define our Encoder class on Lines 25-47. The class constructor (i.e., the __init__ method) takes as input a tuple (i.e., channels) of channel dimensions (Line 26). Note that the first value denotes the number of channels in our input image, and the subsequent numbers gradually double the channel dimension.
We start by initializing a list of blocks for the encoder (i.e., self.encBlocks) with the help of PyTorch’s ModuleList functionality on Lines 29-31. Each Block takes the input channels of the previous block and doubles the channels in the output feature map. We also initialize a MaxPool2d() layer, which reduces the spatial dimension (i.e., height and width) of the feature maps by a factor of 2.
Finally, we define the forward function for our encoder on Lines 34-47. The function takes as input an image x as shown on Line 34. On Line 36, we initialize an empty blockOutputs list, storing the intermediate outputs from the blocks of our encoder. Note that this will enable us to later pass these outputs to that decoder where they can be processed with the decoder feature maps.
On Lines 39-44, we loop through each block in our encoder, process the input feature map through the block (Line 42), and add the output of the block to our blockOutputs list. We then apply the max pool operation on our block output (Line 44). This is done for each block in the encoder.
Finally, we return our blockOutputs list on Line 47.
class Decoder(Module):
def __init__(self, channels=(64, 32, 16)):
super().__init__()
# initialize the number of channels, upsampler blocks, and
# decoder blocks
self.channels = channels
self.upconvs = ModuleList(
[ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
for i in range(len(channels) - 1)])
self.dec_blocks = ModuleList(
[Block(channels[i], channels[i + 1])
for i in range(len(channels) - 1)])
def forward(self, x, encFeatures):
# loop through the number of channels
for i in range(len(self.channels) - 1):
# pass the inputs through the upsampler blocks
x = self.upconvs[i](x)
# crop the current features from the encoder blocks,
# concatenate them with the current upsampled features,
# and pass the concatenated output through the current
# decoder block
encFeat = self.crop(encFeatures[i], x)
x = torch.cat([x, encFeat], dim=1)
x = self.dec_blocks[i](x)
# return the final decoder output
return x
def crop(self, encFeatures, x):
# grab the dimensions of the inputs, and crop the encoder
# features to match the dimensions
(_, _, H, W) = x.shape
encFeatures = CenterCrop([H, W])(encFeatures)
# return the cropped features
return encFeatures
Now we define our Decoder class (Lines 50-87). Similar to the encoder definition, the decoder __init__ method takes as input a tuple (i.e., channels) of channel dimensions (Line 51). Note that the difference here, when compared with the encoder side, is that the channels gradually decrease by a factor of 2 instead of increasing.
We initialize the number of channels on Line 55. Furthermore, on Lines 56-58, we define a list of upsampling blocks (i.e., self.upconvs) that use the ConvTranspose2d layer to upsample the spatial dimension (i.e., height and width) of the feature maps by a factor of 2. In addition, the layer also reduces the number of channels by a factor of 2.
Finally, we initialize a list of blocks for the decoder (i.e., self.dec_Blocks) similar to that on the encoder side.
On Lines 63-75, we define the forward function, which takes as input our feature map x and the list of intermediate outputs from the encoder (i.e., encFeatures). Starting on Line 65, we loop through the number of channels and perform the following operations:
First, we upsample the input to our decoder (i.e., x) by passing it through our i-th upsampling block (Line 67)
Since we have to concatenate (along the channel dimension) the i-th intermediate feature map from the encoder (i.e., encFeatures[i]) with our current output x from the upsampling block, we need to ensure that the spatial dimensions of encFeatures[i] and x match. To accomplish this, we use our crop function on Line 73.
Next, we concatenate our cropped encoder feature maps (i.e., encFeat) with our current upsampled feature map x, along the channel dimension on Line 74
Finally, we pass the concatenated output through our i-th decoder block (Line 75)
After the completion of the loop, we return the final decoder output on Line 78.
On Lines 80-87, we define our crop function which takes an intermediate feature map from the encoder (i.e., encFeatures) and a feature map output from the decoder (i.e., x) and spatially crops the former to the dimension of the latter.
To do this, we first grab the spatial dimensions of x (i.e., height H and width W) on Line 83. Then, we crop encFeatures to the spatial dimension [H, W] using the CenterCrop function (Line 84) and finally return the cropped output on Line 87.
Now that we have defined the sub-modules that make up our U-Net model, we are ready to build our U-Net model class.
class UNet(Module):
def __init__(self, encChannels=(3, 16, 32, 64),
decChannels=(64, 32, 16),
nbClasses=1, retainDim=True,
outSize=(config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)):
super().__init__()
# initialize the encoder and decoder
self.encoder = Encoder(encChannels)
self.decoder = Decoder(decChannels)
# initialize the regression head and store the class variables
self.head = Conv2d(decChannels[-1], nbClasses, 1)
self.retainDim = retainDim
self.outSize = outSize
We start by defining the __init__ constructor method (Lines 91-103). It takes the following parameters as input:
encChannels: The tuple defines the gradual increase in channel dimension as our input passes through the encoder. We start with 3 channels (i.e., RGB) and subsequently double the number of channels.
decChannels: The tuple defines the gradual decrease in channel dimension as our input passes through the decoder. We reduce the channels by a factor of 2 at every step.
nbClasses: This defines the number of segmentation classes where we have to classify each pixel. This usually corresponds to the number of channels in our output segmentation map, where we have one channel for each class.
Since we are working with two classes (i.e., binary classification), we keep a single channel and use thresholding for classification, as we will discuss later.
retainDim: This indicates whether we want to retain the original output dimension.
outSize: This determines the spatial dimensions of the output segmentation map. We set this to the same dimension as our input image (i.e., (config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)).
On Lines 97 and 98, we initialize our encoder and decoder networks. Furthermore, we initialize a convolution head through which will later take our decoder output as input and output our segmentation map with nbClasses number of channels (Line 101).
We also initialize the self.retainDim and self.outSize attributes on Lines 102 and 103.
def forward(self, x):
# grab the features from the encoder
encFeatures = self.encoder(x)
# pass the encoder features through decoder making sure that
# their dimensions are suited for concatenation
decFeatures = self.decoder(encFeatures[::-1][0],
encFeatures[::-1][1:])
# pass the decoder features through the regression head to
# obtain the segmentation mask
map = self.head(decFeatures)
# check to see if we are retaining the original output
# dimensions and if so, then resize the output to match them
if self.retainDim:
map = F.interpolate(map, self.outSize)
# return the segmentation map
return map
Finally, we are ready to discuss our U-Net model’s forward function (Lines 105-124).
We begin by passing our input x through the encoder. This outputs the list of encoder feature maps (i.e., encFeatures) as shown on Line 107. Note that the encFeatures list contains all the feature maps starting from the first encoder block output to the last, as discussed previously. Therefore, we can reverse the order of feature maps in this list: encFeatures[::-1].
Now the encFeatures[::-1] list contains the feature map outputs in reverse order (i.e., from the last to the first encoder block). Note that this is important since, on the decoder side, we will be utilizing the encoder feature maps starting from the last encoder block output to the first.
Next, we pass the output of the final encoder block (i.e., encFeatures[::-1][0]) and the feature map outputs of all intermediate encoder blocks (i.e., encFeatures[::-1][1:]) to the decoder on Line 111. The output of the decoder is stored as decFeatures.
We pass the decoder output to our convolution head (Line 116) to obtain the segmentation mask.
Finally, we check if the self.retainDim attribute is True (Line 120). If yes, we interpolate the final segmentation map to the output size defined by self.outSize (Line 121). We return our final segmentation map on Line 124.
This completes the implementation of our U-Net model. Next, we will look at the training procedure for our segmentation pipeline.
Training Our Segmentation Model
Now that we have implemented our dataset class and model architecture, we are ready to construct and train our segmentation pipeline in PyTorch. Let’s open the train.py file from our project directory.
Specifically, we will be looking at the following in detail:
Structuring the data-loading pipeline
Initializing the model and training parameters
Defining the training loop
Visualizing the training and test loss curves
# USAGE
# python train.py
# import the necessary packages
from pyimagesearch.dataset import SegmentationDataset
from pyimagesearch.model import UNet
from pyimagesearch import config
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os
We begin by importing our custom-defined SegmentationDataset class and the UNet model on Lines 5 and 6. Next, we import our config file on Line 7.
Since our salt segmentation task is a pixel-level binary classification problem, we will be using binary cross-entropy loss to train our model. On Line 8, we import the binary cross-entropy loss function (i.e., BCEWithLogitsLoss) from the PyTorch nn module. In addition to this, we import the Adam optimizer from the PyTorch optim module, which we will be using to train our network (Line 9).
Next, on Line 11, we import the in-built train_test_split function from the sklearn library, enabling us to split our dataset into training and testing sets. Furthermore, we import the transforms module from torchvision on Line 12 to apply image transformations on our input images.
Finally, we import other useful packages for handling our file system, keeping track of progress during training, timing our training process, and plotting loss curves on Lines 13-18.
Once we have imported all necessary packages, we will load our data and structure the data loading pipeline.
# load the image and mask filepaths in a sorted manner
imagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH)))
maskPaths = sorted(list(paths.list_images(config.MASK_DATASET_PATH)))
# partition the data into training and testing splits using 85% of
# the data for training and the remaining 15% for testing
split = train_test_split(imagePaths, maskPaths,
test_size=config.TEST_SPLIT, random_state=42)
# unpack the data split
(trainImages, testImages) = split[:2]
(trainMasks, testMasks) = split[2:]
# write the testing image paths to disk so that we can use then
# when evaluating/testing our model
print("[INFO] saving testing image paths...")
f = open(config.TEST_PATHS, "w")
f.write("\n".join(testImages))
f.close()
On Lines 21 and 22, we first define two lists (i.e., imagePaths and maskPaths) that store the paths of all images and their corresponding segmentation masks, respectively.
We then partition our dataset into a training and test set with the help of scikit-learn’s train_test_split on Line 26. Note that this function takes as input a sequence of lists (here, imagePaths and maskPaths) and simultaneously returns the training and test set images and corresponding training and test set masks which we unpack on Lines 30 and 31.
We store the paths in the testImages list in the test folder path defined by config.TEST_PATHS on Line 36.
Now, we are ready to set up our data loading pipeline.
# define transformations
transforms = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((config.INPUT_IMAGE_HEIGHT,
config.INPUT_IMAGE_WIDTH)),
transforms.ToTensor()])
# create the train and test datasets
trainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks,
transforms=transforms)
testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks,
transforms=transforms)
print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")
# create the training and test data loaders
trainLoader = DataLoader(trainDS, shuffle=True,
batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
num_workers=os.cpu_count())
testLoader = DataLoader(testDS, shuffle=False,
batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
num_workers=os.cpu_count())
We first define the transformations that we want to apply while loading our input images and consolidate them with the help of the Compose function on Lines 41-44. Our transformations include:
ToPILImage(): it enables us to convert our input images to PIL image format. Note that this is necessary since we used OpenCV to load images in our custom dataset, but PyTorch expects the input image samples to be in PIL format.
Resize(): allows us to resize our images to a particular input dimension (i.e., config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH) that our model can accept
ToTensor(): enables us to convert input images to PyTorch tensors and convert the input PIL Image, which is originally in the range from [0, 255], to [0, 1].
Finally, we pass the train and test images and corresponding masks to our custom SegmentationDataset to create the training dataset (i.e., trainDS) and test dataset (i.e., testDS) on Lines 47-50. Note that we can simply pass the transforms defined on Line 41 to our custom PyTorch dataset to apply these transformations while loading the images automatically.
We can now print the number of samples in trainDS and testDS with the help of the len() method, as shown in Lines 51 and 52.
On Lines 55-60, we create our training dataloader (i.e., trainLoader) and test dataloader (i.e., testLoader) directly by passing our train dataset and test dataset to the Pytorch DataLoader class. We keep the shuffle parameter True in the train dataloader since we want samples from all classes to be uniformly present in a batch which is important for optimal learning and convergence of batch gradient-based optimization approaches.
Now that we have structured and defined our data loading pipeline, we will initialize our U-Net model and the training parameters.
# initialize our UNet model
unet = UNet().to(config.DEVICE)
# initialize loss function and optimizer
lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=config.INIT_LR)
# calculate steps per epoch for training and test set
trainSteps = len(trainDS) // config.BATCH_SIZE
testSteps = len(testDS) // config.BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}
We start by defining our UNet() model on Line 63. Note that the to() function takes as input our config.DEVICE and registers our model and its parameters on the device mentioned.
On Lines 66 and 67, we define our loss function and optimizer, which we will use to train our segmentation model. The Adam optimizer class takes as input the parameters of our model (i.e., unet.parameters()) and the learning rate (i.e., config.INIT_LR) we will be using to train our model.
We then define the number of steps required to iterate over our entire train and test set, that is, trainSteps and testSteps, on Lines 70 and 71. Given that the dataloader provides our model config.BATCH_SIZE number of samples to process at a time, the number of steps required to iterate over the entire dataset (i.e., train or test set) can be calculated by dividing the total samples in the dataset by the batch size.
We also create an empty dictionary, H, on Line 74, that we will use to keep track of our training and test loss history.
Finally, we are in good shape to start understanding our training loop.
# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.NUM_EPOCHS)):
# set the model in training mode
unet.train()
# initialize the total training and validation loss
totalTrainLoss = 0
totalTestLoss = 0
# loop over the training set
for (i, (x, y)) in enumerate(trainLoader):
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# perform a forward pass and calculate the training loss
pred = unet(x)
loss = lossFunc(pred, y)
# first, zero out any previously accumulated gradients, then
# perform backpropagation, and then update model parameters
opt.zero_grad()
loss.backward()
opt.step()
# add the loss to the total training loss so far
totalTrainLoss += loss
# switch off autograd
with torch.no_grad():
# set the model in evaluation mode
unet.eval()
# loop over the validation set
for (x, y) in testLoader:
# send the input to the device
(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))
# make the predictions and calculate the validation loss
pred = unet(x)
totalTestLoss += lossFunc(pred, y)
# calculate the average training and validation loss
avgTrainLoss = totalTrainLoss / trainSteps
avgTestLoss = totalTestLoss / testSteps
# update our training history
H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
# print the model training and validation information
print("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS))
print("Train loss: {:.6f}, Test loss: {:.4f}".format(
avgTrainLoss, avgTestLoss))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
endTime - startTime))
To time our training process, we use the time() function on Line 78. This function outputs the time when it is called. Thus, we can call it once at the start and once at the end of our training process and subtract the two outputs to get the time elapsed.
We iterate for config.NUM_EPOCHS in the training loop, as shown on Line 79. Before we start training, it is important to set our model to train mode, as we see on Line 81. This directs the PyTorch engine to track our computations and gradients and build a computational graph to backpropagate later.
We initialize variables totalTrainLoss and totalTestLoss on Lines 84 and 85 to track our losses in the given epoch. Next, on Line 88, we iterate over our trainLoader dataloader, which provides a batch of samples at a time. The training loop, as shown on Lines 88-103, comprises of the following steps:
First, on Line 90, we move our data samples (i.e., x and y) to the device we are training our model on, defined by config.DEVICE
We then pass our input image sample x through our unet model on Line 93 and get the output prediction (i.e., pred)
On Line 94, we compute the loss between the model prediction, pred and our ground-truth label y
On Lines 98-100, we backpropagate our loss through the model and update the parameters
This is executed with the help of three simple steps; we start by clearing all accumulated gradients from previous steps on Line 98. Next, we call the backward method on our computed loss function as shown on Line 99. This directs PyTorch to compute gradients of our loss w.r.t. all variables involved in the computation graph. Finally, we call opt.step() to update our model parameters as shown on Line 100.
In the end, Line 103 enables us to keep track of our training loss by adding the loss for the step to the totalTrainLoss variable, which accumulates the training loss for all samples.
This process is repeated until iterated through all dataset samples once (i.e., completed one epoch).
Once we have processed our entire training set, we would want to evaluate our model on the test set. This is helpful since it allows us to monitor the test loss and ensure that our model is not overfitting to the training set.
While evaluating our model on the test set, we do not track gradients since we will not be learning or backpropagating. Thus we can switch off the gradient computation with the help of torch.no_grad() and freeze the model weights, as shown on Line 106. This directs the PyTorch engine not to calculate and save gradients, saving memory and compute during evaluation.
We set our model to evaluation mode by calling the eval() function on Line 108. Then, we iterate through the test set samples and compute the predictions of our model on test data (Line 116). The test loss is then added to the totalTestLoss, which accumulates the test loss for the entire test set.
We then obtain the average training loss and test loss over all steps, that is, avgTrainLoss and avgTestLoss on Lines 120 and 121, and store them on Lines 124 and 125, to our dictionary, H, which we had created in the beginning to keep track of our losses.
Finally, we print the current epoch statistics, including train and test losses on Lines 128-130. This brings us to the end of one epoch, consisting of one full cycle of training on our train set and evaluation on our test set. This entire process is repeated config.NUM_EPOCHS times until our model converges.
On Lines 133 and 134, we note the end time of our training loop and subtract endTime from startTime (which we had initialized at the beginning of training) to get the total time elapsed during our network training.
# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(config.PLOT_PATH)
# serialize the model to disk
torch.save(unet, config.MODEL_PATH)
Next, we use the pyplot package of matplotlib to visualize and save our training and test loss curves on Lines 138-146. We can do this by simply passing the train_loss and test_loss keys of our loss history dictionary, H, to the plot function as shown on Lines 140 and 141. Finally, we set the title and legends of our plots (Lines 142-145) and save our visualizations on Line 146.
Finally, on Lines 149, we save the weights of our trained U-Net model with the help of the torch.save() function, which takes our trained unet model and the config.MODEL_PATH as input where we want our model to be saved.
Once our model is trained, we will see a loss trajectory plot similar to the one shown in Figure 4. Notice that train_loss gradually reduces over epochs and slowly converges. Furthermore, we see that test_loss also consistently reduces with train_loss following similar trend and values, implying our model generalizes well and is not overfitting to the training set.
Figure 4: Train and Test Loss trajectory for our U-Net segmentation model.
Using Our Trained U-Net Model for Prediction
Once we have trained and saved our segmentation model, we are ready to see it in action and use it for segmentation tasks.
Open the predict.py file from our project directory.
# USAGE
# python predict.py
# import the necessary packages
from pyimagesearch import config
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2
import os
def prepare_plot(origImage, origMask, predMask):
# initialize our figure
figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
# plot the original image, its mask, and the predicted mask
ax[0].imshow(origImage)
ax[1].imshow(origMask)
ax[2].imshow(predMask)
# set the titles of the subplots
ax[0].set_title("Image")
ax[1].set_title("Original Mask")
ax[2].set_title("Predicted Mask")
# set the layout of the figure and display it
figure.tight_layout()
figure.show()
We import the necessary packages and modules as always on Lines 5-10.
To use our segmentation model for prediction, we will need a function that can take our trained model and test images, predict the output segmentation mask and finally, visualize the output predictions.
To this end, we start by defining the prepare_plot function to help us to visualize our model predictions.
This function takes as input an image, its ground-truth mask, and the segmentation output predicted by our model, that is, origImage, origMask, and predMask (Line 12) and creates a grid with a single row and three columns (Line 14) to display them (Lines 17-19).
Finally, Lines 22-24 set titles for our plots, displaying them on Lines 27 and 28.
def make_predictions(model, imagePath):
# set model to evaluation mode
model.eval()
# turn off gradient tracking
with torch.no_grad():
# load the image from disk, swap its color channels, cast it
# to float data type, and scale its pixel values
image = cv2.imread(imagePath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image.astype("float32") / 255.0
# resize the image and make a copy of it for visualization
image = cv2.resize(image, (128, 128))
orig = image.copy()
# find the filename and generate the path to ground truth
# mask
filename = imagePath.split(os.path.sep)[-1]
groundTruthPath = os.path.join(config.MASK_DATASET_PATH,
filename)
# load the ground-truth segmentation mask in grayscale mode
# and resize it
gtMask = cv2.imread(groundTruthPath, 0)
gtMask = cv2.resize(gtMask, (config.INPUT_IMAGE_HEIGHT,
config.INPUT_IMAGE_HEIGHT))
Next, we define our make_prediction function (Lines 31-77), which will take as input the path to a test image and our trained segmentation model and plot the predicted output.
Since we are only using our trained model for prediction, we start by setting our model to eval mode and switching off PyTorch gradient computation on Line 33 and Line 36, respectively.
On Lines 39-41, we load the test image (i.e., image) from imagePath using OpenCV (Line 39), convert it to RGB format (Line 40), and normalize its pixel values from the standard [0-255] to the range [0, 1], which our model is trained to process (Line 41).
The image is then resized to the standard image dimension that our model can accept on Line 44. Since we will have to modify and process the image variable before passing it through the model, we make an additional copy of it on Line 45 and store it in the orig variable, which we will use later.
On Lines 49-51, we get the path to the ground-truth mask for our test image and load the mask on Line 55. Note that we resize the mask to the same dimensions as the input image (Lines 56 and 57).
Now we process our image to a format that our model can process. Note that currently, our image has the shape [128, 128, 3]. However, our segmentation model accepts four-dimensional inputs of the format [batch_dimension, channel_dimension, height, width].
# make the channel axis to be the leading one, add a batch
# dimension, create a PyTorch tensor, and flash it to the
# current device
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, 0)
image = torch.from_numpy(image).to(config.DEVICE)
# make the prediction, pass the results through the sigmoid
# function, and convert the result to a NumPy array
predMask = model(image).squeeze()
predMask = torch.sigmoid(predMask)
predMask = predMask.cpu().numpy()
# filter out the weak predictions and convert them to integers
predMask = (predMask > config.THRESHOLD) * 255
predMask = predMask.astype(np.uint8)
# prepare a plot for visualization
prepare_plot(orig, gtMask, predMask)
On Line 62, we transpose the image to convert it to channel-first format, that is, [3, 128, 128], and on Line 63, we add an extra dimension using the expand_dims function of numpy to convert our image into a four-dimensional array (i.e., [1, 3, 128, 128]). Note that the first dimension here represents the batch dimension equal to one since we are processing one test image at a time. We then convert our image to a PyTorch tensor with the help of the torch.from_numpy() function and move it to the device our model is on with the help of Line 64.
Finally, on Lines 68-70, we process our test image by passing it through our model and saving the output prediction as predMask. We then apply the sigmoid activation to get our predictions in the range [0, 1]. As discussed earlier, the segmentation task is a classification problem where we have to classify the pixels in one of the two discrete classes. Since sigmoid outputs continuous values in the range [0, 1], we use our config.THRESHOLD on Line 73 to binarize our output and assign the pixels, values equal to 0 or 1. This implies that anything greater than the threshold will be assigned the value 1, and others will be assigned 0.
Since the thresholded output (i.e., (predMask > config.THRESHOLD)), now comprises of values 0 or 1, multiplying it with 255 makes the final pixel values in our predMask either 0 (i.e., pixel value for black color) or 255 (i.e., pixel value for white color). As discussed earlier, the white pixels will correspond to the region where our model has detected salt deposits, and the black pixels correspond to regions where salt is not present.
We plot our original image (i.e., orig), ground-truth mask (i.e., gtMask), and our predicted output (i.e., predMask) with the help of our prepare_plot function on Line 77. This completes the definition of our make_prediction function.
We are ready to see our model in action now.
# load the image paths in our testing file and randomly select 10
# image paths
print("[INFO] loading up test image paths...")
imagePaths = open(config.TEST_PATHS).read().strip().split("\n")
imagePaths = np.random.choice(imagePaths, size=10)
# load our model from disk and flash it to the current device
print("[INFO] load up model...")
unet = torch.load(config.MODEL_PATH).to(config.DEVICE)
# iterate over the randomly selected test image paths
for path in imagePaths:
# make predictions and visualize the results
make_predictions(unet, path)
On Lines 82 and 83, we open the folder where our test image paths are stored and randomly grab 10 image paths. Line 87 loads the trained weights of our U-Net from the saved checkpoint at config.MODEL_PATH.
We finally iterate over our randomly chosen test imagePaths and predict the outputs with the help of our make_prediction function on Lines 90-92.
Figure 5 shows sample visualization outputs from our make_prediction function. The yellow region represents Class 1: Salt and the dark blue region represents Class 2: Not Salt (sediment).
Figure 5: Sample test images, corresponding ground-truth segmentation mask and predicted segmentation mask for our trained U-Net segmentation model.
We see that in case 1 and case 2 (i.e., row 1 and row 2, respectively), our model correctly identified most of the locations containing salt deposits. However, some regions where the salt deposit exists are not identified.
However, in case 3 (i.e., row 3), our model has identified some regions as salt deposits where there is no salt (the yellow blob in the middle). This is a false positive, where our model has incorrectly predicted the positive class, that is, the presence of salt, in a region where it does not exist in the ground truth.
It is worth noting that, practically, from an application point of view, the prediction in case 3 is misleading and riskier than that in the other two cases. This is likely because for the first two cases if experts set up drillers for mining salt deposits at the predicted yellow marked locations, they will successfully find salt deposits. However, if they do the same at the location of false-positive predictions (as seen in case 3), it will waste time and resources since salt deposits do not exist at that location.
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, we learned about image segmentation and built a U-Net-based image segmentation pipeline from scratch in PyTorch.
Specifically, we discussed the architectural details and salient features of the U-Net model that make it the de-facto choice for image segmentation.
In addition, we learned how we can define our own custom dataset in PyTorch for the segmentation task at hand.
Finally, we saw how we can train our U-Net based-segmentation pipeline in PyTorch and use the trained model to make predictions on test images in real-time.
After following the tutorial, you will be able to understand the internal working of any image segmentation pipeline and build your own segmentation models from scratch in PyTorch.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
The uniqueness of NeRF is proved by the number of doors it opens up in the field of computer graphics and deep learning. These range from medical imaging, 3D scene reconstruction, animation industry, relighting a scene to depth estimation.
In our previous week’s tutorial, we familiarize ourselves with the prerequisites of NeRF. We have also explored the dataset that will be used. Now, it is best to remind ourselves of the initial problem statement.
What if there was a way to capture the entire 3D scene just from a sparse set of 2D pictures?
In this tutorial, we will focus on the algorithm that NeRF takes to capture the 3D scene from the sparse set of images.
This lesson is part 2 of a 3-part series on Computer Graphics and Deep Learning with NeRF using TensorFlow and Keras:
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
The parent directory has two python scripts and two folders.
The dataset folder contains three subfolders: train, test, and val for the train, test, and validation images.
The pyimagesearch folder contains all of the python scripts we will be using for training.
Finally, we have the two driver scripts: train.py and inference.py. We will be looking at training and inference in next week’s tutorial.
Note: In the interest of time, we have divided the implementation of NeRF into two parts. This blog introduces the concepts, while next week’s blogs will cover the train and inference scripts.
Let’s talk about the premise of the paper. You have images of a particular scene from a few specific viewpoints. Now you want to generate an image of the scene from an entirely new view. This problem falls under novel image synthesis, as shown in Figure 2.
Figure 2: Novel view generation.
The immediate solution to novel view synthesis that comes to our mind is to use a Generative Adversarial Network (GAN) on the training dataset. With GANs, we are constraining ourselves to the 2D space of images.
Why not capture the entire 3D scenery from the images itself?
Let’s take a moment and try to absorb this idea.
We are now looking at a transformed problem statement. From novel view synthesis, we have transited to 3D scene capture from a sparse set of 2D images.
This new problem statement will also serve as a solution to the novel view synthesis problem. How difficult is it to generate a novel view if we have the 3D scenery at our hands?
Note that, NeRF is not the first to tackle this problem. Its predecessors have used various methods, including Convolutional Neural Networks (CNN) and gradient-based mesh optimization. However, according to the paper, these methods could not scale to better resolution due to higher space and time complexity. NeRF aims at optimizing an underlying continuous volumetric scene function.
Do not worry if you don’t get all of these terms at first glance. The rest of the blog is dedicated to breaking each of these topics down in the finest details and explaining them one by one.
We begin with a sparse set of images and their corresponding camera metadata (orientation and position). Next, we want to achieve a 3D representation of the entire scene, as shown in Figure 3.
The steps for NeRF can be visualized in the following figures:
Generate Rays: In this step, we march rays through each pixel of the image. The rays (Ray A and Ray B) are the red lines (Figure 4) that intersect the image and traverse through the 3D box (scene).
Sample points: In this step we sample points on the rays as shown in Figure 5. We must note that these points are located on the rays, making them 3D points inside the box.
Each point has a unique position and a direction component linked as shown (Figure 6). The direction of each point is the same as the direction of the ray.
Volume Rendering: Let’s consider a single ray (Ray A here) and send all the sample points to the MLP to get the corresponding color and density, as shown in Figure 8. After we have the color and density of each point, we can apply classical volume rendering (defined in a later section) to predict the color of the image pixel (pixel P here) through which the ray passes.
Photometric Loss: The difference between the predicted color of the pixel (shown in Figure 9) and the actual color of the pixel makes the photometric loss. This eventually allows us to perform backpropagation on the MLP and minimize the loss.
At this point, we have a bird’s eye view of NeRF. However, before describing the algorithm further, we need first to define an input data pipeline.
We know from the previous week’s tutorial that our dataset contains images and the corresponding camera orientations. So now, we need to build a data pipeline that produces images and the corresponding rays.
In this section, we will build this data pipeline step by step using the tf.data API. tf.data ensures an efficient way to build and use the dataset. If you want a primer on tf.data, you can refer to this tutorial.
The entire data pipeline is written in the pyimagesearch/data.py file. So, let’s open the file and start digging!
# import the necessary packages
from tensorflow.io import read_file
from tensorflow.image import decode_jpeg
from tensorflow.image import convert_image_dtype
from tensorflow.image import resize
from tensorflow import reshape
import tensorflow as tf
import json
We begin with importing the necessary packages on Lines 2-8
tensorflow to build the data pipeline
json for reading and working with json data
def read_json(jsonPath):
# open the json file
with open(jsonPath, "r") as fp:
# read the json data
data = json.load(fp)
# return the data
return data
On Lines 10-17, we define the read_json function. This function takes the path to the json file (jsonPath) and returns the parsed data.
We open the json file with the open function on Line 12. Then, with the file pointer in hand, we read the contents and parse it with the json.load function on Line 14. Finally, Line 17 returns the parsed json data.
def get_image_c2w(jsonData, datasetPath):
# define a list to store the image paths
imagePaths = []
# define a list to store the camera2world matrices
c2ws = []
# iterate over each frame of the data
for frame in jsonData["frames"]:
# grab the image file name
imagePath = frame["file_path"]
imagePath = imagePath.replace(".", datasetPath)
imagePaths.append(f"{imagePath}.png")
# grab the camera2world matrix
c2ws.append(frame["transform_matrix"])
# return the image file names and the camera2world matrices
return (imagePaths, c2ws)
On Lines 19-37, we define the get_image_c2w function. This function takes the parsed json data (jsonData) and the path to the dataset (datasetPath) and returns the path to the images (imagePaths) and its corresponding camera-to-world (c2ws) matrices.
On Lines 21-24, we define two empty lists: imagePaths and c2ws. On Lines 27-34, we iterate over the parsed json data and add the image paths and camera-to-world matrices to the empty lists. After iterating over the entire data, we return both lists (Line 37).
Working with tf.data.Dataset instances, we will need a way to transform our dataset while feeding it to the model. To efficiently do this, we use the map functionality. The map function takes in the tf.data.Dataset instance and a function that is applied to each element of the dataset.
The later part of the pyimagesearch/data.py defines functions used with the map function to transform the dataset.
class GetImages():
def __init__(self, imageWidth, imageHeight):
# define the image width and height
self.imageWidth = imageWidth
self.imageHeight = imageHeight
def __call__(self, imagePath):
# read the image file
image = read_file(imagePath)
# decode the image string
image = decode_jpeg(image, 3)
# convert the image dtype from uint8 to float32
image = convert_image_dtype(image, dtype=tf.float32)
# resize the image to the height and width in config
image = resize(image, (self.imageWidth, self.imageHeight))
image = reshape(image, (self.imageWidth, self.imageHeight, 3))
# return the image
return image
Before moving ahead, let’s discuss why we chose to build a class with a __call__ method instead of building a function that could be applied with the map function.
The problem is that the function passed to the map function cannot accept anything other than the element of the dataset. This is an imposed constraint which we need to bypass.
To overcome this problem, we have created a class that can hold some properties (here imageWidth and imageHeight) used during the function call.
On Lines 39-60, we build the GetImages class with a custom __call__ and __init__ function.
__init__: we will be using this function to initialize the parameters imageWidth and imageHeight (Lines 40-43)
__call__: this method makes the object callable. We will be using this function to read the images from the imagePaths (Line 47). Next, it is now decoded in a usable jpeg format (Line 50). We then convert the image from uint8 to float32 and reshape it (Lines 53-57).
A ray in computer graphics can be parameterized as
where
is the ray
is the origin of the ray
is the unit vector for the direction of the ray
is the parameter (e.g., time)
To build the ray equation, we need the origin and the direction. In the context of NeRF, we generate rays by taking the origin of the ray as the pixel position of the image plane and the direction as the straight line joining the pixel and the camera aperture. This is illustrated in Figure 10.
Figure 10: The process of ray generation.
We can easily devise the pixel positions of the 2D image with respect to the camera coordinate frame using the following equations.
It is easy to locate the origin of the pixel points but a little challenging to get the direction of the rays. From the previous section, we have
The camera-to-world matrix from the dataset is the that we need.
To define the direction vector, we do not need the entire camera-to-world matrix; instead, we use the upper matrix that defines the camera’s orientation.
With the rotation matrix, we can get the unit direction vector by the following equation.
The difficult calculations are now over. For the easy part, the rays’ origin will be the translation vector of the camera-to-world matrix.
Let’s see how we can translate this to code. We will be continuing with the pyimagesearch/data.py file.
class GetRays:
def __init__(self, focalLength, imageWidth, imageHeight, near,
far, nC):
# define the focal length, image width, and image height
self.focalLength = focalLength
self.imageWidth = imageWidth
self.imageHeight = imageHeight
# define the near and far bounding values
self.near = near
self.far = far
# define the number of samples for coarse model
self.nC = nC
On Lines 62-75,we create the class GetRays with a custom __call__ and __init__ function.
__init__: we initialize the focalLength, imageWidth, and imageHeight on Lines 66-68 and also the near and far bounds of the camera viewing field (Lines 71 and 72). We will need this to construct the rays to be marched into the scene, as shown In Figure 8.
def __call__(self, camera2world):
# create a meshgrid of image dimensions
(x, y) = tf.meshgrid(
tf.range(self.imageWidth, dtype=tf.float32),
tf.range(self.imageHeight, dtype=tf.float32),
indexing="xy",
)
# define the camera coordinates
xCamera = (x - self.imageWidth * 0.5) / self.focalLength
yCamera = (y - self.imageHeight * 0.5) / self.focalLength
# define the camera vector
xCyCzC = tf.stack([xCamera, -yCamera, -tf.ones_like(x)],
axis=-1)
# slice the camera2world matrix to obtain the rotation and
# translation matrix
rotation = camera2world[:3, :3]
translation = camera2world[:3, -1]
__call__: we input the camera2world matrix to this method which in turn returns
rayO: the origin points
rayD: the set of direction points
tVals: the sampled points
On Lines 79-83, we create a meshgrid of the image dimension. This is the same as the Image Plane shown in Figure 10.
Next, we obtain the camera coordinates (Lines 86 and 87) using the equation derived from our previous blog.
We define a homogeneous representation (Lines 90 and 91) of the camera vector xCyCzC by stacking the camera coordinates.
On Lines 95 and 96, we extract the rotation matrix and the translation vector from the camera-to-world matrix.
# expand the camera coordinates to
xCyCzC = xCyCzC[..., None, :]
# get the world coordinates
xWyWzW = xCyCzC * rotation
# calculate the direction vector of the ray
rayD = tf.reduce_sum(xWyWzW, axis=-1)
rayD = rayD / tf.norm(rayD, axis=-1, keepdims=True)
# calculate the origin vector of the ray
rayO = tf.broadcast_to(translation, tf.shape(rayD))
# get the sample points from the ray
tVals = tf.linspace(self.near, self.far, self.nC)
noiseShape = list(rayO.shape[:-1]) + [self.nC]
noise = (tf.random.uniform(shape=noiseShape) *
(self.far - self.near) / self.nC)
tVals = tVals + noise
# return ray origin, direction, and the sample points
return (rayO, rayD, tVals)
We then transform the camera coordinates to world coordinates using the rotation matrix (Lines 99-102).
Next, we calculate the direction rayD (Lines 105 and 106) and the origin vector rayO (Line 109).
On Lines 112-116, we sample points from the ray.
Note:We will learn about sampling points on a ray in the following section.
Finally we return rayO, rayD, and tVals on Line 119.
After the generation of rays, we need to draw sample 3D points from the rays. To do this, we suggest two ways.
Sample points at regular intervals: The name of the method is self-explanatory. Here, we sample points on the ray at regular intervals, as shown in Figure 11.
Figure 11: Sample points at regular intervals.
The sampling equation is as follows:
where and are the farthest and nearest points on the ray, respectively. We divide the entire ray into equidistant parts, and the divisions serve as the sample points.
Sample points randomly: In this method, we add randomness into the process of sampling points. The idea here is that if the sample points come from random positions of the ray, the model will be exposed to new data. This will regularize it to produce better results. The strategy is shown in Figure 12.
Figure 12: Sample points at random.
This is demonstrated by the equation below:
where refers to uniform sampling. Here, we take a random point from the space between two adjacent points.
Each sample point is of 5 dimensions. The spatial location of the point is a 3D vector (), and the direction of the point is a 2D vector (). Mildenhall et al. (2020) advocate expressing the viewing direction as a 3D Cartesian unit vector .
These 5D points serve as the input to the MLP. This field of rays with 5D points is referred to as the neural radiance field in the paper.
The MLP network predicts each input point’s color and volume density . Color refers to the () content of the point. The volume density can be interpreted as the differential probability of a ray terminating at an infinitesimal particle at that point.
We encourage the representation to be multiview consistent by restricting the network to predict the volume density as a function of only the location while allowing the RGB color to be predicted as a function of both locations and viewing direction.
With all that theory out of the way, we can start building the NeRF architecture in TensorFlow. So, let’s open the file pyimagesearch/nerf.py and start digging.
# import the necessary packages
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import concatenate
from tensorflow.keras import Input
from tensorflow.keras import Model
We begin with importing our necessary packages on Lines 2-5.
def get_model(lxyz, lDir, batchSize, denseUnits, skipLayer):
# build input layer for rays
rayInput = Input(shape=(None, None, None, 2 * 3 * lxyz + 3),
batch_size=batchSize)
# build input layer for direction of the rays
dirInput = Input(shape=(None, None, None, 2 * 3 * lDir + 3),
batch_size=batchSize)
# creating an input for the MLP
x = rayInput
for i in range(8):
# build a dense layer
x = Dense(units=denseUnits, activation="relu")(x)
# check if we have to include residual connection
if i % skipLayer == 0 and i > 0:
# inject the residual connection
x = concatenate([x, rayInput], axis=-1)
# get the sigma value
sigma = Dense(units=1, activation="relu")(x)
# create the feature vector
feature = Dense(units=denseUnits)(x)
# concatenate the feature vector with the direction input and put
# it through a dense layer
feature = concatenate([feature, dirInput], axis=-1)
x = Dense(units=denseUnits//2, activation="relu")(feature)
# get the rgb value
rgb = Dense(units=3, activation="sigmoid")(x)
# create the nerf model
nerfModel = Model(inputs=[rayInput, dirInput],
outputs=[rgb, sigma])
# return the nerf model
return nerfModel
Next, on Lines 7-46, we create our MLP model in the function called get_model. This method takes in the following inputs:
lxyz: the number of dimensions used for positional encoding of the xyz coordinates
lDir: the number of dimensions used for positional encoding of the direction vector
batchSize: the batch size of the data
denseUnits: the number of units in each layer of MLP
skipLayer: the layer at which we want the skip connection
On Lines 9-14, we define the rayInput and the dirInput layers. Next, we create the MLP with the skip connection (Lines 17-25).
To align with the paper (multiview consistency), only the rayInput is passed through the model to produce sigma (volume density) and a feature vector on Lines 28-31. Finally, the feature vector is concatenated with the dirInput (Line 35) to produce color (Line 39).
On Lines 42 and 43,we build the nerfModel using the Keras functional API. Finally, we return the nerfModel on Line 46.
In this section, we study how to achieve volume rendering. We use the predicted color and volume density from the MLP to render the 3D scene.
The predictions from the network are plugged into the classical volume rendering equation to derive the color of one particular point. For example, the equation for the same is given below:
Sounds complicated?
Let us break this equation down into simple parts.
The term is the color of the point of the object.
is the ray that is fed into the network where the variable stands for the following:
as the origin of the ray point
is the direction of the ray
is the set of uniform samples between the near and far points used for the integral
is the volume density which can also be interpreted as the differential probability of the ray terminating at the point .
is the color of the ray at the point
These are the building blocks of the equation. Apart from these, there is another term
This represents the transmittance along the ray from near point to the current point . Think of this as a measure of how much the ray can penetrate the 3D space to a certain point.
Now when we have all the terms together, we can finally make sense of the equation.
The color of an object in the 3D space is defined as the sum over of () the transmittance (), volume density (), the color of the current point () and the direction of the ray sampled for all points existing between the near () and far () of the viewing plane.
Let’s look at how to express this in code. First, we will look at the render_image_depth in the pyimagesearch/utils.py file.
def render_image_depth(rgb, sigma, tVals):
# squeeze the last dimension of sigma
sigma = sigma[..., 0]
# calculate the delta between adjacent tVals
delta = tVals[..., 1:] - tVals[..., :-1]
deltaShape = [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 1]
delta = tf.concat(
[delta, tf.broadcast_to([1e10], shape=deltaShape)], axis=-1)
# calculate alpha from sigma and delta values
alpha = 1.0 - tf.exp(-sigma * delta)
# calculate the exponential term for easier calculations
expTerm = 1.0 - alpha
epsilon = 1e-10
# calculate the transmittance and weights of the ray points
transmittance = tf.math.cumprod(expTerm + epsilon, axis=-1,
exclusive=True)
weights = alpha * transmittance
# build the image and depth map from the points of the rays
image = tf.reduce_sum(weights[..., None] * rgb, axis=-2)
depth = tf.reduce_sum(weights * tVals, axis=-1)
# return rgb, depth map and weights
return (image, depth, weights)
On Lines 15-42, we are building a render_image_depth function which takes as inputs:
rgb: the red-green-blue color matrix of the ray points
sigma: the volume density of the sample points
tVals: the sample points
It produces the volume-rendered image (image), its depth map (depth), and the weights (required for hierarchical sampling).
On Line 17, we reshape sigma for ease of calculation. Next, we calculate the space (delta) between adjacent tVals (Lines 20-23).
Next we create alpha using sigma and delta(Line 26).
We create the transmittance and weight vector (Lines 33-35).
On Lines38 and 39, we create the image and depth map.
Finally, we return image, depth, and weights on Line 42.
We refer to the loss function used by NeRF as the photometric loss. This is computed by comparing the colors of the synthesized image with the ground-truth image. Mathematically this can be expressed as:
where is the real image and is the synthesized image. This function, when applied to the entire pipeline, is still fully differentiable. This allows us to train the model parameters () using backpropagation.
We have learned about computer graphics and their fundamentals in the first part of our blog series. In this tutorial, we have taken those concepts and applied them to 3D scene representation. Here we have:
Built an image and a ray dataset from the given json files.
Sampled points from the rays using the random sampling strategy.
Passed these points into the NeRF MLP.
Rendered a novel image using the color and volume density predicted by the MLP.
Established a loss function (photometric loss) with which we will optimize the parameters of the MLP.
These steps are sufficient to train a NeRF model and render novel views. However, this vanilla architecture will eventually produce renderings of low quality. To mitigate these issues, Mildenhall et al. (2020) propose additional enhancements.
In the next section, we will learn about these enhancements and their implementation using TensorFlow.
Positional Encoding is a popular encoding format used in architectures like transformers. Mildenhall et al. (2020) justify using this to better render high-frequency features such as texture and details.
Rahaman et al. (2019) suggest that deep networks are biased toward learning low-frequency functions. To bypass this problem NeRF proposes mapping the input vector to a higher dimensional representation. Since the 5D input space is the position of the points, we are essentially encoding the positions from which it gets the name.
Let’s say we have 10 positions indexed as . The indices are in the decimal system. If we encode the digits in the binary system, we will get something, as shown in Figure 15.
Figure 15: Binary encoding.
The binary system is an easy encoding system. The only problem we face here is that the binary system is filled with zeros, making it a sparse representation. We would want to make this system continuous and compact.
The encoding function used in NeRF is as follows:
To draw a parallel between the binary and the NeRF encoding, let’s look at Figure 16.
Figure 16: Similarity between binary encoding and NeRF’s positional encoding.
The sine and cosine functions make the encoding continuous, and the term makes it similar to the binary system.
A visualization of the positional encoding function is given in Figure 17. The blue line depicts the cosine component, while the red line is the sine component.
Figure 17: Visualization of the sinusoids used for positional encoding.
We can create this fairly simply in a function called encoder_fn in the pyimagesearch/encode.py file.
# import the necessary packages
import tensorflow as tf
def encoder_fn(p, L):
# build the list of positional encodings
gamma = [p]
# iterate over the number of dimensions in time
for i in range(L):
# insert sine and cosine of the product of current dimension
# and the position vector
gamma.append(tf.sin((2.0 ** i) * p))
gamma.append(tf.cos((2.0 ** i) * p))
# concatenate the positional encodings into a positional vector
gamma = tf.concat(gamma, axis=-1)
# return the positional encoding vector
return gamma
We start with importing tensorflow (Line 2). On Lines 4-19, we define the encoder function, which takes in the following parameters:
p: position of each element to be encoded
L: the dimension into which the encoding will take place
On Line 6, we define a list that will hold the positional encoding. Next, we iterate over dimensions and append the encoded values into the list (Lines 9-13). Lines 16-19 are used to convert the same list into a tensor and finally return it.
Mildenhall et al. (2020) found another problem with the original structure. The random sampling method would sample N points along each camera ray. This means we don’t have any prior understanding of where it should sample. That ultimately leads to an inefficient rendering.
They propose the following solution to remedy this:
Build two identical NeRF MLP models, the coarse and fine network.
Sample a set of points along the camera ray using the random sampling strategy, as shown in Figure 12. These points will be used to query the coarse network.
The output of the coarse network is used to produce a more informed sampling of points along each ray. These samples are biased towards the more relevant parts of the 3D scene.
To do this, we rewrite the color equation:
As a weighted sum of all sample colors .
where the term .
The weights, when normalized, produce a piecewise-constant probability density function.
The entire procedure of turning the weights into a probability density function is visualized in Figure 18.
Figure 18: From weights to PDF.
From the probability density function, we sample the second set of locations using the inverse transform sampling method, as shown in Figure 19.
Now we have both and set of sampled points. We send these points to the fine network to produce the final rendered color of the ray.
This process of converting weights to a new set of sample points can be expressed through a function called sample_pdf. First, let’s refer to the utils.py file inside the pyimagesearch folder.
def sample_pdf(tValsMid, weights, nF):
# add a small value to the weights to prevent it from nan
weights += 1e-5
# normalize the weights to get the pdf
pdf = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)
# from pdf to cdf transformation
cdf = tf.cumsum(pdf, axis=-1)
# start the cdf with 0s
cdf = tf.concat([tf.zeros_like(cdf[..., :1]), cdf], axis=-1)
# get the sample points
uShape = [BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, nF]
u = tf.random.uniform(shape=uShape)
# get the indices of the points of u when u is inserted into cdf in a
# sorted manner
indices = tf.searchsorted(cdf, u, side="right")
# define the boundaries
below = tf.maximum(0, indices-1)
above = tf.minimum(cdf.shape[-1]-1, indices)
indicesG = tf.stack([below, above], axis=-1)
# gather the cdf according to the indices
cdfG = tf.gather(cdf, indicesG, axis=-1,
batch_dims=len(indicesG.shape)-2)
# gather the tVals according to the indices
tValsMidG = tf.gather(tValsMid, indicesG, axis=-1,
batch_dims=len(indicesG.shape)-2)
# create the samples by inverting the cdf
denom = cdfG[..., 1] - cdfG[..., 0]
denom = tf.where(denom < 1e-5, tf.ones_like(denom), denom)
t = (u - cdfG[..., 0]) / denom
samples = (tValsMidG[..., 0] + t *
(tValsMidG[..., 1] - tValsMidG[..., 0]))
# return the samples
return samples
This code snippet has been inspired by the official NeRF implementation. On Lines 44-86, we create a function called sample_pdf that takes in the following parameters:
tValsMid: the midpoints between two adjacent tVals
weights: the weights used in the volume rendering function
nF: number of points used by the fine model
On Lines 46-49, we define the probability density function from the weights and then convert the same into a cumulative distribution function (cdf). This is then converted back into sample points for the fine model using inverse transform sampling (Lines 52-86).
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
We have gone through the core concepts proposed in the paper NeRF and also implemented them using TensorFlow.
We can recall what we have learned so far in the following steps:
Building the image and ray dataset for 5D scene representation
Sample points from the rays using any of the sampling strategies
Passing these points through the NeRF MLP model
Volume rendering based on the output of the MLP model
Calculating the photometric loss
Using positional encoding and hierarchical sampling to improve the quality of rendering
In next week’s tutorial, we will cover how to utilize all of these concepts to train the NeRF model. In addition, we will also render a 360-degree video of a 3D scene from 2D images.
We hope you enjoyed this week’s tutorial, and as always, you can download the source code and try it out yourself.
@article{Gosthipaty_Raha_2021_pt2,
author = {Aritra Roy Gosthipaty and Ritwik Raha},
title = {Computer Graphics and Deep Learning with {NeRF} using {TensorFlow} and {Keras}: Part 2},
journal = {PyImageSearch},
year = {2021},
note = {https://www.pyimagesearch.com/2021/11/17/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-2/},
}
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
Let’s wind the clocks back a week or two. In the first tutorial, we learned about the fundamentals of Computer Graphics and image rendering. In the second tutorial, we went deeper into the core ideas proposed by NeRF and implemented them using TensorFlow and Keras.
We begin by reminding ourselves of the original problem we set out to solve:
What if there was a way to capture the entire 3D scene just from a sparse set of 2D pictures?
We have come a long way to solve this problem. We have created the architecture and the components needed to build NeRF. But we don’t know yet how each piece fits in the larger picture.
In this tutorial, we assemble all the details to train the NeRF model.
This lesson is the final part of a 3-part series on Computer Graphics and Deep Learning with NeRF using TensorFlow and Keras:
In this week’s tutorial, we will be explicitly looking at training the NeRF Multilayer Perceptron (MLP) that we built last week. We have divided this tutorial into the following sections:
NeRF Assemble: How to train a NeRF
NeRF Trainer: A helper model which trains the coarse and the fine NeRF models
Custom callback: A custom callback that helps us visualize the training process
Tying it all together: Bringing together all of the components
Inference: Build the 3D scene from a trained NeRF model
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Start by accessing the “Downloads” section of this tutorial to retrieve the source code. We also expect you to download the dataset and keep it handy. You can find details about the dataset in the first tutorial.
From there, let’s take a look at the directory structure:
The dataset folder contains three subfolders, train, test, and val for the training, testing, and validation images.
The pyimagesearch folder contains all of the python scripts we will be using for training. These were discussed and explained in the previous week’s tutorial.
Next, we have the two driver scripts: train.py and inference.py. We train our NeRF model with the train.py script. With the inference.py, we generate a video of a 360-degree view of the scenery from the trained NeRF model.
In this section, we assemble (pun intended) all of the components explained in the previous blog post and head on to training the NeRF model. This section will cover three python scripts.
nerf_trainer.py: custom keras model to train the coarse and fine models
train_monitor.py: a custom callback to visualize and draw insights from the training process
train.py: the final script that brings everything together
Consider this section as the ultimate battle cry like Figure 2. By the time we finish this section, we will be ready with our trained NeRF model.
tf.keras has a beautiful fit API called to train a model. When the training pipeline becomes complicated, we build a custom tf.keras.Model and a custom train_step. This way, we can still leverage the fit function. We recommend the official keras tutorial on customizing the fit call to anyone who wants to go deeper.
In the NeRF training pipeline, the MLP is simple. The only complications we face are volume rendering and hierarchical sampling.
Take note that we train two models (coarse and fine) with hierarchical sampling instead of one. To encapsulate everything inside the fit call, we build a custom NeRF_Trainer model.
The NeRF_Trainer is written in pyimagesearch/nerf_trainer.py. Let’s open the file and go through the script to understand it better.
# import the necessary packages
from tensorflow.keras.metrics import Mean
import tensorflow as tf
We begin with the necessary imports in Lines 2-3.
class Nerf_Trainer(tf.keras.Model):
def __init__(self, coarseModel, fineModel, lxyz, lDir,
encoderFn, renderImageDepth, samplePdf, nF):
super().__init__()
# define the coarse model and fine model
self.coarseModel = coarseModel
self.fineModel = fineModel
# define the dimensions for positional encoding for spatial
# coordinates and direction
self.lxyz = lxyz
self.lDir = lDir
# define the positional encoder
self.encoderFn = encoderFn
# define the volume rendering function
self.renderImageDepth = renderImageDepth
# define the hierarchical sampling function and the number of
# samples for the fine model
self.samplePdf = samplePdf
self.nF = nF
On Lines 6-27, the __init__ method serves as the Nerf_Trainer model constructor. The method accepts the following parameters:
coarseModel: the coarse NeRF model
fineModel: the fine NeRF model
lxyz: the number of dimensions used for positional encoding of the xyz coordinates
lDir: the number of dimensions used for positional encoding of the direction vector
encoderFn: positional encoding function for the model
renderImageDepth: the volume rendering function
samplePdf: utility function for hierarchical sampling
nF: number of fine model samples
def compile(self, optimizerCoarse, optimizerFine, lossFn):
super().compile()
# define the optimizer for the coarse and fine model
self.optimizerCoarse = optimizerCoarse
self.optimizerFine = optimizerFine
# define the photometric loss function
self.lossFn = lossFn
# define the loss and psnr tracker
self.lossTracker = Mean(name="loss")
self.psnrMetric = Mean(name="psnr")
On Lines 29-40, we define the compile method, which is called when the Nerf_Trainer model is compiled. The method accepts the following parameters:
optimizerCoarse: the optimizer for the coarse model
optimizerFine: the optimizer for the fine model
lossFn: the loss function for the NeRF models
On Lines 39 and 40, we define two trackers, namely lossTracker and psnrTracker. We use these trackers to track the model loss and the PSNR between original and predicted images.
def train_step(self, inputs):
# get the images and the rays
(elements, images) = inputs
(raysOriCoarse, raysDirCoarse, tValsCoarse) = elements
# generate the coarse rays
raysCoarse = (raysOriCoarse[..., None, :] +
(raysDirCoarse[..., None, :] * tValsCoarse[..., None]))
# positional encode the rays and dirs
raysCoarse = self.encoderFn(raysCoarse, self.lxyz)
dirCoarseShape = tf.shape(raysCoarse[..., :3])
dirsCoarse = tf.broadcast_to(raysDirCoarse[..., None, :],
shape=dirCoarseShape)
dirsCoarse = self.encoderFn(dirsCoarse, self.lDir)
Now we start with the train_step method (Lines 42-127). This method is called when we do a model.fit() on the Nerf_Trainer custom model. The following points explain the train_step method:
Lines 44 and 45 unpack the input.
Lines 48 and 49 generate the rays for the coarse model.
Lines 52-56 encode the ray and direction using the positional encoding function.
# keep track of our gradients
with tf.GradientTape() as coarseTape:
# compute the predictions from the coarse model
(rgbCoarse, sigmaCoarse) = self.coarseModel([raysCoarse,
dirsCoarse])
# render the image from the predictions
renderCoarse = self.renderImageDepth(rgb=rgbCoarse,
sigma=sigmaCoarse, tVals=tValsCoarse)
(imagesCoarse, _, weightsCoarse) = renderCoarse
# compute the photometric loss
lossCoarse = self.lossFn(images, imagesCoarse)
On Lines 59-70, we define the forward pass of the coarse model. On Lines 61 and 62,the model takes in rays and directions as input and produces rgb (color) and sigma (volume density).
These outputs (rgb and sigma) are then passed through the renderImageDepth function (for volume rendering) and produce the image depth map and the weights (Lines 65-67).
On Line 70, we compute the mean-squared error between the target image and the rendered image for the coarse model.
# compute the middle values of t vals
tValsCoarseMid = (0.5 *
(tValsCoarse[..., 1:] + tValsCoarse[..., :-1]))
# apply hierarchical sampling and get the t vals for the fine
# model
tValsFine = self.samplePdf(tValsMid=tValsCoarseMid,
weights=weightsCoarse, nF=self.nF)
tValsFine = tf.sort(
tf.concat([tValsCoarse, tValsFine], axis=-1), axis=-1)
# build the fine rays and positional encode it
raysFine = (raysOriCoarse[..., None, :] +
(raysDirCoarse[..., None, :] * tValsFine[..., None]))
raysFine = self.encoderFn(raysFine, self.lxyz)
# build the fine directions and positional encode it
dirsFineShape = tf.shape(raysFine[..., :3])
dirsFine = tf.broadcast_to(raysDirCoarse[..., None, :],
shape=dirsFineShape)
dirsFine = self.encoderFn(dirsFine, self.lDir)
On Lines 73-81, we compute the tValsFine for the fine model using the sample_pdf function
Next we build the rays and directions for the fine model (Lines 84-92).
# keep track of our gradients
with tf.GradientTape() as fineTape:
# compute the predictions from the fine model
rgbFine, sigmaFine = self.fineModel([raysFine, dirsFine])
# render the image from the predictions
renderFine = self.renderImageDepth(rgb=rgbFine,
sigma=sigmaFine, tVals=tValsFine)
(imageFine, _, _) = renderFine
# compute the photometric loss
lossFine = self.lossFn(images, imageFine)
Lines 94-105 are used to define the forward pass of the fine model. This is identical to the forward pass of the coarse model.
# get the trainable variables from the coarse model and
# apply back propagation
tvCoarse = self.coarseModel.trainable_variables
gradsCoarse = coarseTape.gradient(lossCoarse, tvCoarse)
self.optimizerCoarse.apply_gradients(zip(gradsCoarse,
tvCoarse))
# get the trainable variables from the coarse model and
# apply back propagation
tvFine = self.fineModel.trainable_variables
gradsFine = fineTape.gradient(lossFine, tvFine)
self.optimizerFine.apply_gradients(zip(gradsFine, tvFine))
psnr = tf.image.psnr(images, imageFine, max_val=1.0)
# compute the loss and psnr metrics
self.lossTracker.update_state(lossFine)
self.psnrMetric.update_state(psnr)
# return the loss and psnr metrics
return {"loss": self.lossTracker.result(),
"psnr": self.psnrMetric.result()}
On Line 109, we obtain the trainable parameters of the coarse model. The gradient of these parameters is computed (Line 110). We apply the computed gradients on these parameters using the optimizer (Lines 111 and 112)
The same is then repeated for the parameters of the fine model (Lines 116-119).
Lines 122 and 123 are used to update the loss and peak signal-to-noise ratio (PSNR) tracker, which is then returned on Lines 126 and 127.
def test_step(self, inputs):
# get the images and the rays
(elements, images) = inputs
(raysOriCoarse, raysDirCoarse, tValsCoarse) = elements
# generate the coarse rays
raysCoarse = (raysOriCoarse[..., None, :] +
(raysDirCoarse[..., None, :] * tValsCoarse[..., None]))
# positional encode the rays and dirs
raysCoarse = self.encoderFn(raysCoarse, self.lxyz)
dirCoarseShape = tf.shape(raysCoarse[..., :3])
dirsCoarse = tf.broadcast_to(raysDirCoarse[..., None, :],
shape=dirCoarseShape)
dirsCoarse = self.encoderFn(dirsCoarse, self.lDir)
# compute the predictions from the coarse model
(rgbCoarse, sigmaCoarse) = self.coarseModel([raysCoarse,
dirsCoarse])
# render the image from the predictions
renderCoarse = self.renderImageDepth(rgb=rgbCoarse,
sigma=sigmaCoarse, tVals=tValsCoarse)
(_, _, weightsCoarse) = renderCoarse
# compute the middle values of t vals
tValsCoarseMid = (0.5 *
(tValsCoarse[..., 1:] + tValsCoarse[..., :-1]))
# apply hierarchical sampling and get the t vals for the fine
# model
tValsFine = self.samplePdf(tValsMid=tValsCoarseMid,
weights=weightsCoarse, nF=self.nF)
tValsFine = tf.sort(
tf.concat([tValsCoarse, tValsFine], axis=-1), axis=-1)
# build the fine rays and positional encode it
raysFine = (raysOriCoarse[..., None, :] +
(raysDirCoarse[..., None, :] * tValsFine[..., None]))
raysFine = self.encoderFn(raysFine, self.lxyz)
# build the fine directions and positional encode it
dirsFineShape = tf.shape(raysFine[..., :3])
dirsFine = tf.broadcast_to(raysDirCoarse[..., None, :],
shape=dirsFineShape)
dirsFine = self.encoderFn(dirsFine, self.lDir)
# compute the predictions from the fine model
rgbFine, sigmaFine = self.fineModel([raysFine, dirsFine])
# render the image from the predictions
renderFine = self.renderImageDepth(rgb=rgbFine,
sigma=sigmaFine, tVals=tValsFine)
(imageFine, _, _) = renderFine
# compute the photometric loss and psnr
lossFine = self.lossFn(images, imageFine)
psnr = tf.image.psnr(images, imageFine, max_val=1.0)
# compute the loss and psnr metrics
self.lossTracker.update_state(lossFine)
self.psnrMetric.update_state(psnr)
# return the loss and psnr metrics
return {"loss": self.lossTracker.result(),
"psnr": self.psnrMetric.result()}
@property
def metrics(self):
# return the loss and psnr tracker
return [self.lossTracker, self.psnrMetric]
Now we define the test_step (Lines 129-194). The test_step and train_step are identical. The only difference is we do not compute the gradients in the test_step.
Finally, we define the loss tracker and the PSNR tracker as class properties (Lines 196-199).
An important point to note here is that the NeRF model is very memory intensive. Therefore, while it would be cool to see the result, it is equally important to visualize each training process step.
To visualize each step, we create a custom callback. We recommend going through this tutorial to get a better understanding of custom callbacks in Keras.
Let’s open pyimagesearch/train_monitor.py and start digging.
# import the necessary packages
from tensorflow.keras.preprocessing.image import array_to_img
from tensorflow.keras.callbacks import Callback
import matplotlib.pyplot as plt
import tensorflow as tf
We begin with importing the necessary packages for this script (Lines 2-5).
def get_train_monitor(testDs, encoderFn, lxyz, lDir, imagePath):
# grab images and rays from the testing dataset
(tElements, tImages) = next(iter(testDs))
(tRaysOriCoarse, tRaysDirCoarse, tTvalsCoarse) = tElements
# build the test coarse ray
tRaysCoarse = (tRaysOriCoarse[..., None, :] +
(tRaysDirCoarse[..., None, :] * tTvalsCoarse[..., None]))
# positional encode the rays and direction vectors for the coarse
# ray
tRaysCoarse = encoderFn(tRaysCoarse, lxyz)
tDirsCoarseShape = tf.shape(tRaysCoarse[..., :3])
tDirsCoarse = tf.broadcast_to(tRaysDirCoarse[..., None, :],
shape=tDirsCoarseShape)
tDirsCoarse = encoderFn(tDirsCoarse, lDir)
On Line 7, we define the get_train_monitor method which builds and returns a custom callback.
On Lines 9 and 10,we unpack the inputs from the testDs (test dataset).
Next on Lines 13 and 14, we generate the rays for the coarse model.
On Lines 18-22, we encode the rays and directions for the coarse model using positional encoding.
class TrainMonitor(Callback):
def on_epoch_end(self, epoch, logs=None):
# compute the coarse model prediction
(tRgbCoarse, tSigmaCoarse) = self.model.coarseModel.predict(
[tRaysCoarse, tDirsCoarse])
# render the image from the model prediction
tRenderCoarse = self.model.renderImageDepth(rgb=tRgbCoarse,
sigma=tSigmaCoarse, tVals=tTvalsCoarse)
(tImageCoarse, _, tWeightsCoarse) = tRenderCoarse
# compute the middle values of t vals
tTvalsCoarseMid = (0.5 *
(tTvalsCoarse[..., 1:] + tTvalsCoarse[..., :-1]))
# apply hierarchical sampling and get the t vals for the
# fine model
tTvalsFine = self.model.samplePdf(
tValsMid=tTvalsCoarseMid, weights=tWeightsCoarse,
nF=self.model.nF)
tTvalsFine = tf.sort(
tf.concat([tTvalsCoarse, tTvalsFine], axis=-1),
axis=-1)
# build the fine rays and positional encode it
tRaysFine = (tRaysOriCoarse[..., None, :] +
(tRaysDirCoarse[..., None, :] * tTvalsFine[..., None])
)
tRaysFine = self.model.encoderFn(tRaysFine, lxyz)
# build the fine directions and positional encode it
tDirsFineShape = tf.shape(tRaysFine[..., :3])
tDirsFine = tf.broadcast_to(tRaysDirCoarse[..., None, :],
shape=tDirsFineShape)
tDirsFine = self.model.encoderFn(tDirsFine, lDir)
# compute the fine model prediction
tRgbFine, tSigmaFine = self.model.fineModel.predict(
[tRaysFine, tDirsFine])
# render the image from the model prediction
tRenderFine = self.model.renderImageDepth(rgb=tRgbFine,
sigma=tSigmaFine, tVals=tTvalsFine)
(tImageFine, tDepthFine, _) = tRenderFine
# plot the coarse image, fine image, fine depth map and
# target image
(_, ax) = plt.subplots(nrows=1, ncols=4, figsize=(10, 10))
ax[0].imshow(array_to_img(tImageCoarse[0]))
ax[0].set_title(f"Corase Image")
ax[1].imshow(array_to_img(tImageFine[0]))
ax[1].set_title(f"Fine Image")
ax[2].imshow(array_to_img(tDepthFine[0, ..., None]),
cmap="inferno")
ax[2].set_title(f"Fine Depth Image")
ax[3].imshow(array_to_img(tImages[0]))
ax[3].set_title(f"Real Image")
plt.savefig(f"{imagePath}/{epoch:03d}.png")
plt.close()
# instantiate a train monitor callback
trainMonitor = TrainMonitor()
# return the train monitor
return trainMonitor
We define the on_epoch_end function inside the custom callback class to help visualize training logs and figures (Line 25). As the name suggests, this function is only triggered on every epoch end of the model training
On Lines 27 and 28,we predict the color and volume density using the coarse model. Next, on Lines 31-33, we render the coarse image using the volumetric rendering function renderImageDepth.
We then generate fine sample points using hierarchical sampling (Lines 36-46).
On Lines 49-51, we use the fine sample points and generate the fine rays by multiplying the fine sample points with the coarse rays.
On Line 52, we encode the fine rays using positional encoding.
We then extract the direction component from the rays (Line 55)and reshape it (Lines 56 and 57), and finally encode the directions using positional encoding (Line 58).
The fine rays, directions, and the model are then used to predict the refined color and volume density (Lines 61 and 62). We use these to render the image and the depth map on Lines 65-67.
The coarse image, fine image, and depth maps are then visualized on Lines 71-86.
On Line 89, we instantiate the train monitor callback and then return it on Line 92.
With all the components in hand, we will finally be able to train our NeRF model using the script given below. Let’s open train.py and start going through it.
# USAGE
# python train.py
# setting seed for reproducibility
import tensorflow as tf
tf.random.set_seed(42)
# import the necessary packages
from pyimagesearch.data import read_json
from pyimagesearch.data import get_image_c2w
from pyimagesearch.data import GetImages
from pyimagesearch.data import GetRays
from pyimagesearch.utils import get_focal_from_fov, render_image_depth, sample_pdf
from pyimagesearch.encoder import encoder_fn
from pyimagesearch.nerf import get_model
from pyimagesearch.nerf_trainer import Nerf_Trainer
from pyimagesearch.train_monitor import get_train_monitor
from pyimagesearch import config
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
import os
On Line 6, we set the random seed for reproducibility. Next, we begin by importing the necessary packages (Lines 5-21).
# get the train validation and test data
print("[INFO] grabbing the data from json files...")
jsonTrainData = read_json(config.TRAIN_JSON)
jsonValData = read_json(config.VAL_JSON)
jsonTestData = read_json(config.TEST_JSON)
focalLength = get_focal_from_fov(
fieldOfView=jsonTrainData["camera_angle_x"],
width=config.IMAGE_WIDTH)
# print the focal length of the camera
print(f"[INFO] focal length of the camera: {focalLength}...")
On Lines 25-27, we extract the train, test, and validation data from the respective json files. We then calculate the camera’s focal length (Lines 29-34) and print the same.
# get the train, validation, and test image paths and camera2world
# matrices
print("[INFO] grabbing the image paths and camera2world matrices...")
trainImagePaths, trainC2Ws = get_image_c2w(jsonData=jsonTrainData,
datasetPath=config.DATASET_PATH)
valImagePaths, valC2Ws = get_image_c2w(jsonData=jsonValData,
datasetPath=config.DATASET_PATH)
testImagePaths, testC2Ws = get_image_c2w(jsonData=jsonTestData,
datasetPath=config.DATASET_PATH)
# instantiate a object of our class used to load images from disk
getImages = GetImages(imageHeight=config.IMAGE_HEIGHT,
imageWidth=config.IMAGE_WIDTH)
# get the train, validation, and test image dataset
print("[INFO] building the image dataset pipeline...")
trainImageDs = (
tf.data.Dataset.from_tensor_slices(trainImagePaths)
.map(getImages, num_parallel_calls=config.AUTO)
)
valImageDs = (
tf.data.Dataset.from_tensor_slices(valImagePaths)
.map(getImages, num_parallel_calls=config.AUTO)
)
testImageDs = (
tf.data.Dataset.from_tensor_slices(testImagePaths)
.map(getImages, num_parallel_calls=config.AUTO)
)
We construct the image paths and camera-to-world matrices (Lines 39-44) from the json data extracted earlier.
Next, we build the tf.data image dataset (Lines 52-63). These include the train, test, and validation datasets, respectively.
# instantiate the GetRays object
getRays = GetRays(focalLength=focalLength, imageWidth=config.IMAGE_WIDTH,
imageHeight=config.IMAGE_HEIGHT, near=config.NEAR, far=config.FAR,
nC=config.N_C)
# get the train validation and test rays dataset
print("[INFO] building the rays dataset pipeline...")
trainRayDs = (
tf.data.Dataset.from_tensor_slices(trainC2Ws)
.map(getRays, num_parallel_calls=config.AUTO)
)
valRayDs = (
tf.data.Dataset.from_tensor_slices(valC2Ws)
.map(getRays, num_parallel_calls=config.AUTO)
)
testRayDs = (
tf.data.Dataset.from_tensor_slices(testC2Ws)
.map(getRays, num_parallel_calls=config.AUTO)
)
On Lines 66-68, we instantiate an object of the GetRays class. We then create the tf.data train, validation, and testing ray dataset (Lines 72-83).
# zip the images and rays dataset together
trainDs = tf.data.Dataset.zip((trainRayDs, trainImageDs))
valDs = tf.data.Dataset.zip((valRayDs, valImageDs))
testDs = tf.data.Dataset.zip((testRayDs, testImageDs))
# build data input pipeline for train, val, and test datasets
trainDs = (
trainDs
.shuffle(config.BATCH_SIZE)
.batch(config.BATCH_SIZE)
.repeat()
.prefetch(config.AUTO)
)
valDs = (
valDs
.shuffle(config.BATCH_SIZE)
.batch(config.BATCH_SIZE)
.repeat()
.prefetch(config.AUTO)
)
testDs = (
testDs
.batch(config.BATCH_SIZE)
.prefetch(config.AUTO)
)
The image and ray dataset are then zipped together (Lines 86-88). All the datasets (train, validation, and test) are then shuffled, batched, repeated, and prefetched on (Lines 91-109).
# instantiate the coarse model
coarseModel = get_model(lxyz=config.L_XYZ, lDir=config.L_DIR,
batchSize=config.BATCH_SIZE, denseUnits=config.DENSE_UNITS,
skipLayer=config.SKIP_LAYER)
# instantiate the fine model
fineModel = get_model(lxyz=config.L_XYZ, lDir=config.L_DIR,
batchSize=config.BATCH_SIZE, denseUnits=config.DENSE_UNITS,
skipLayer=config.SKIP_LAYER)
# instantiate the nerf trainer model
nerfTrainerModel = Nerf_Trainer(coarseModel=coarseModel, fineModel=fineModel,
lxyz=config.L_XYZ, lDir=config.L_DIR, encoderFn=encoder_fn,
renderImageDepth=render_image_depth, samplePdf=sample_pdf,
nF=config.N_F)
# compile the nerf trainer model with Adam optimizer and MSE loss
nerfTrainerModel.compile(optimizerCoarse=Adam(),optimizerFine=Adam(),
lossFn=MeanSquaredError())
Now we define the coarse and the fine models in (Lines 112-119). Next, we define the nerfTrainerModel, which is a custom keras model that trains the coarse and fine model together (Lines 122-125).
On Lines 128 and 129, we compile the nerfTrainerModel with suitable optimizer (here Adam) and loss function (here mean-squared error).
# check if the output image directory already exists, if it doesn't,
# then create it
if not os.path.exists(config.IMAGE_PATH):
os.makedirs(config.IMAGE_PATH)
# get the train monitor callback
trainMonitorCallback = get_train_monitor(testDs=testDs,
encoderFn=encoder_fn, lxyz=config.L_XYZ, lDir=config.L_DIR,
imagePath=config.IMAGE_PATH)
# train the NeRF model
print("[INFO] training the nerf model...")
nerfTrainerModel.fit(trainDs, steps_per_epoch=config.STEPS_PER_EPOCH,
validation_data=valDs, validation_steps=config.VALIDATION_STEPS,
epochs=config.EPOCHS, callbacks=[trainMonitorCallback],
)
# save the coarse and fine model
nerfTrainerModel.coarseModel.save(config.COARSE_PATH)
nerfTrainerModel.fineModel.save(config.FINE_PATH)
Lines 133-139 create the output directory and initialize the trainMonitorCallback. Finally, we train the nerfTrainerModel with the training dataset and validate it with the validation dataset (Lines 143-146).
We wrap up the training process by storing the trained coarse and fine models to disk (Lines 149 and 150).
Take a minute and congratulate yourself, as in Figure 3. We started from the basics, and now we have successfully trained NeRF. This was a long journey, and I am happy we did this together.
After all the hard work, what is better than seeing the results.
We have modeled the entire 3D scenery in the MLP, right? Why not rotate the camera around the entire scene and click pictures?
In this section, we will ask our model to synthesize novel views from the 3D scenery that it just modeled. We will be synthesizing novel views across 360 degrees in the axis.
If you are unfamiliar with and axes in the 3D coordinate system, you can quickly revise your concepts with Figures 4 and 5.
Figure 4: A full rotation about the axis.
Figure 5: A full rotation about the axis.
Let’s open inference.py to visualize the complete rotation about theta axes.
# import the necessary packages
from pyimagesearch import config
from pyimagesearch.utils import pose_spherical
from pyimagesearch.data import GetRays
from pyimagesearch.utils import get_focal_from_fov
from pyimagesearch.data import read_json
from pyimagesearch.encoder import encoder_fn
from pyimagesearch.utils import render_image_depth
from pyimagesearch.utils import sample_pdf
from tensorflow.keras.models import load_model
from tqdm import tqdm
import tensorflow as tf
import numpy as np
import imageio
import os
We begin with our usual necessary imports (Lines 2-15).
# create a camera2world matrix list to store the novel view
# camera2world matrices
c2wList = []
# iterate over theta and generate novel view camera2world matrices
for theta in np.linspace(0.0, 360.0, config.SAMPLE_THETA_POINTS,
endpoint=False):
# generate camera2world matrix
c2w = pose_spherical(theta, -30.0, 4.0)
# append the new camera2world matrix into the collection
c2wList.append(c2w)
# get the train validation and test data
print("[INFO] grabbing the data from json files...")
jsonTrainData = read_json(config.TRAIN_JSON)
focalLength = get_focal_from_fov(
fieldOfView=jsonTrainData["camera_angle_x"],
width=config.IMAGE_WIDTH)
# instantiate the GetRays object
getRays = GetRays(focalLength=focalLength, imageWidth=config.IMAGE_WIDTH,
imageHeight=config.IMAGE_HEIGHT, near=config.NEAR, far=config.FAR,
nC=config.N_C)
# create a dataset from the novel view camera2world matrices
ds = (
tf.data.Dataset.from_tensor_slices(c2wList)
.map(getRays)
.batch(config.BATCH_SIZE)
)
# load the coarse and the fine model
coarseModel = load_model(config.COARSE_PATH, compile=False)
fineModel = load_model(config.FINE_PATH, compile=False)
Next, on Line 19, we build an empty list of camera-to-world matrices c2w. On Line 22, we iterate over a range of 0 to 360. The range corresponds to the theta values that we will be using. We keep phi as -30 and distance as 4. These values, theta, phi, and distance, are passed into the function pose_spherical to obtain our camera-to-world matrices (Lines 25-28).
On Lines 31-48, we grab the training json data and extract the rays and focal length. We then create a dataset and batch it as needed.
On Lines 51 and 52, we load the pretrained coarse and fine model.
# create a list to hold all the novel view from the nerf model
print("[INFO] grabbing the novel views...")
frameList = []
for element in tqdm(ds):
(raysOriCoarse, raysDirCoarse, tValsCoarse) = element
# generate the coarse rays
raysCoarse = (raysOriCoarse[..., None, :] +
(raysDirCoarse[..., None, :] * tValsCoarse[..., None]))
# positional encode the rays and dirs
raysCoarse = encoder_fn(raysCoarse, config.L_XYZ)
dirCoarseShape = tf.shape(raysCoarse[..., :3])
dirsCoarse = tf.broadcast_to(raysDirCoarse[..., None, :],
shape=dirCoarseShape)
dirsCoarse = encoder_fn(dirsCoarse, config.L_DIR)
# compute the predictions from the coarse model
(rgbCoarse, sigmaCoarse) = coarseModel.predict(
[raysCoarse, dirsCoarse])
# render the image from the predictions
renderCoarse = render_image_depth(rgb=rgbCoarse,
sigma=sigmaCoarse, tVals=tValsCoarse)
(_, _, weightsCoarse) = renderCoarse
# compute the middle values of t vals
tValsCoarseMid = (0.5 *
(tValsCoarse[..., 1:] + tValsCoarse[..., :-1]))
# apply hierarchical sampling and get the t vals for the fine
# model
tValsFine = sample_pdf(tValsMid=tValsCoarseMid,
weights=weightsCoarse, nF=config.N_F)
tValsFine = tf.sort(
tf.concat([tValsCoarse, tValsFine], axis=-1), axis=-1)
# build the fine rays and positional encode it
raysFine = (raysOriCoarse[..., None, :] +
(raysDirCoarse[..., None, :] * tValsFine[..., None]))
raysFine = encoder_fn(raysFine, config.L_XYZ)
# build the fine directions and positional encode it
dirsFineShape = tf.shape(raysFine[..., :3])
dirsFine = tf.broadcast_to(raysDirCoarse[..., None, :],
shape=dirsFineShape)
dirsFine = encoder_fn(dirsFine, config.L_DIR)
# compute the predictions from the fine model
(rgbFine, sigmaFine) = fineModel.predict([raysFine, dirsFine])
# render the image from the predictions
renderFine = render_image_depth(rgb=rgbFine, sigma=sigmaFine,
tVals=tValsFine)
(imageFine, _, _) = renderFine
# insert the rendered fine image to the collection
frameList.append(imageFine.numpy()[0])
We iterate through our created dataset and unpack ray origin, ray direction, and sample points for each element in the dataset (Lines 57 and 58). We use these to render our coarse and fine scenes exactly as we did in training. This is explained in the following points:
The inputs are then broadcasted to suitable shapes, passed through the encoder function, and finally into the coarse model to predict rgbCoarse and sigmaCoarse (Lines 61-73).
On Lines 76-78, the color and volume density obtained is passed through the render_image_depth function to produce rendered images.
On Lines 81-89, we use the mid values of these samples and the weights derived from the rendered image to compute tValsFine using the sample_pdf function.
On Lines 92-100, we build the fine model rays and position encode it and then repeat the same for directions of the fine rays.
We render the fine image using the predictions from the fine model. The novel views are then appended to frameList (Lines 103-111).
# check if the output video directory exists, if it does not, then
# create it
if not os.path.exists(config.VIDEO_PATH):
os.makedirs(config.VIDEO_PATH)
# build the video from the frames and save it to disk
print("[INFO] creating the video from the frames...")
imageio.mimwrite(config.OUTPUT_VIDEO_PATH, frameList, fps=config.FPS,
quality=config.QUALITY, macro_block_size=config.MACRO_BLOCK_SIZE)
Finally, we use these frames to render a 360-degree video of the object (Lines 115-121).
Let’s have a look at the fruits of our hard work, shall we? The 360-degree video of the rendered hot dog object is shown in Figure 6.
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, we have successfully implemented a training and inference script that is scalable and compact.
In the final stage of the tutorial, we have synthesized novel views from sparse static images and rendered a video out of them.
NeRF is an example of groundbreaking research in both Deep Learning and Computer Graphics. It advances the field a great deal by achieving results that very few methods have been able to do thus far. Numerous variants and improvements will also succeed in the coming years.
Tell us which of these variants you would like us to cover next?
@article{Gosthipaty_Raha_2021_pt3,
author = {Aritra Roy Gosthipaty and Ritwik Raha},
title = {Computer Graphics and Deep Learning with {NeRF} using {TensorFlow} and {Keras}: Part 3},
journal = {PyImageSearch},
year = {2021},
note = {https://www.pyimagesearch.com/2021/11/24/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-3/},
}
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
So far in this course, we’ve relied on the Tesseract OCR engine to detect the text in an input image. However, as we discovered in a previous tutorial, sometimes Tesseract needs a bit of help before we can actually OCR the text.
This tutorial will explore this idea more, demonstrating that computer vision and image processing techniques can localize text regions in a complex input image. Once the text is localized, we can extract the text ROI from the input image and then OCR it using Tesseract.
As a case study, we’ll be developing a computer vision system that can automatically locate the machine-readable zones (MRZs) in a scan of a passport. The MRZ contains information such as the passport holder’s name, passport number, nationality, date of birth, sex, and passport expiration date.
By automatically OCR’ing this region, we can help Transportation Security Administration (TSA) agents and immigration officials more quickly process travelers, reducing long lines (and not to mention stress and anxiety waiting in the queue).
Learning Objectives
In this tutorial, you will:
Learn how to use image processing techniques and the OpenCV library to localize text in an input image
Extract the localized text and OCR it with Tesseract
Build a sample passport reader project that can automatically detect, extract, and OCR the MRZ in a passport image
Finding Text in Images with Image Processing
In the first part of this tutorial, we’ll briefly review what a passport MRZ is. From there, I’ll show you how to implement a Python script to detect and extract the MRZ from an input image. Once the MRZ is extracted, we can use Tesseract to OCR the MRZ.
What Is a Machine-Readable Zone?
A passport is a travel document that looks like a small notebook. This document is issued by your country’s government and includes information that identifies you personally, including your name, address, etc.
You typically use your passport when traveling internationally. Once you arrive in your destination country, an immigration official checks your passport, validates your identity, and stamps your passport with your arrival date.
Inside your passport, you’ll find your personal identifying information (Figure 1). If you look at the bottom of the passport, you’ll see 2-3 lines of fixed-width characters.
Figure 1. Passport showing 3 lines of fixed-width characters at the bottom.
Type 1 passports have three lines, each with 30 characters, while Type 3 passports have two lines, each with 44 characters.
These lines are called the MRZ of your passport.
The MRZ encodes your personal identifying information, including:
Name
Passport number
Nationality
Date of birth/age
Sex
Passport expiration date
Before computers and MRZs, TSA agents and immigration officials had to review your passport and tediously validate your identity. It was a time-consuming task that was monotonous for the officials and frustrating for travelers who patiently waited for their turn in long immigration lines.
MRZs allow TSA agents to quickly scan your information, validate who you are, and enable you to pass through the queue more quickly, thereby reducing queue length (and reducing the stress on travelers and officials alike).
In the rest of this tutorial, you will learn how to implement an automatic passport MRZ scanner with OpenCV and Tesseract.
Configuring Your Development Environment
To follow this guide, you need to have the OpenCV library installed on your system.
Luckily, OpenCV is pip-installable:
$ pip install opencv-contrib-python
If you need help configuring your development environment for OpenCV, I highly recommend that you read my pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having Problems Configuring Your Development Environment?
Figure 2: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project Structure
We first need to review our project directory structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
Before we can build our MRZ reader and scan passport images, let’s first review the directory structure for this project:
We only have a single Python script here, ocr_passport.py, which, as the name suggests, is used to load passport images from disk and scan them.
Inside the passports directory, we have two images, passport_01.png and passport_02.png — these images contain sample scanned passports. Our ocr_passport.py script will load these images from disk, locate their MRZ regions, and then OCR them.
Locating MRZs in Passport Images
Let’s learn how to locate the MRZ of a passport image using OpenCV and image processing.
Open the ocr_passport.py file in your project directory structure and insert the following code:
# import the necessary packages
from imutils.contours import sort_contours
import numpy as np
import pytesseract
import argparse
import imutils
import sys
import cv2
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,
help="path to input image to be OCR'd")
args = vars(ap.parse_args())
We start on Lines 2-8 by importing our required Python packages. These imports should begin to feel pretty standard to you by this point in the text. The only exception is perhaps the sort_contours import on Line 2 — what does this function do?
The sort_contours function will accept an input set of contours found by using OpenCV’s cv2.findContours function. Then, sort_contours will sort these contours either horizontally (left-to-right or right-to-left) or vertically (top-to-bottom or bottom-to-top).
We perform this sorting operation because OpenCV’s cv2.findContours does not guarantee the ordering of the contours. We’ll need to sort them explicitly to access the MRZ lines at the bottom of the passport image. Performing this sorting operation will make detecting the MRZ region far easier (as we’ll see later in this implementation).
Lines 11-14 parse our command line arguments. Only a single argument is required here, the path to the input --image.
With our imports and command line arguments taken care of, we can move on loading our input image and preparing it for MRZ detection:
# load the input image, convert it to grayscale, and grab its
# dimensions
image = cv2.imread(args["image"])
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
(H, W) = gray.shape
# initialize a rectangular and square structuring kernel
rectKernel = cv2.getStructuringElement(cv2.MORPH_RECT, (25, 7))
sqKernel = cv2.getStructuringElement(cv2.MORPH_RECT, (21, 21))
# smooth the image using a 3x3 Gaussian blur and then apply a
# blackhat morpholigical operator to find dark regions on a light
# background
gray = cv2.GaussianBlur(gray, (3, 3), 0)
blackhat = cv2.morphologyEx(gray, cv2.MORPH_BLACKHAT, rectKernel)
cv2.imshow("Blackhat", blackhat)
Lines 18 and 19 load our input image from disk and then convert it to grayscale, such that we can apply basic image processing routines to it (again, keep in mind that our goal is to detect the MRZ of the passport without having to utilize machine learning). We then grab the spatial dimensions (width and height) of the input image on Line 20.
Lines 23 and 24 initialize two kernels, which we’ll later use when applying morphological operations, specifically the closing operation. For the time being, note that the first kernel is rectangular with a width approximately 3x larger than the height. The second kernel is square. These kernels will allow us to close gaps between MRZ characters and openings between MRZ lines.
Gaussian blurring is applied on Line 29 to reduce high-frequency noise. We then apply a blackhat morphological operation to the blurred, grayscale image on Line 30.
A blackhat operator is used to reveal dark regions (i.e., MRZ text) against light backgrounds (i.e., the passport’s background). Since the passport text is always black on a light background (at least in this dataset), a blackhat operation is appropriate. Figure 3 shows the output of applying the blackhat operator.
Figure 3. Output results of applying the blackhat operator to a passport.
In Figure 3, the left-hand side shows our original input image, while the right-hand side displays the output of the blackhat operation. Notice that the text is visible after this operation, while much of the background noise has been removed.
The next step in MRZ detection is to compute the gradient magnitude representation of the blackhat image using the Scharr operator:
# compute the Scharr gradient of the blackhat image and scale the
# result into the range [0, 255]
grad = cv2.Sobel(blackhat, ddepth=cv2.CV_32F, dx=1, dy=0, ksize=-1)
grad = np.absolute(grad)
(minVal, maxVal) = (np.min(grad), np.max(grad))
grad = (grad - minVal) / (maxVal - minVal)
grad = (grad * 255).astype("uint8")
cv2.imshow("Gradient", grad)
Lines 35 and 36 compute the Scharr gradient along the x-axis of the blackhat image, revealing regions of the image that are dark against a light background and contain vertical changes in the gradient, such as the MRZ text region. We then take this gradient image and scale it back into the range [0, 255]using min/max scaling (Lines 37-39). The resulting gradient image is then displayed on our screen (Figure 4).
Figure 4. Results of min/max scaling, Otsu’s thresholding method, and square closure of our image.
The next step is to try to detect the actual lines of the MRZ:
# apply a closing operation using the rectangular kernel to close
# gaps in between letters -- then apply Otsu's thresholding method
grad = cv2.morphologyEx(grad, cv2.MORPH_CLOSE, rectKernel)
thresh = cv2.threshold(grad, 0, 255,
cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
cv2.imshow("Rect Close", thresh)
# perform another closing operation, this time using the square
# kernel to close gaps between lines of the MRZ, then perform a
# series of erosions to break apart connected components
thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, sqKernel)
thresh = cv2.erode(thresh, None, iterations=2)
cv2.imshow("Square Close", thresh)
First, we apply a closing operation using our rectangular kernel (Lines 44-46). This closing operation is meant to close gaps between MRZ characters. We then apply thresholding using Otsu’s method to automatically threshold the image (Figure 4). As we can see, each of the MRZ lines is present in our threshold map.
We then close the gaps between the actual lines, using a closing operation with our square kernel (Line 52). The sqKernel is a 21 x 21kernel that attempts to close the gaps between the lines, yielding one large rectangular region corresponding to the MRZ.
A series of erosions are then performed to break apart connected components that may have joined during the closing operation (Line 53). These erosions are also helpful in removing small blobs that are irrelevant to the MRZ.
The result of these operations can be seen in Figure 4. Notice how the MRZ region is a large rectangular blob in the bottom third of the image.
Now that our MRZ region is visible, let’s find contours in the thresh image — this process will allow us to detect and extract the MRZ region:
# find contours in the thresholded image and sort them from bottom
# to top (since the MRZ will always be at the bottom of the passport)
cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(cnts)
cnts = sort_contours(cnts, method="bottom-to-top")[0]
# initialize the bounding box associated with the MRZ
mrzBox = None
Lines 58-61 detect contours in the thresholded image. We then sort them bottom-to-top. Why bottom-to-top, you may ask?
Simple: the MRZ region is always located in the bottom third of the input passport image. We use this a priori knowledge to exploit the structure of the image. If we know we are looking for a large rectangular region that always appears at the bottom of the image, why not search the bottom first?
Whenever applying image processing operations, always see if there is a way you can exploit your knowledge of the problem. Don’t overcomplicate your image processing pipeline. Use any domain knowledge to make the problem simpler.
Line 64 then initializes mrzBox, the bounding box associated with the MRZ region.
We’ll attempt to find the mrzBox in the following code block:
# loop over the contours
for c in cnts:
# compute the bounding box of the contour and then derive the
# how much of the image the bounding box occupies in terms of
# both width and height
(x, y, w, h) = cv2.boundingRect(c)
percentWidth = w / float(W)
percentHeight = h / float(H)
# if the bounding box occupies > 80% width and > 4% height of the
# image, then assume we have found the MRZ
if percentWidth > 0.8 and percentHeight > 0.04:
mrzBox = (x, y, w, h)
break
We start a loop over the detecting contours on Line 67. We compute the bounding box for each contour and then determine the percent of the image the bounding box occupies (Lines 72 and 73).
We compute how large the bounding box is (w.r.t. the original input image) to filter our contours. Remember that our MRZ is a large rectangular region that spans near the passport’s entire width.
Therefore, Line 77 takes advantage of this knowledge by making sure the detected bounding box spans at least 80% of the image’s width along with 4% of the height. Provided that the current bounding box region passes those tests, we update our mrzBox and break from the loop.
We can now move on to processing the MRZ region itself:
# if the MRZ was not found, exit the script
if mrzBox is None:
print("[INFO] MRZ could not be found")
sys.exit(0)
# pad the bounding box since we applied erosions and now need to
# re-grow it
(x, y, w, h) = mrzBox
pX = int((x + w) * 0.03)
pY = int((y + h) * 0.03)
(x, y) = (x - pX, y - pY)
(w, h) = (w + (pX * 2), h + (pY * 2))
# extract the padded MRZ from the image
mrz = image[y:y + h, x:x + w]
Lines 82-84 handle the case where no MRZ region was found — here, we exit the script. This could happen if the image that does not contain a passport is accidentally passed through the script or if the passport image was low quality/too noisy for our basic image processing pipeline to handle.
Provided we did indeed find the MRZ, the next step is to pad the bounding box region. We performed this padding because we applied a series of erosions (back on Line 53) when attempting to detect the MRZ itself.
However, we need to pad this region so that the MRZ characters are not touching the ROI’s borders. If the characters touch the image’s border, Tesseract’s OCR procedure may not be accurate.
Line 88 unpacks the bounding box coordinates. We then pad the MRZ region by 3% in each direction (Lines 89-92).
Once the MRZ is padded, we extract it from the image using array slicing (Line 95).
With the MRZ extracted, the final step is to apply Tesseract to OCR it:
# OCR the MRZ region of interest using Tesseract, removing any
# occurrences of spaces
mrzText = pytesseract.image_to_string(mrz)
mrzText = mrzText.replace(" ", "")
print(mrzText)
# show the MRZ image
cv2.imshow("MRZ", mrz)
cv2.waitKey(0)
Line 99 OCRs the MRZ region of the passport. We then explicitly remove any spaces from the MRZ text (Line 100) as Tesseract may have accidentally introduced spaces during the OCR process.
We then wrap up our passport OCR implementation by displaying the OCR’d mrzText on our terminal and displaying the final mrz ROI on our screen. You can see the result in Figure 5.
Figure 5. MRZ extracted results from our image processing pipeline.
Text Blob Localization Results
We are now ready to put our text localization script to the test.
Open a terminal and execute the following command:
Figure 6(left) shows our original input image, while Figure 6(right) displays the MRZ extracted via our image processing pipeline. Our terminal output shows that we’ve correctly OCR’d the MRZ area using Tesseract.
Figure 6. Original image and MRZ extracted results from our image processing pipeline.
Let’s try another passport image, this one a Type-1 passport with three MRZ lines instead of two:
As Figure 7 shows, we detected the MRZ in the input image and then extracted it. The MRZ was then passed into Tesseract for OCR, of which our terminal output shows the result.
Figure 7. Original image on the left and MRZ extracted results from our image processing pipeline on the right.
However, our MRZ OCR is not 100% accurate — notice there is an “L” between the “T” and “I” in “KATIA.”
For higher OCR accuracy, we should consider training a custom Tesseract model specifically on the fonts used in passports, making it easier for Tesseract to recognize these characters.
Course information:
28 total classes • 39h 44m video • Last updated: 10/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 28 courses on essential computer vision, deep learning, and OpenCV topics
✓ 28 Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 400+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this tutorial, you learned how to implement an OCR system capable of localizing, extracting, and OCR’ing the text in the MRZ of a passport.
When you build your own OCR applications, don’t blindly throw Tesseract at them and see what sticks. Instead, carefully examine the problem as a computer vision practitioner.
Ask yourself:
Can I use image processing to localize the text in an image, thereby reducing my reliance on Tesseract text localization?
Can I use OpenCV functions to extract these regions automatically?
What image processing steps would be required to detect the text?
The image processing pipeline presented in this tutorial is an example of a text localization pipeline you can build. It will not work in all situations. Still, computing gradients and using morphological operations to close gaps in the text will work in a surprising number of applications.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
In this tutorial, you will learn how to train a DCGAN to generate fashion images in color. You will learn the common challenges, techniques to address these challenges, and GAN evaluation metrics through the training process.
This lesson is the third post of a GAN tutorial series:
In my previous post, Get Started: DCGAN for Fashion-MNIST, you learned how to train a DCGAN to generate grayscale Fashion-MNIST images. In this post, let’s train a DCGAN with color images to demonstrate the common challenges of GAN training. We will also briefly discuss some improvement techniques and GAN evaluation metrics. Please follow the tutorial with the Colab notebook here for a complete code example.
DCGAN for Color Images
We will take the DCGAN code from my previous post as the baseline and then make adjustments to train color images. Since we already walked through the DCGAN training end-to-end in detail in my previous post, now we will focus only on the key changes needed to train DCGAN for color images:
Data: download the color images from Kaggle and preprocess them to the range of [-1, 1].
Generator: adjust how to upsample the model architecture to generate a color image.
Discriminator: adjust the input image shape from 28×28×1 to 64×64×3.
With these changes, you can start training the DCGAN on the color image; however, when working with color images or any data other than MNIST or Fashion-MNIST, you will realize how challenging GAN training can be. Even training with Fashion-MNIST grayscale images could be tricky.
1. Prepare the Data
We will train the DCGAN with a dataset called Clothing & Models from Kaggle, which is a collection of clothing pieces scraped from Zalando.com. There are six categories and over 16k color images in the size of 606×875, which will be resized to 64×64 for training.
To download data from Kaggle, you will need to provide your Kaggle credential. You could either upload the Kaggle json file to Colab or put your Kaggle user name and key in the notebook. We chose the latter option.
Then we use Keras’ image_dataset_from_directory to create a tf.data.Dataset from the images in the directory, which will be used for training the model later on. Finally, we specify the image size of 64×64 and a batch size of 32.
Same as before, we normalize the images to the range of [-1, 1] because the generator’s final layer activation uses tanh. Finally, we apply the normalization by using the map function of tf.dataset with a lambda function.
We create the generator architecture with the keras Sequential API in the build_generator function. We already went through the details of how to create the generator architecture in my previous DCGAN post. Here let’s look at how to adjust the upsampling to generate the desired color image size of 64×64×3:
We update CHANNELS = 3 for color images instead of 1, which is for grayscale images.
A stride of 2 halves the width and height so you can work backward to figure out the initial image size dimension: for Fashion-MNIST, we upsampled as 7 -> 14 -> 28. Now we are working with a training image size of 64×64, so we upsample a few times as 8 -> 16 -> 32 -> 64. This means we add one more set of Conv2DTranspose -> BatchNormalization -> ReLU.
Another change made to the generator is to update kernel size from 5 to 4 to avoid reducing checkerboard artifacts in the generated images (see Figure 2).
Figure 2: Checkerboard artifacts (image by the author).
This is because the kernel size of 5 is not divisible by the stride of 2, according to the post Deconvolution and Checkerboard Artifacts. So the solution is to use a kernel size of 4 instead of 5.
We can visualize the DCGAN generator architecture in Figure 3:
Figure 3: Generator architecture diagram (image by the author).
Visualize the generator architecture in code by calling generator.summary() in Figure 4:
Figure 4: Generator architecture with Keras code (image by the author).
3. Discriminator
The main change in the discriminator architecture is the image input shape: we are using the shape of [64, 64, 3] instead of [28, 28, 1]. We also added one more set of Conv2D -> BatchNormalization -> LeakyReLU to balance out the increased architecture complexity in the generator as mentioned above. Everything else remains the same.
We can visualize the DCGAN discriminator architecture in Figure 5:
Figure 5: Discriminator architecture diagram (image by the author).
Visualize the discriminator architecture in code by calling discriminator.summary() in Figure 6:
Figure 6: Discriminator architecture with Keras code (image by the author).
The DCGAN Model
Again we define the DCGAN model architecture by subclass keras.Model and override train_step to define the custom training loops. The only slight change in code is to apply one-sided label smoothing to the real labels.
This technique reduces the overconfidence of the discriminator and therefore helps stabilize the GAN training. Refer to Adrian Rosebrock’s post Label smoothing with Keras, TensorFlow, and Deep Learning for details on label smoothing in general. The “one-sided label smoothing” technique for regularizing GAN training is proposed in the paper Improved Techniques for Training GANs, where you may find other improvement techniques as well.
Define Kera Callback for Training Monitoring
Same code with no change — override Keras Callback to monitor and visualize the generated images during training.
Compile the dcgan model, and the main change is the learning rate. Here I have set the discriminator learning rate as 0.0001 and generator learning rate as 0.0003. This is to make sure that the discriminator doesn’t overpower the generator.
Now we simply call model.fit() to train the dcgan model!
NUM_EPOCHS = 50 # number of epochs
dcgan.fit(train_images, epochs=NUM_EPOCHS,
callbacks=[GANMonitor(num_img=16, latent_dim=LATENT_DIM)])
Here are the screenshots with images created by the generator throughout the DCGAN training process (Figure 7):
Figure 7: DCGAN for Fashion Color Images Training Results (image by the author).
GAN Training Challenges
Now that we have finished training DCGAN with color images. Let’s discuss some of the common challenges of GAN training.
GANs are very difficult to train, and here are some of the well-known challenges:
Non-convergence: instability, vanishing gradients, or slow training
Mode collapse
Difficult to evaluate
Failure to Converge
Unlike training other models such as an image classifier, the losses or accuracy of D and G during training only measure D and G individually and doesn’t measure the GAN overall performance and how good the generator is at creating images. The GAN model is “good” when an equilibrium is reached between the generator and discriminator, typically when the discriminator’s loss is around 0.5.
GAN training instability: it’s difficult to keep D and G balanced to reach an equilibrium. Looking at the losses during training, you will notice they may oscillate wildly. And both D and G could get stuck and never improve. Training for a long time doesn’t always make the generator better. The image quality by the generator may deteriorate over time.
Vanishing gradient: in the custom training loop, we went over how to calculate the discriminator and generator losses, compute gradients and then use the gradients to make updates.The generator relies on the discriminator’s feedback to make improvements. If the discriminator is so strong that it overpowers the generator: it can tell each time there is a fake image, then the generator stops making progress in its training.
You may notice that sometimes the generated images stay as poor quality even after training for a while. This means the model fails to find an equilibrium between the discriminator and generator.
Experiment: Make D architecture much stronger (more parameters in model architecture) or train faster than G (e.g., increase D’s learning rate to be much higher than G’s).
Mode Collapse
Mode collapse occurs when the generator produces the same images or a small subset of the training images repeatedly. A good generator should make a wide variety of images that resemble the training images in all its categories. Mode collapse happens when the discriminator can’t tell the generated images are fake, so the generator keeps producing those same images to fool the discriminator.
Experiment: to simulate the mode collapse issue in the code, try reducing the noise vector dimension from 100 to 10; or increase the noise vector dimension from 100 to 128 to increase image diversity.
Difficult to Evaluate
It’s challenging to evaluate the GAN models because there is no easy way to determine whether a generated image is “good.” Unlike an image classifier, the prediction is either correct or incorrect according to the ground truth label. This leads to the discussion below on how we evaluate GAN models.
GAN Evaluation Metrics
There are two criteria for a successful generator — it should generate images with:
good quality: high fidelity and realistic,
diversity (or variety): a good representation of the training images’ different types (or categories).
We can evaluate the model either qualitatively (visually inspect images) or quantitatively with some metrics.
Qualitative evaluation via visual inspection. As we did in the DCGAN training, we look at a set of images generated on the same seed and visually inspect whether the images look better as training goes on. This works for a toy example, but it’s too labor-intensive for large-scale training.
Inception Score (IS) and Fréchet Inception Distance (FID) are two popular metrics to compare GAN models quantitatively.
The Inception Score was introduced in this paper: Improved Techniques for Training GANs. It measures both the quality and diversity of the generated images. The idea is to use the inception model to classify the generated images and use the predictions to evaluate the generator. A higher score indicates the model is better.
The Fréchet Inception Distance (FID) also uses the inception network for feature extraction and calculates the data distribution. FID improves upon IS by looking at both the generated images and training images instead of only the generated images in isolation. A lower FID means the generated images are more similar to the real images, therefore a better GAN model.
Course information:
30+ total classes • 39h 44m video • Last updated: 12/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 30+ courses on essential computer vision, deep learning, and OpenCV topics
✓ 30+ Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 500+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
In this post, you have learned how to train a DCGAN to generate fashion images in color. You have also learned about the common challenges of GAN training, some improvement techniques, and the GAN evaluation metrics. In my next post, we will learn how to further improve training stability with Wasserstein GAN (WGAN) and Wasserstein GAN with Gradient Penalty (WGAN-GP).
@article{Maynard-Reid_2021_GAN_Training,
author = {Margaret Maynard-Reid},
title = {{GAN} Training Challenges: {DCGAN} for Color Images},
journal = {PyImageSearch},
year = {2021},
note = {https://www.pyimagesearch.com/2021/12/13/gan-training-challenges-dcgan-for-color-images/},
}
Want free GPU credits to train models?
We used Jarvislabs.ai, a GPU cloud, for all the experiments.
We are proud to offer PyImageSearch University students $20 worth of Jarvislabs.ai GPU cloud credits. Join PyImageSearch University and claim your $20 credit here.
In Deep Learning, we need to train Neural Networks. These Neural Networks can be trained on a CPU but take a lot of time. Moreover, sometimes these networks do not even fit (run) on a CPU.
To overcome this problem, we use GPUs. The problem is these GPUs are expensive and become outdated quickly.
GPUs are great because they take your Neural Network and train it quickly. The problem is that GPUs are expensive, so you don’t want to buy one and use it only occasionally. Cloud GPUs let you use a GPU and only pay for the time you are running the GPU. It’s a brilliant idea that saves you money.
JarvisLabs provides the best-in-class GPUs, and PyImageSearch University students get between 10 - 50 hours on a world-class GPU (time depends on the specific GPU you select).
This gives you a chance to test-drive a monstrously powerful GPU on any of our tutorials in a jiffy. So join PyImageSearch University today and try for yourself.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!
It was 2020 when my friends and I worked night and day to finish our final year project. Like most students in our year, we decided it was a good idea to leave it till the very end.
That wasn’t the brightest idea from our end. What followed were neverending nights of constant model calibration and training, burning through gigabytes of cloud storage, and maintaining records for deep learning model results.
The environment that we had created for ourselves did not only harm our efficiency, but it affected our morale. Due to the sheer individual brilliance of my other teammates, we managed to complete our project.
In retrospect, I realized how much more efficient our work could have been — and so much more enjoyable — had we chosen a better ecosystem to work in.
Fortunately, you don’t have to make the same mistakes I made.
The creators of PyTorch often emphasized that a key intention behind this initiative is to bridge the gap between research and production. PyTorch now stands toe to toe with its contemporaries on many fronts, being utilized equally in both research and production ecosystems.
One of the ways they’ve achieved this is through Torch Hub. Torch Hub as a concept was conceived to further extend PyTorch’s credibility as a production-based framework. In today’s tutorial, we’ll learn how to utilize Torch Hub to store and publish pre-trained models for wide-scale use.
What Is Torch Hub?
In Computer Science, many believe that a key puzzle piece in the bridge between research and production is reproducibility. Building on that very notion, PyTorch introduced Torch Hub, an Application Programmable Interface (API), which allows two programs to interact with each other and enhances the workflow for easy research reproducibility.
Torch Hub lets you publish pre-trained models to help in the cause of research sharing and reproducibility. The process of harnessing Torch Hub is simple, but before moving further, let’s configure the prerequisites of our system!
Configuring Your Development Environment
To follow this guide, you need to have the OpenCV library installed on your system.
Luckily, OpenCV is pip-installable:
$ pip install opencv-contrib-python
If you need help configuring your development environment for OpenCV, I highly recommend that you read our pip install OpenCV guide — it will have you up and running in a matter of minutes.
Having Problems Configuring Your Development Environment?
Figure 1: Having trouble configuring your dev environment? Want access to pre-configured Jupyter Notebooks running on Google Colab? Be sure to join PyImageSearch University — you’ll be up and running with this tutorial in a matter of minutes.
All that said, are you:
Short on time?
Learning on your employer’s administratively locked system?
Wanting to skip the hassle of fighting with the command line, package managers, and virtual environments?
Ready to run the code right now on your Windows, macOS, or Linux system?
Gain access to Jupyter Notebooks for this tutorial and other PyImageSearch guides that are pre-configured to run on Google Colab’s ecosystem right in your web browser! No installation required.
And best of all, these Jupyter Notebooks will run on Windows, macOS, and Linux!
Project Structure
We first need to review our project structure.
Start by accessing the “Downloads” section of this tutorial to retrieve the source code and example images.
Before moving to the directory, let’s take a look at the project structure in Figure 2.
Figure 2: Project Overview.
Today, we’ll be working with two directories. This is to help you better understand the use of Torch Hub.
The subdirectory is where we’ll initialize and train our model. Here, we’ll create a hubconf.py script. The hubconf.py script contains callable functions called entry_points. These callable functions initialize and return the models which the user requires. Hence, this script will connect our own created model to Torch Hub.
In our main Project Directory, we’ll be using torch.hub.load to load our model from Torch Hub. After loading the model with pre-trained weights, we’ll evaluate it on some sample data.
A Generalized Outlook on Torch Hub
Torch Hub already hosts an array of models for various tasks, as seen in Figure 3.
As you can see, there are a total of 42 research models that Torch Hub has accepted in its official showcase. Each model belongs to one or more of the following labels: Audio, Generative, Natural Language Processing (NLP), scriptable, and vision. These models have also been trained on widely accepted benchmark datasets (e.g., Kinetics 400 and COCO 2017).
It’s easy to use these models in your projects using the torch.hub.load function. Let’s look at an example of how it works.
(If you want to know more about DCGANs, do check out this blog.)
# USAGE
# python inference.py
# import the necessary packages
import matplotlib.pyplot as plt
import torchvision
import argparse
import torch
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-n", "--num-images", type=int, default=64,
help="# of images you want the DCGAN to generate")
args = vars(ap.parse_args())
# check if gpu is available for use
useGpu = True if torch.cuda.is_available() else False
# load the DCGAN model
model = torch.hub.load("facebookresearch/pytorch_GAN_zoo:hub", "DCGAN",
pretrained=True, useGPU=useGpu)
On Lines 11-14, we created an argument parser to give the user more freedom to choose the batch size of generated images.
To use the Facebook Research pretrained DCGAN model, we just need the torch.hub.load function as shown on Lines 20 and 21. The torch.hub.load function here takes in the following arguments:
repo_or_dir: The repository name in the format repo_owner/repo_name:branch/tag_name if the source argument is set to github. Otherwise, it will point to the desired path in your local machine.
entry_point: To publish a model in torch hub, you need to have a script called hubconf.py in your repository/directory. In that script, you’ll define normal callable functions known as entry points. Calling the entry points to return the desired models. You’ll learn more about entry_point later in this blog.
pretrained and useGpu: These fall under the *args or the arguments banner of this function. These arguments are for the callable model.
Now, this isn’t the only major function Torch Hub offers. You can use several other notable functions like torch.hub.list to list all available entry points (callable functions) belonging to the repository, and torch.hub.help to show the documentation docstring of the target entry point.
# generate random noise to input to the generator
(noise, _) = model.buildNoiseData(args["num_images"])
# turn off autograd and feed the input noise to the model
with torch.no_grad():
generatedImages = model.test(noise)
# reconfigure the dimensions of the images to make them channel
# last and display the output
output = torchvision.utils.make_grid(generatedImages).permute(
1, 2, 0).cpu().numpy()
plt.imshow(output)
plt.show()
On Line 24, we use a function exclusive to the called model named buildNoiseData to generate random input noise, keeping the input size in mind.
Turning off automatic gradients (Line 27), we generate images by feeding the noise to the model.
Before plotting the images, we do a dimensional re-shaping of the images on Lines 32-35 (since PyTorch works with channel first tensors, we need to make them channel last again). The output will look like Figure 4.
Figure 4: DCGAN output.
Voila! This is all you need to use a pre-trained state-of-the-art DCGAN model for your purposes. Using the pre-trained models in Torch Hub is THAT easy. However, we are not stopping there, are we?
Calling a pre-trained model to see how the latest state-of-the-art research performs is fine, but what about when we produce state-of-the-art results using our research? For that, we’ll next learn how to publish our own created models on Torch Hub.
Today, we’ll train our simple neural network and publish it using Torch Hub. I will not go into a full dissection of the code since a tutorial for that already exists. For a detailed and precise dive into building a simple neural network, refer to this blog.
Building a Simple Neural Network
Next, we’ll go through the salient parts of the code. For that, we’ll be moving into the subdirectory. First, let’s build our simple neural network in mlp.py!
# import the necessary packages
from collections import OrderedDict
import torch.nn as nn
# define the model function
def get_training_model(inFeatures=4, hiddenDim=8, nbClasses=3):
# construct a shallow, sequential neural network
mlpModel = nn.Sequential(OrderedDict([
("hidden_layer_1", nn.Linear(inFeatures, hiddenDim)),
("activation_1", nn.ReLU()),
("output_layer", nn.Linear(hiddenDim, nbClasses))
]))
# return the sequential model
return mlpModel
The get_training_model function on Line 6 takes in parameters (input size, hidden layer size, output classes). Inside the function, we use nn.Sequential to create a 2-layered neural network, consisting of a single hidden layer with ReLU activator and an output layer (Lines 8-12).
Training the Neural Network
We won’t be using any external dataset to train the model. Instead, we’ll generate data points ourselves. Let’s hop into train.py.
# import the necessary packages
from pyimagesearch import mlp
from torch.optim import SGD
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs
import torch.nn as nn
import torch
import os
# define the path to store your model weights
MODEL_PATH = os.path.join("output", "model_wt.pth")
# data generator function
def next_batch(inputs, targets, batchSize):
# loop over the dataset
for i in range(0, inputs.shape[0], batchSize):
# yield a tuple of the current batched data and labels
yield (inputs[i:i + batchSize], targets[i:i + batchSize])
# specify our batch size, number of epochs, and learning rate
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-2
# determine the device we will be using for training
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("[INFO] training using {}...".format(DEVICE))
First, we create a path to save the trained model weights on Line 11, which will be used later. The next_batch function on Lines 14-18 will act as the data generator for our project, yielding batches of data for efficient training.
Next, we set up hyperparameters (Lines 21-23) and set our DEVICE to cuda if a compatible GPU is available (Line 26).
# generate a 3-class classification problem with 1000 data points,
# where each data point is a 4D feature vector
print("[INFO] preparing data...")
(X, y) = make_blobs(n_samples=1000, n_features=4, centers=3,
cluster_std=2.5, random_state=95)
# create training and testing splits, and convert them to PyTorch
# tensors
(trainX, testX, trainY, testY) = train_test_split(X, y,
test_size=0.15, random_state=95)
trainX = torch.from_numpy(trainX).float()
testX = torch.from_numpy(testX).float()
trainY = torch.from_numpy(trainY).float()
testY = torch.from_numpy(testY).float()
On Lines 32 and 33, we use the make_blobs function to mimic data points of an actual three-class dataset. Using scikit-learn’strain_test_split function, we create the training and test splits of the data.
# initialize our model and display its architecture
mlp = mlp.get_training_model().to(DEVICE)
print(mlp)
# initialize optimizer and loss function
opt = SGD(mlp.parameters(), lr=LR)
lossFunc = nn.CrossEntropyLoss()
# create a template to summarize current training progress
trainTemplate = "epoch: {} test loss: {:.3f} test accuracy: {:.3f}"
On Line 45, we call the get_training_model function from the mlp.py module and initialize the model.
We choose stochastic gradient descent as the optimizer (Line 49) and Cross-Entropy loss as the loss function (Line 50).
The trainTemplate variable on Line 53 will act as a string template to print accuracy and loss.
# loop through the epochs
for epoch in range(0, EPOCHS):
# initialize tracker variables and set our model to trainable
print("[INFO] epoch: {}...".format(epoch + 1))
trainLoss = 0
trainAcc = 0
samples = 0
mlp.train()
# loop over the current batch of data
for (batchX, batchY) in next_batch(trainX, trainY, BATCH_SIZE):
# flash data to the current device, run it through our
# model, and calculate loss
(batchX, batchY) = (batchX.to(DEVICE), batchY.to(DEVICE))
predictions = mlp(batchX)
loss = lossFunc(predictions, batchY.long())
# zero the gradients accumulated from the previous steps,
# perform backpropagation, and update model parameters
opt.zero_grad()
loss.backward()
opt.step()
# update training loss, accuracy, and the number of samples
# visited
trainLoss += loss.item() * batchY.size(0)
trainAcc += (predictions.max(1)[1] == batchY).sum().item()
samples += batchY.size(0)
# display model progress on the current training batch
trainTemplate = "epoch: {} train loss: {:.3f} train accuracy: {:.3f}"
print(trainTemplate.format(epoch + 1, (trainLoss / samples),
(trainAcc / samples)))
Looping over the training epochs, we initialize the losses (Lines 59-61) and set the model to training mode (Line 62).
Using the next_batch function, we iterate through batches of training data (Line 65). After loading them to the device (Line 68), the predictions for the data batch are obtained on Line 69. These predictions are then fed to the loss function for loss calculation (Line 70).
The gradients are flushed using zero_grad (Line 74), followed by backpropagation on Line 75. Finally, the optimizer parameter is updated on Line 76.
For each epoch, the training loss, accuracy, and sample size variables are upgraded (Lines 80-82) and displayed using the template on Line 85.
# initialize tracker variables for testing, then set our model to
# evaluation mode
testLoss = 0
testAcc = 0
samples = 0
mlp.eval()
# initialize a no-gradient context
with torch.no_grad():
# loop over the current batch of test data
for (batchX, batchY) in next_batch(testX, testY, BATCH_SIZE):
# flash the data to the current device
(batchX, batchY) = (batchX.to(DEVICE), batchY.to(DEVICE))
# run data through our model and calculate loss
predictions = mlp(batchX)
loss = lossFunc(predictions, batchY.long())
# update test loss, accuracy, and the number of
# samples visited
testLoss += loss.item() * batchY.size(0)
testAcc += (predictions.max(1)[1] == batchY).sum().item()
samples += batchY.size(0)
# display model progress on the current test batch
testTemplate = "epoch: {} test loss: {:.3f} test accuracy: {:.3f}"
print(testTemplate.format(epoch + 1, (testLoss / samples),
(testAcc / samples)))
print("")
# save model to the path for later use
torch.save(mlp.state_dict(), MODEL_PATH)
We set the model to eval mode for model evaluation and do the same during the training phase, except for backpropagation.
On Line 121, we have the most important step of saving the model weights for later use.
Let’s assess the epoch-wise performance of our model!
[INFO] training using cpu...
[INFO] preparing data...
Sequential(
(hidden_layer_1): Linear(in_features=4, out_features=8, bias=True)
(activation_1): ReLU()
(output_layer): Linear(in_features=8, out_features=3, bias=True)
)
[INFO] epoch: 1...
epoch: 1 train loss: 0.798 train accuracy: 0.649
epoch: 1 test loss: 0.788 test accuracy: 0.613
[INFO] epoch: 2...
epoch: 2 train loss: 0.694 train accuracy: 0.665
epoch: 2 test loss: 0.717 test accuracy: 0.613
[INFO] epoch: 3...
epoch: 3 train loss: 0.635 train accuracy: 0.669
epoch: 3 test loss: 0.669 test accuracy: 0.613
...
[INFO] epoch: 7...
epoch: 7 train loss: 0.468 train accuracy: 0.693
epoch: 7 test loss: 0.457 test accuracy: 0.740
[INFO] epoch: 8...
epoch: 8 train loss: 0.385 train accuracy: 0.861
epoch: 8 test loss: 0.341 test accuracy: 0.973
[INFO] epoch: 9...
epoch: 9 train loss: 0.286 train accuracy: 0.980
epoch: 9 test loss: 0.237 test accuracy: 0.993
[INFO] epoch: 10...
epoch: 10 train loss: 0.211 train accuracy: 0.985
epoch: 10 test loss: 0.173 test accuracy: 0.993
Since we are training on data generated by paradigms we set, our training process went smoothly, reaching a final training accuracy of 0.985.
Configuring the hubconf.py script
With model training complete, our next step is to configure the hubconf.py file in the repo to make our model accessible through Torch Hub.
# import the necessary packages
import torch
from pyimagesearch import mlp
# define entry point/callable function
# to initialize and return model
def custom_model():
""" # This docstring shows up in hub.help()
Initializes the MLP model instance
Loads weights from path and
returns the model
"""
# initialize the model
# load weights from path
# returns model
model = mlp.get_training_model()
model.load_state_dict(torch.load("model_wt.pth"))
return model
As mentioned earlier, we have created an entry point called custom_model on Line 7. Inside the entry_point, we initialize the simple neural network from the mlp.py module (Line 16). Next, we load the weights we previously saved (Line 17). This current setup is made such that this function will look for the model weights in your project directory. You can host the weights on a cloud platform and choose the path accordingly.
Now, we’ll use Torch Hub to access this model and test it on our data.
Using torch.hub.load to Call Our Model
Coming back to our main project directory, let’s hop into the hub_usage.py script.
# USAGE
# python hub_usage.py
# import the necessary packages
from pyimagesearch.data_gen import next_batch
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_blobs
import torch.nn as nn
import argparse
import torch
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-b", "--batch-size", type=int, default=64,
help="input batch size")
args = vars(ap.parse_args())
After importing the necessary packages, we create an argument parser (Lines 13-16) for the user to input the batch size for the data.
# load the model using torch hub
print("[INFO] loading the model using torch hub...")
model = torch.hub.load("cr0wley-zz/torch_hub_test:main",
"custom_model")
# generate a 3-class classification problem with 1000 data points,
# where each data point is a 4D feature vector
print("[INFO] preparing data...")
(X, Y) = make_blobs(n_samples=1000, n_features=4, centers=3,
cluster_std=2.5, random_state=95)
# create training and testing splits, and convert them to PyTorch
# tensors
(trainX, testX, trainY, testY) = train_test_split(X, Y,
test_size=0.15, random_state=95)
trainX = torch.from_numpy(trainX).float()
testX = torch.from_numpy(testX).float()
trainY = torch.from_numpy(trainY).float()
testY = torch.from_numpy(testY).float()
On Lines 20 and 21, we use torch.hub.load to initialize our own model, the same way we had loaded the DCGAN model as shown earlier. The model has been initialized and the weights have been loaded according to the entry point in the hubconf.py script in our subdirectory. As you can notice, we give the subdirectory github repository as the parameter.
Now, for evaluation of the model, we’ll create data the same way we had created during our model training (Lines 26 and 27) and use train_test_split to create data splits (Lines 31-36).
# initialize the neural network loss function
lossFunc = nn.CrossEntropyLoss()
# set device to cuda if available and initialize
# testing loss and accuracy
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
testLoss = 0
testAcc = 0
samples = 0
# set model to eval and grab a batch of data
print("[INFO] setting the model in evaluation mode...")
model.eval()
(batchX, batchY) = next(next_batch(testX, testY, args["batch_size"]))
On Line 39, we initialize the cross-entropy loss function as done during the model training. We proceed to initialize the evaluation metrics on Lines 44-46.
The model is set to evaluation mode (Line 50), and a single batch of data is grabbed to be evaluated upon by the model (Line 51).
# initialize a no-gradient context
with torch.no_grad():
# load the data into device
(batchX, batchY) = (batchX.to(DEVICE), batchY.to(DEVICE))
# pass the data through the model to get the output and calculate
# loss
predictions = model(batchX)
loss = lossFunc(predictions, batchY.long())
# update test loss, accuracy, and the number of
# samples visited
testLoss += loss.item() * batchY.size(0)
testAcc += (predictions.max(1)[1] == batchY).sum().item()
samples += batchY.size(0)
print("[INFO] test loss: {:.3f}".format(testLoss / samples))
print("[INFO] test accuracy: {:.3f}".format(testAcc / samples))
Turning off the automatic gradients (Line 54), we load the batch of data to the device and feed it to the model (Lines 56-60). The lossFunc proceeds to calculate the loss on Line 61.
With the help of the loss, we update the accuracy variable on Line 66, along with some other metrics like sample size (Line 67).
Let’s see how the model fared!
[INFO] loading the model using torch hub...
[INFO] preparing data...
[INFO] setting the model in evaluation mode...
Using cache found in /root/.cache/torch/hub/cr0wley-zz_torch_hub_test_main
[INFO] test loss: 0.086
[INFO] test accuracy: 0.969
Since we created our test data using the same paradigms used during training the model, it performed as expected, with a test accuracy of 0.969.
Course information:
30+ total classes • 39h 44m video • Last updated: 12/2021 ★★★★★ 4.84 (128 Ratings) • 3,000+ Students Enrolled
I strongly believe that if you had the right teacher you could master computer vision and deep learning.
Do you think learning computer vision and deep learning has to be time-consuming, overwhelming, and complicated? Or has to involve complex mathematics and equations? Or requires a degree in computer science?
That’s not the case.
All you need to master computer vision and deep learning is for someone to explain things to you in simple, intuitive terms. And that’s exactly what I do. My mission is to change education and how complex Artificial Intelligence topics are taught.
If you're serious about learning computer vision, your next stop should be PyImageSearch University, the most comprehensive computer vision, deep learning, and OpenCV course online today. Here you’ll learn how to successfully and confidently apply computer vision to your work, research, and projects. Join me in computer vision mastery.
Inside PyImageSearch University you'll find:
✓ 30+ courses on essential computer vision, deep learning, and OpenCV topics
✓ 30+ Certificates of Completion
✓ 39h 44m on-demand video
✓ Brand new courses released every month, ensuring you can keep up with state-of-the-art techniques
✓ Pre-configured Jupyter Notebooks in Google Colab
✓ Run all code examples in your web browser — works on Windows, macOS, and Linux (no dev environment configuration required!)
✓ Access to centralized code repos for all 500+ tutorials on PyImageSearch
✓ Easy one-click downloads for code, datasets, pre-trained models, etc.
I cannot emphasize enough how important the reproduction of results is in today’s research world. Especially in machine learning, we’ve slowly reached a point where novel research ideas are getting more complex day by day. In a situation like that, researchers having a platform to easily make their research and results public takes a huge burden.
When you already have enough things to worry about as a researcher, having the tool to make your model and results public using a single script and a few lines of code is a great boon for us. Of course, as a project, Torch Hub will evolve more according to the user’s needs as days progress. Regardless of that, the ecosystem advocated by the creation of Torch Hub will help Machine Learning enthusiasts for generations to come.
@article{dev_2021_THS1,
author = {Devjyoti Chakraborty},
title = {{Torch Hub} Series \#1: Introduction to {Torch Hub}},
journal = {PyImageSearch},
year = {2021},
note = {https://www.pyimagesearch.com/2021/12/20/torch-hub-series-1-introduction-to-torch-hub/},
}
Want free GPU credits to train models?
We used Jarvislabs.ai, a GPU cloud, for all the experiments.
We are proud to offer PyImageSearch University students $20 worth of Jarvislabs.ai GPU cloud credits. Join PyImageSearch University and claim your $20 credit here.
In Deep Learning, we need to train Neural Networks. These Neural Networks can be trained on a CPU but take a lot of time. Moreover, sometimes these networks do not even fit (run) on a CPU.
To overcome this problem, we use GPUs. The problem is these GPUs are expensive and become outdated quickly.
GPUs are great because they take your Neural Network and train it quickly. The problem is that GPUs are expensive, so you don’t want to buy one and use it only occasionally. Cloud GPUs let you use a GPU and only pay for the time you are running the GPU. It’s a brilliant idea that saves you money.
JarvisLabs provides the best-in-class GPUs, and PyImageSearch University students get between 10 - 50 hours on a world-class GPU (time depends on the specific GPU you select).
This gives you a chance to test-drive a monstrously powerful GPU on any of our tutorials in a jiffy. So join PyImageSearch University today and try for yourself.
To download the source code to this post (and be notified when future tutorials are published here on PyImageSearch), simply enter your email address in the form below!
Download the Source Code and FREE 17-page Resource Guide
Enter your email address below to get a .zip of the code and a FREE 17-page Resource Guide on Computer Vision, OpenCV, and Deep Learning. Inside you'll find my hand-picked tutorials, books, courses, and libraries to help you master CV and DL!