Make sure you have completed the getting started tutorial.

Table of contents:

Pruning tutorial

Introduction

Deep learning has achieved unprecedented performance on image recognition tasks like ImageNet and natural language processing tasks such as question answering and machine translation. These models generally are on the order of millions or even billions of parameters. For example, Google's recent released large language model, Pathways Language Model (PaLM), can do well on tasks like conceptual understanding and cause & effect reasoning and contains over 540 Billion parameters!

While this is great for advancing state of the art (SOTA) in terms of accuracy, we have applications like self-driving cars or client-side video processing where we would want to deploy these models on the edge to meet real-time deadlines. By having models on the device, we can avoid the latency cost of sending a request to the server for the model to process and sending back the output. In these settings, we can't use these gigantic models for a few reasons: the on-device memory available is quite limited (meaning we can't fit our model into memory) and the number of operations it takes to obtain output from the model would fail constraints like latency and power. Luckily, we can rely on model compression to address these concerns.

Model compression is the area of research focused on deploying SOTA models in resource-constrained devices while minimizing accuracy degradation. Various approaches to compressing a model include: weight pruning, quantization, knowledge distillation, low-rank tensor decomposition, hardware-aware neural architecture search, etc. In particular, we will discuss and target, arguably, the simplest of these methods: weight pruning.

Weight Pruning

Unstructured Pruning

Pruning a deep learning models involves finding a percentage of the weights that don't contribute much to the classification output and setting their values to 0. By setting the values to 0, we reduce the memory footprint of the model as well as the number of multiplies and accumulates during inference. The de-facto method is low-weight magnitude pruning where we rank the weights in the network and eliminate the smallest weights upto a threshold dictated by a chosen compression ratio. Let's explore how to do this in Julia with the Flux.jl package.

First, let's begin with some imports that will help us load the dataset and model.

include("_tutorials/src/setup.jl");

Next, let's define our model. We are using MobileNetv1, which is a popular deep learning model that achieves high classification accuracies while still being very resource-efficient. Note the number of parameters the model contains and the amount of memory needed to store the model. We can also calculate the number of multiplies and accumulates that MobileNetv1 incurs to produce an output.

m = MobileNet(slopehtanh, 0.25; fcsize = 64, nclasses = 1)
mults, adds, output_size = compute_dot_prods(m, (96, 96, 3, 1)) # height and weight are 96, input channels are 3, batch size = 1
println("MobileNet Mults ", mults, " Adds ", adds)
MobileNet Mults 7505600 Adds 7273983

Next, we need to load in the dataset to prune and finetune our model. show line that loads in the data. Now that we've finished our setup, let's prune our model. We can use the FluxPrune.jl package to easily prune the lowest magnitude weights by calling LevelPrune.

m_lv_pruned = prune(LevelPrune(0.2), m);

FluxPrune's prune function takes in two inputs: the pruning strategy and the model to prune. We are using the LevelPrune strategy which traverses each layer of the model and removes the lowest p% (20% in this case) weights in each layer. This is called unstructured pruning since we are concerned with removing the lowest magnitude weights and not worrying about if the sparsity induces some kind of structure. FluxPrune allows you to set a different pruning strategy for every layer in the model if you desire. Typically, we also have to finetune our resulting pruned model in order to recover some accuracy penalty induced by setting the weights to 0. Let's compute the number of multiplies and accumulates to see how much we have saved.

mults, adds, output_size = compute_dot_prods(m_lv_pruned, (96, 96, 3, 1)) # height and weight are 96, input channels are 3, batch size = 1
println("MobileNet Mults ", mults, " Adds ", adds)
MobileNet Mults 7460806 Adds 7273983

We can see that we have obtained a reduction in the number of multiplies relative to our unpruned baseline. Unstructured pruning is powerful in that we are able to prune so aggressively that we can obtain sparse models that perform just as well as the baseline at less than 10% of the original model capacity. While unstructured pruning achieves the best compression vs. accuracy tradeoffs, it may not translate into faster inference since the unstructured nature of zeros in the weight matrices may induce irregular memory access patterns and sparse GEMM kernels are competitive with dense ones only at extreme sparsities. For these reasons, one may consider structured pruning instead.

Structured Pruning

In structured pruning, we remove entire channels (typically) or filters rather than individual weights. This type of pruning only applies to the convolutional layers, as the concept of removing structure really applies to conv layers as opposed to full-connected layers. By removing the lowest magnitude channels, we are drastically able to reduce the number of multiplies and accumulates that our model has to perform.

To prune channels, we can define the ChannelPrune strategy, which solely targets the convolutional layers.

m_ch_pruned = prune(ChannelPrune(0.2), m);
mults, adds, output_size = compute_dot_prods(m_ch_pruned, (96, 96, 3, 1)) # height and weight are 96, input channels are 3, batch size = 1
println("MobileNet Mults ", mults, " Adds ", adds)
MobileNet Mults 6105200 Adds 5915181

Compared to the number of multiplies reduced from unstructured pruning, structured pruning drastically reduces the computational cost incurred by the model during inference. The caveat for structured pruning is that by eliminating groups of weights, the compression ratio that structured pruning methods are set at are much lower than those from unstructured methods so the memory savings are limited. Choosing what the optimal amount of compression vs. latency of the model during inference is a design choice that must be made during model design and prior to deployment.

Propagating the pruning

Once we have a sparse model through structured or unstructured pruning, there may be convolution nodes that have been completely eliminated in the pruning process. During the propagating phase, we evaluate which other neighbouring nodes have been impacted by the effects of pruning, and zero them out as well to extract further sparsity from pur pruning without impacting accuracy.

m_pruned = keepprune(m_ch_pruned)
m_prop = prune_propagate(m_pruned)
mults, adds, output_size = compute_dot_prods(m_prop, (96, 96, 3, 1))
println("Propagated MobileNet Mults ", mults, " Adds ", adds)
Propagated MobileNet Mults 5002864 Adds 4899018

Resizing the propagated model

If enough nodes get pruned out, there would be slices in the model which accomplish nothing, computationally. Instead of wasting resources on passing these kernels full of zeros around, they can be eliminated from the structure of our model.

m_resized = resize(m_prop)
mults, adds, output_size = compute_dot_prods(m_resized, (96, 96, 3, 1))
println("Resized MobileNet Mults ", mults, " Adds ", adds)
Resized MobileNet Mults 3680115 Adds 3527886

Pruning and Finetuning pipeline

Now that we seen how to prune our model, let's try to finetune it to recover some of the accuracy we lost. A basic template setup for training the model is provided by trainer function, and can be used as a starting point for your own training methodology.

include("_tutorials/trainerfunc.jl");
m = trainer(m, 1); #trains the model m for 1 epoch

Useful Resources:

  1. Blog Post on Pruning and Sparsity

  2. Blog Post on Model Compression

  3. Model Compression Survey Paper

  4. Deep Compression Paper


CC BY-SA 4.0 UW-Madison PHARM Group. Last modified: November 15, 2023. Website built with Franklin.jl and the Julia programming language.