In the past few years, state-of-the-art architectures became more and more complex. The number of parameters grows exponentially. But what if all networks are over-parameterized and more than a half parameters don’t influence the result? Several methods can help us. Not so far ago I wrote a post about one of them called knowledge distillation. You can find it here.
Today I will continue this series about reducing model size with introduction to pruning neural networks. Let’s start!
If you didn’t hear about Catalyst before I recommend you to read this post, which introduces ideas and minimal examples of this framework.
If network is over-parameterized, let’s try to simply null some parameters. The process of removing connections between neurons called pruning. It takes the idea from the biology field. For instance, the human brain is also over-parameterized in the first stages of growth, and we are learning through pruning unnecessary connections.
But what can we do with this theory in practice? When it comes to neural networks, connections between neurons can be represented as a matrix. So the result of applying one layer is
Where “f” is a non-linear activation function, for example, ReLU.
We often have bias term there, but let’s focus on values of the matrix elements. In this case, we can represent matrix multiplication as
So we can assume that the less absolute value of “w” we have, the less influence this value has on the result. This method called magnitude pruning. I will do all my experiments on MNIST dataset. Here is a code sample:
Let’s try to prune these connections and see the results!
Even this simple method can reduce the number of parameters for about 3 times without losing the quality! But can we do better?
We can tune our network every time we applying pruning.
Let’s look on the results!
If our network is over parameterized then maybe in the randomly initialized network there is already subnetwork, that could solve our task more efficient. All we need to do is to find this subnetwork. So after pruning we can restore initial weights but keep the pruning mask. And then tune our subnetwork.
As you can see the results are very close to iterative pruning.
This result is not very intuitive, but results are as good as in iterative pruning. The most interesting thing is that we could somehow guess the lottery ticket mask right from the start and reach the result with one iteration!
Since sparse layers are not available in PyTorch (except nn.Embeddings) we can’t feel any acceleration or model size reduction, as we just replace some weights with zeros. It works even slower if we keep pruning mask, as we need to execute some additional pre-forward hooks to apply this mask. What can we do?
But wait, as we remember we can prune neurons instead of connections. What does it mean practically?
Every neuron is represented by a row in weight matrix, so we need to apply pruning over first dim (
dim=0). After pruning, we can remove the entire row and therefore reduce the number of operations!
However, there is one problem here. If we prune several neurons, the output shape will be different. In other words, if we remove a row in a weight matrix at say layer 1, we should also prune column in weight matrix from layer 2.
If we are talking about convolution layers, we can reduce the number of channels by pruning weight tensors over first dim. But remember that if you want to speed up your model, you should remove columns or rows with zero values manually.
Imagine you have a big slow model and you are trying to speed it up. What should you try first? For now, the answer depends on task, but pruning is not a first thing to try. Here are things I would try:
When you convert your model to Torchscript, you could run it from various languages that are faster than Python (for example C++). It is a suitable solution for almost every case. Also, you can convert your model to special frameworks like ONNX. But what if it is not enough?
Quantization is also a beta feature in PyTorch. But it works well in almost every case. For example, here is a tutorial with BERT. Don’t forget to apply the first step after quantization.
3. Try different architecture
For example, if you have enough data for text classification you can just replace your transformer with logistic regression on top of the tf-idf features and maybe the quality remains almost the same. You could check out Yury Kashnitsky talk about such case.
3. KD + Quantization
If you have enough time, you can try to transfer knowledge from the big model to smaller student. But it requires more time, and probably there is no working pipeline for your concrete case.
4. Pruning + KD + Quantization
The last things is pruning. For example, you can try to prune your network and then try to add KD losses from the full model for tuning it to the quality you need.
Saying so, pruning is not a common thing in production nowadays, but it is one of the most interesting field to research. Someday there will be sparse layers in PyTorch and we can see that the newest GPU cards can handle sparse operation more effectively.
Code for all experiments above available here.
If you have any questions, feel free to join our Catalyst community Slack 😉