BerandaComputers and TechnologyPruning Neural Networks with Catalyst

Pruning Neural Networks with Catalyst

Catalyst Team

Hi! My name is Nikita. I am one of the Catalyst contributors. I want to tell you about pruning with PyTorch and Catalyst.

Image for post

Image for post

In the past few years, state-of-the-art architectures became more and more complex. The number of parameters grows exponentially. But what if all networks are over-parameterized and more than a half parameters don’t influence the result? Several methods can help us. Not so far ago I wrote a post about one of them called knowledge distillation. You can find it here.

Today I will continue this series about reducing model size with introduction to pruning neural networks. Let’s start!

If you didn’t hear about Catalyst before I recommend you to read this post, which introduces ideas and minimal examples of this framework.

If network is over-parameterized, let’s try to simply null some parameters. The process of removing connections between neurons called pruning. It takes the idea from the biology field. For instance, the human brain is also over-parameterized in the first stages of growth, and we are learning through pruning unnecessary connections.

But what can we do with this theory in practice? When it comes to neural networks, connections between neurons can be represented as a matrix. So the result of applying one layer is

Image for post

Image for post

Where “f” is a non-linear activation function, for example, ReLU.

Image for post

Image for post

We often have bias term there, but let’s focus on values of the matrix elements. In this case, we can represent matrix multiplication as

Image for post

Image for post

So we can assume that the less absolute value of “w” we have, the less influence this value has on the result. This method called magnitude pruning. I will do all my experiments on MNIST dataset. Here is a code sample:

Let’s try to prune these connections and see the results!

Image for post

Image for post

Even this simple method can reduce the number of parameters for about 3 times without losing the quality! But can we do better?

We can tune our network every time we applying pruning.

Image for post

Image for post

Let’s look on the results!

Image for post

Image for post

If our network is over parameterized then maybe in the randomly initialized network there is already subnetwork, that could solve our task more efficient. All we need to do is to find this subnetwork. So after pruning we can restore initial weights but keep the pruning mask. And then tune our subnetwork.

Image for post

Image for post

Image for post

Image for post

As you can see the results are very close to iterative pruning.

Image for post

Image for post

This result is not very intuitive, but results are as good as in iterative pruning. The most interesting thing is that we could somehow guess the lottery ticket mask right from the start and reach the result with one iteration!

Since sparse layers are not available in PyTorch (except nn.Embeddings) we can’t feel any acceleration or model size reduction, as we just replace some weights with zeros. It works even slower if we keep pruning mask, as we need to execute some additional pre-forward hooks to apply this mask. What can we do?

But wait, as we remember we can prune neurons instead of connections. What does it mean practically?

Image for post

Image for post

Every neuron is represented by a row in weight matrix, so we need to apply pruning over first dim (dim=0). After pruning, we can remove the entire row and therefore reduce the number of operations!

However, there is one problem here. If we prune several neurons, the output shape will be different. In other words, if we remove a row in a weight matrix at say layer 1, we should also prune column in weight matrix from layer 2.

If we are talking about convolution layers, we can reduce the number of channels by pruning weight tensors over first dim. But remember that if you want to speed up your model, you should remove columns or rows with zero values manually.

Imagine you have a big slow model and you are trying to speed it up. What should you try first? For now, the answer depends on task, but pruning is not a first thing to try. Here are things I would try:

  1. Torchscript

When you convert your model to Torchscript, you could run it from various languages that are faster than Python (for example C++). It is a suitable solution for almost every case. Also, you can convert your model to special frameworks like ONNX. But what if it is not enough?

2. Quantization

Quantization is also a beta feature in PyTorch. But it works well in almost every case. For example, here is a tutorial with BERT. Don’t forget to apply the first step after quantization.

3. Try different architecture

For example, if you have enough data for text classification you can just replace your transformer with logistic regression on top of the tf-idf features and maybe the quality remains almost the same. You could check out Yury Kashnitsky talk about such case.

3. KD + Quantization

If you have enough time, you can try to transfer knowledge from the big model to smaller student. But it requires more time, and probably there is no working pipeline for your concrete case.

4. Pruning + KD + Quantization

The last things is pruning. For example, you can try to prune your network and then try to add KD losses from the full model for tuning it to the quality you need.

Saying so, pruning is not a common thing in production nowadays, but it is one of the most interesting field to research. Someday there will be sparse layers in PyTorch and we can see that the newest GPU cards can handle sparse operation more effectively.

Code for all experiments above available here.

Thank you!

If you have any questions, feel free to join our Catalyst community Slack 😉

Read More

RELATED ARTICLES

LEAVE A REPLY

Please enter your comment!
Please enter your name here

Most Popular

Recent Comments