Skip to content

Training BNNs

Once you have defined a good architecture for your BNN, you want to train it. Here we give an introduction to common training strategies and tricks that are popular in the field. After reading this guide you should have a good idea of how you can train your BNN.

Below we first discuss the fundamental challenges of BNN optimization. We then explain the most commonly used training strategy, of using latent weights, and cover the questions, tips & tricks that tend to come up when training BNNs.

The Problem with SGD in BNNs

Stochastic Gradient Descent (SGD) is used pretty much everywhere in the field Deep Learning nowadays - either in its vanilla form or as the core part of some more sophisticated algorithm like Adam.

However, when turning to BNNs, two fundamental issues arise with SGD:

  • The gradient of the binarization operation is zero almost everywhere, making the gradient \(\frac{\partial L}{\partial w}\) utterly uninformative.
  • SGD performs optimization through small update steps that are accumulated over time. Binary weights, meanwhile, cannot absorb small updates: they can only be left alone or flipped.

Another way of putting this is that the loss landscape for BNNs is very different than what you are used to for real-valued networks. Gone are the glowing hills you can simply glide down from: the loss is now a discrete function, and many of the intuitions and theories developed for continuous loss landscapes no longer apply.

Luckily, there has been significant progress in solving these problems. The issue of zero gradients is resolved by replacing the gradient by some more informative alternative, what we call a 'pseudo-gradient'. The issue of updating can be resolved either by introducing latent weights, or by opting for a custom BNN optimizer.

Choice of Pseudo-Gradient

In larq.quantizers you will find a variety of quantizers that have been introduced in different papers. Many of these quantizers behave identically during the forward pass but implement different pseudo-gradients. Studies comparing different pseudo-gradients report little difference between them. Therefore, we recommend using the classical ste_sign() as a default.

Latent Weights

Suppose we take a batch of training samples and evaluate a forward and backward pass. During the backward pass we replace the gradients with a pseudo-gradient, and we get a gradient vector on our weights. We then feed this into an optimizer, and get a vector with updates for our weights.

At this point, what do we do? If we directly apply the updates to our weights, they are no longer binary. The standard solution to this problem has been to introduce real-valued latent weights. We apply our update step to this real-valued weight. During the forward pass, we use the binarized version of the latent weight.

Beware that latent weights are not really weights at all - after all, changing the latent weights usually doesn't affect the behavior of the network and we throw the latent weights away after we're done training. Instead, they are best thought of as a product between the weight and a positive inertia: the higher this inertia, the stronger the signal required to make the weight flip.

One implication of this is that the latent weights should be constrained: as an increase in inertia does not change the behavior of the network, it can otherwise grow indefinitely.

In Larq, it is trivial to implement this strategy. An example of a layer optimized with this method would look like:

x_out = larq.layers.QuantDense(
    512,
    input_qunatizer="ste_sign",
    kernel_quantizer="ste_sign",
    kernel_constraint="weight_clip",
)(x_out)

Any optimizer you now apply will update the latent weights; after the update the latent weights are clipped to \([-1, 1]\).

Choice of Optimizer

When using a latent weight strategy, you can apply any optimizer you are familiar with from real-valued deep learning. However, due to the different nature of BNNs your intuitions may be off. We recommend using Adam: although other optimizers can achieve similar accuracies with a lot of finetuning, we and others have found that Adam is the quickest to converge and the least sensitive to the choice of hyperparameters.

Retrieving the binary weights

When using the latent weight strategy, the weights are only quantized on the forward pass. This means that when saving the model weights, the latent weights will be saved. To access the binary weights we can use the quantized_scope context:

model.save("full_precision_model.h5")  # save full precision latent weights
fp_weights = model.get_weights()  # get latent weights

with larq.context.quantized_scope(True):
    model.save("binary_model.h5")  # save binary weights
    weights = model.get_weights()  # get binary weights

Alternative: Custom Optimizers

Instead of using latent weights, one can opt for a custom BNN optimizer that inherently generates binary weights. An example of such an optimizer is Bop.

Tips & Tricks

Here are some general tips and tricks that you may want to keep in mind:

  • BNN training is noisier due to the non-continuous nature of flipping weights; therefore, we recommend setting your batch norm momentum to 0.9.
  • Beware that BNNs tend to require many more epochs than real-valued networks to converge: 200+ epochs when training an AlexNet or ResNet-18 style network on ImageNet is not unusual.
  • Networks tend to train much quicker if they are initialized from a trained real-valued model. Importantly, this requires the overall architecture of the pretrained network to be as similar as possible to the BNN, including placement of the activation operation (which replaces the binarization operation). Note that although convergence is faster, pretraining does not seem to improve the final accuracy.

Further References

If you would like to learn more, we recommend checking out the following papers (starting at the most recent):