What is transfer learning?
Deeplearning is a hot topic right now. Since artificial neural networks got relevance again, new architectures, tools, methodologies and applications have been born.
One of the biggest challenges when working with neural networks is the training, and the high computational cost associated with it. To have an idea, AlexNet had 55x55x96 neurons at the output of the first convolutional layer. Each neuron is connected to an input region of 11x11x3.
That’s a huge number of connections and neurons, and we are talking about one of the earliest famous deep networks. Architectures have been growing since then.
Training networks of this size is expensive, and even with a recent GPU it can take days to train a network. Luckily, transfer learning comes to the rescue!
Convolutional networks are normally comprised of an input layer, a stack of convolutional layers and max pool layers (subsampling), a fully connected layer and the output layer.
The convolutional layers are filters that are applied to the images, and they are basically the ones in charge of the feature extraction. The more deep in the network they are, the more specific is the extracted feature.
The last layer is basically a classic network, and as we know, they can be less expensive to train. It is in charge of taking the output of the first part and classify it.
Transfer learning is a technique where we take a pre-trained network, remove the last part of the network (the classifier) and replace it with our own. Then we train our classifier. This allows us to use battle-proven architectures at a fraction of the cost.
An example with VGG16 in pytorch is pretty simple. VGG16 is the 16 layers version of a popular architecture. First we import the pre-trained network:
from torchvision import models vgg16 = models.vgg16(pretrained=True)
After that, we disable gradients for the network:
for param in vgg16.features.parameters(): param.requires_grad = False
With that done, we replace the classifier. In the case of VGG16, that’s at index 6:
from torch import nn last_layer = nn.Linear(n_inputs, num_labels) vgg16.classifier[6] = last_layer
In the example we used a single layer, but you can use any structure you deem necessary. It does not need to resemble the original classifier.
After the last layer has been replaced, you can proceed to re-train the network. As the gradients will not be calculated for the convolutional part of the network, the training will be a lot faster, and you can focus on the classifier.
The training process is left out of the post, as this is just a small showcase of what transfer learning is. If you want to get more details, just explore pytorch. It is a powerful and flexible framework for deep learning. It takes a little longer to learn when compared to keras, but you will love it once you get your first network working.