The Unreasonable Effectiveness of Deep Feature Extraction

There have been a small handful of times in my life where I've read the abstract of a paper -- thought about it for a few moments -- and then audibly exclaimed "Oh shit!".

Reading my first paper on deep feature extraction, back in 2014, was one of those times. The results from back then were amazing, and they've only gotten more impressive since.

But let's start at the beginning. Back in 2014, deep learning was producing impressive results, but was still in its awkward adolescent period. Tensorflow was about a year away from being released, and for most people, myself included, deep learning felt very far away. I didn't have a dataset with millions of images, or a GPU cluster to train on. What did I need deep learning for?

Then, I encountered CNN Features off-the-shelf: an Astounding Baseline for Recognition by Razavian et al. I'm sure it isn't the first paper on the topic, or even the most well-known, but it's the one I encountered. To summarize, they took a deep neural network (OverFeat), pretrained it on millions of images (ImageNet, and then repurposed it as a feature extractor.

If you haven't heard the term before, pretraining is when you train a network on a big dataset that's different than the dataset you're actually interested in. It's a form a of transfer learning -- a way of learning how to solve one problem, and using that knowledge to perform better on a second problem. A feature extractor is anything that takes in data and spits out a set of numbers that describe the important parts of it, in the same way a tailor might describe your shape with a small set of measurements.

Once they had this feature extractor, they fed images from smaller datasets into it, and then fed the resulting features into an support-vector machine, a very simple model that's existed since the 90s. Basically, the OverFeat network was used to preprocess images so that they could be modeled using well-known ML tools, like a mother bird regurgitating partially-digested food for her young.

Doing this outperformed the existing state of the art in about a dozen tasks, many of which diverged wildly from the neural network's original input distribution. Performing well on highly-divergent inputs is a big deal. It's obvious that if you train a network on one problem, it will perform well on similar problems. But the network they'd trained on ImageNet, learning things like how to distinguish dogs from strawberries, worked well as a feature extractor for things like the Oxford Flowers dataset, where it has to make fine-grained distinctions between different types of flowers.

Performing well on highly-divergent inputs means you've found a general technique. In the authors' words: "The results strongly suggest that features obtained from deep learning with convolutional nets should be the primary candidate in most visual recognition tasks."

This was the "oh shit!" moment for me. Any time someone finds a general technique that outperforms something hand-tuned by experts, it's pretty exciting. When that general technique works across a dozen different tasks, it's even more exciting.

But here's what really caught my attention: it was something I could actually use. All I needed to do was download a pretrained deep convolutional net, and I could expect the features it extracted to work pretty well on whatever problem I had. Before reading this paper, I thought of "deep learning" as "training a neural network from scratch". I usually didn't have a big enough dataset to do that, and even if I did, I didn't have access to the sort of specialized hardware that would let me do it quickly.

But suddenly, I didn't need a million images to train on. I just needed a few thousand to run through the pretrained network and feed into a support-vector machine. I didn't need a GPU cluster. I could do it all on my laptop CPU, if I was willing to wait a while for the feature extraction to run.

After reading this, deep learning didn't feel so far away any more.

How does one actually repurpose a pretrained neural network as a feature extractor? Fortunately, doing the simplest possible thing works really well.

As you probably already know, a deep neural network consists of a series of layers. The word "layer" usually refers to a big linear transformation, followed by a non-linear function called an activation. (A common activation is ReLu, which is just f(x) = max(x, 0).) The outputs of the activation function are called the "activations" of the layer, and they're what get fed into the next layer.

To turn a pretrained deep neural network into a feature extractor, you pick a layer toward the end of the network, feed in the image you want to extract features from, and then use the activations of that layer as features. It's sometimes helpful to do a little bit of post-processing, but once you have the activations of your chosen layer, you're basically done.

Calling this vector of floats yanked from the inside of a neural network "features" might seem strange, since they don't have an obvious meaning. But they're features in the most important sense of the word, because you can use them to train a model that makes accurate predictions.

I was confused for a long time about why this works as well as it does. Why would you expect the intermediate state of a neural network to produce features that are good for training traditional ML models?

It turns out that you can visualize what images correspond to a specific extracted feature. If you do, you see that the activations of late layers are capturing very high-level semantic concepts like eyes and clothing. (Activations of earlier layers tend to capture lower-level concepts like horizontal edges or blobs of color.)

Interestingly, you can also choose an arbitrary linear combination of extracted features, and see what images cause that to increase. If you do, you find that a surprising number of linear combinations also capture some high-level semantic concept. Essentially, the late layers of the neural network are learning to recognize features that can be linearly recombined to capture high-level semantic properties of the images in the training dataset.

With that in mind, it kind of makes sense that feeding another dataset into this network produces features which are good for a huge variety of tasks.

What's changed since 2014?

I promised you "unreasonable effectiveness", but so far all we've seen are a handful of impressive results from half a decade ago. What's happened since then?

The field has changed a lot. Deep learning tools have improved dramatically, and knowledge on how to use them has spread. GPUs are faster, with more memory. Parallelization techniques have improved. Networks that used to take a month to train can train in days or hours.

You might think this would make pretrained networks unnecessary, since it's now so much easier to train a deep neural network from scratch. But I find myself relying on deep feature extraction more, not less, as time goes on. It turns out that for most problems, the major limiting factor isn't ease of use -- it's dataset size.

There's been some empirical research on this, which confirms the unfortunate fact that the error rate of a neural network tends to improve extremely sublinearly with the amount of data used to train it. It also confirms the even more unfortunate fact that most architectual improvements don't change the shape of those curves very much. After a certain point, it takes a lot of data to meaningfully improve your results.

But let's focus on a specific example. In 2018, Facebook put out a paper called Exploring the Limits of Weakly Supervised Pretraining. This was the paper that really convinced me of where things are headed, so I want to take a little while to dig into it.

Back in 2014, ImageNet was a huge dataset. It was common to train on ImageNet, and then transfer that training to another problem. By 2018, ImageNet is considered small, and it's becoming common to train on another dataset and transfer that training to ImageNet.

At the time of publication, Facebook's paper achieved the best performance ever recorded on ImageNet-1k single-crop top-1 accuracy. The key to these results was pretraining on billions of Instagram images, using the user-provided tags as noisy labels, and increasing the model capacity accordingly.

Reading this paper was both exciting and discouraging for me. On the one hand, it's exciting that existing techniques seem to scale pretty much indefinitely, just by increasing their capacity and the amount of data available. On the other hand, who has billions of images to train on? Who has the resources to do architecture and hyperparameter search on these enormous models and datasets?

Buried in the paper, though, is this very important aside:

The highest accuracies on val-IN-1k are 83.3% (source: IG-940M-1k) and 83.6% (source: IG-3.5B-17k), both with ResNeXt-101 32×16d. These results are obtained by training a linear classifier on fixed features and yet are nearly as good as full network finetuning, demonstrating the effectiveness of the feature representation learned from hashtag prediction.

(The accuracy of the fully-finetuned network was 84.2%, the accuracy of a linear classifier with deep feature extraction was 83.6%, and the same architecture trained directly on ImageNet got 79.6%.)

In other words, if you take this huge network pretrained on a billion images, repurpose it as a feature extractor, and use it to train a simple linear model on your dataset -- it's almost as good as finetuning the whole thing. Crucially, it significantly outperforms training a neural network directly on your dataset. Even for a dataset as large as ImageNet.

Unless you have literally the largest dataset in the world, training a neural network on it from scratch will probably give worse results than using a huge pretrained net as a feature extractor and training a simple linear model on that. This is absolutely remarkable, because training a linear model is way easier. We're talking ten lines of Python to implement, and training times measured in minutes.

(You could get slightly better results by finetuning the huge pretrained net, and this will be worth it in some cases. But the performance improvement will probably be negligible for dataset sizes in the thousands, and less than a percent for larger datasets like ImageNet.)

Another thing that's changed since 2014 is that deep feature extraction has sort of been eaten by the concept of embeddings. You might be familiar with word embeddings like Word2vec, which map words from a dictionary to a vector of floats. The idea is actually very general: anything which maps an object in one space to point in a vector space can be called an embedding. Usually the term implies either that the second space is lower-dimensional, or that objects which are similar by some important metric are close together in the second space. (In the case of Word2vec, two words are considered similar if they're used in similar contexts.)

Deep feature extraction takes in an image, and spits out a vector of floats, so it's clearly an embedding in that sense. It turns out it meets the second property as well; semantically similar images tend to have similar features, so their points in the vector space end up being close together.

Thinking of deep feature extraction as embedding images into a more tractable space turns out to be a very powerful idea. It lets you think interesting thoughts like "what if I embed two different data types, like images and text, into the same space?" (The goal, of course, being to cluster points in that space in order to e.g. search for images based on a description.)

I usually hate switching terminology mid-stream in a post, but you'll see both terms in the wild, and I think the idea of embeddings is important enough to warrant it. We'll use the term "embedding" from now on.

There's a growing consensus that deep learning is going to be a centralizing technology rather than a decentralizing one. We seem to be headed toward a world where the only people with enough data and compute to train truly state-of-the-art networks are a handful of large tech companies.

I think this consensus is probably correct, but that this world is better than it sounds. Right now, most people work in a paradigm where they're given a dataset, and their job is to fit a model on it from scratch. In this paradigm, it's pretty depressing if only a few companies have the data and compute to actually fit the best models from scratch.

But in the future, I think ML will look more like a tower of transfer learning. You'll have a sequence of models, each of which specializes the previous model, which was trained on a more general task with more data available. If, for example, you want to train a visual similarity model on a catalogue of product images, your future pipeline will probably look something like this:

  1. Some giant tech company pretrains a neural network on 1B-1T images.
  2. Someone else with access to significantly more product images than you finetunes the giant pretrained image network on 1M-1B product images.
  3. You use the pretrained product image network to embed your product images, and train your model on those embeddings.

Now, predicting the future is hard, and that last part is a pretty specific prediction. Why embeddings?

We already talked about efficacy. Linear models trained on embeddings perform incredibly well. But embeddings have a bunch of other properties which are more important than people realize.

Embeddings have short train-test cycles.

Models that take a long time to train kill productivity in the same way as long compile times or slow test suites. Real-world data science requires being able to quickly test hypotheses. Does this hand-engineered feature improve my model? Hm, no. What about this one? Maybe I should be normalizing differently? Ooh, maybe my labels are noisy. Does it work on a different dataset I know is clean?

If you're training a deep neural network, the turnaround time on these questions might be measured in days, and your life will be awful. If you have a big hardware budget, it might be measured in hours, and your life will be merely miserable. A linear model trained on embeddings will often have a turnaround time measured in minutes.

Embeddings let you mix and match model architectures

Embeddings provide a clean interface between models. You can change the model which produced the embedding, or change the model consuming the embedding, without worrying about the other one.

Right now, the best image embeddings are produced by deep convolutional nets. But that might not be true forever. Maybe we invent something better in the future. Maybe an old technique makes a comeback, and in ten years we're using genetic algorithms to train image embeddings. If you're using embeddings as the input to a simpler model, your model can benefit from these improvements with as little effort as swapping Word2vec out for ELMo takes today.

Embeddings plug directly into existing pipelines

The biggest thing preventing most companies from adopting deep learning is legacy infrastructure. This is entirely reasonable. If my business depended on a random forest I'd spent five years painstakingly tuning, I'd be reluctant to get rid of it too.

Embeddings provide an easy way to use deep learning to improve existing models. The features from the embedding can be fed into a traditional model alongside hand-crafted ones. Often the model's accuracy will improve with little or no feature engineering on the embedding.

What still needs to happen?

All these amazing things we've been talking about -- fast training, good performance on small datasets, tools that work on your laptop, interoperability with existing models -- all of these depend on one thing. Someone has to do the grunt work of putting together giant datasets and training huge networks on them, so that everyone else can make use of the embeddings.

Researchers today frequently release pretrained networks, but there are a handful of problems with this state of the world:

  1. The best pretrained networks are coming out of the big tech companies, and it isn't obvious they'll keep releasing them forever. Today's good PR is tomorrow's competitive advantage.

  2. Running state of the art pretrained networks with reasonable latency requires specialized hardware. A lot less hardware than it takes to train them, to be clear, but you don't want to be doing it on your laptop.

  3. There are good pretrained models available for big research topics, but nobody is training and releasing models specialized for the sorts of data most people work with. There's no good pretrained network for things like product images or typo-filled support requests.

I'm not sure exactly how to solve those problems. I have a few ideas, but I'll leave them for a future post.

Even if I'm not sure exactly how to solve them, though, those problems seem very surmountable. I'm basically convinced at this point that embeddings are going to be the main way most data scientists get value out of deep learning, which is really exciting. It's rare to get a convincing-feeling glimpse of the future, and this future seems particularly bright.