Differential Privacy Series Part 3 | Efficient Per-Sample Gradient Computation for More Layers in Opacus


In the previous blog post, we covered how performance-improving vectorized computation is done in Opacus and why Opacus can compute “per-sample gradients” a lot faster than “microbatching”. We had also introduced the vectorized computation for nn.linear layers. In this blog post, we explain further on how per-sample gradients efficiently for other layer types: convolutions, RNNs, LSTMs, normalizations, embeddings, and multi-head attentions.


In the previous blog post, we covered the following:

  • One of the features of Opacus is “vectorized computation”, in that it can compute per-sample gradients a lot faster than microbatching. To do so, we derive the per-sample gradient formula, and implement a vectorized version of it.
  • The per-sample gradient formula is
Per-sample gradient of loss with respect to the weights
  • We call the gradients with respect to activations the “highway gradients” and the gradients with respect to the weights the “exit gradients”.
  • Highway gradients retain per-sample information, but exit gradients do not.
  • einsum facilitates vectorized computation.

Extending the idea to other modules

Now that we have seen how to efficiently compute per-sample gradients for linear layers (the building blocks of multilayer perceptrons (MLPs)), we can apply the underlying techniques to other layers too. First of all, note that this should be possible. Why? Let us explain. All a linear layer does is a matrix multiplication (matmul) between the inputs and the parameters. All other kinds of layers are probably doing something like this too! The only difference is that they come with additional constraints, such as weight sharing in a convolution, or sequential accumulation in the backward pass in an LSTM. Here is how we do it for convolutions, LSTMs, multi-head attention, normalization, GRUs, and embedding layers.


As a refresher, let’s look at the forward pass of a convolution module. For simplicity, we consider a Conv2D with a 2x2 kernel operating on an input with just one channel (shape 1x3x3).

Recurrent: RNN, GRU, and LSTM

A little background. Recurrent neural networks catch temporal effects by using intermediate hidden states connected in a sequence. Similar to other neural network blocks they map a sequence of input vectors to a sequence of output vectors. A recurrent neural network can be represented as a series of consequent flat layers, each consisting of a chain of cells (directed either forward or backward). A cell, a basic element of a recurrent neural network, transforms a single input token or its intermediate representation and updates the hidden state vector of the cell. The parameters of a recurrent layer are basically represented by the parameters of the underlying cells. All cells in one flat sublayer share the same set of parameters, i.e., regardless of the time, the input and the current hidden state go through the same transformation. There are different approaches to handling temporal dependencies and implementing recurrent neural networks. RNN, GRU and LSTM are the three most popular implementations. They introduce different cell types, all based on a parameterized linear transformation, but the basic form of the neural network remains unchanged.

Multi-Head Attention

A refresher on multi-head attention: multi-head attention is one of the main components of a transformer. Multi-head attention computes queries, keys, and values by applying three linear layers on a sequence of input vectors, and returns a combination of the values weighted by the attention. The attention itself is obtained via softmax on the dot product between queries and keys. In Pytorch, all these components are fused together at the cuDNN level to allow for more efficient computation.

  • We rewrote the multi-head attention which has the underlying three linear layers. Opacus automatically hooks itself to these linear layers to compute per-sample gradients; these linear layers use einsum to compute grad samples, as discussed in the previous blog post.
  • We implemented an additional SequenceBias layer which adds a bias vector to the whole sequence augmented with per-sample gradient computation. Note that the main part of implementation is SequenceBias, which is a pretty straightforward module.

Normalization Layers

With Differential Privacy, batch normalization layers are prohibited because they mix information across samples of a batch. Nevertheless, other types of normalization — such as LayerNorm, InstanceNorm, or GroupNorm — are allowed and supported as they do not normalize over the batch dimension and hence do not mix information.


An embedding layer can (once again) be viewed as a special case of a linear layer where the input is one-hot encoded, as shown in this figure.


In summary, Opacus computes per-sample gradients by (1) capturing the activations and highway gradients, and then (2) efficiently performing matrix multiplications.

  1. Building block. These are “atomic” trainable modules (i.e., “default classes”) that have their own hooks, and can be used directly, for example, nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, and the normalization layers (nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm). See points 1,2,3 here.
  2. Composite. These are modules that are composed of building blocks. Composite modules are supported as long as all trainable submodules are supported. Frozen submodules need not be supported; An nn.Module can be frozen in PyTorch by unsetting requires_grad in each of its parameters


In this blog post, we explained the idea of efficiently computing per-sample gradients in Opacus for other layers: convolutions, LSTMs, multi-head attentions, normalizations, GRUs, RNNs, LSTMs, and embeddings. We also explained how arbitrary modules can be supported in Opacus, as long as they consist of building block and composite modules.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store

PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.