Training Neural Networks with Regularization
Regularization refers to a suite of techniques used to prevent overfitting, which is the tendency of highly expressive models such as deep neural networks to memorize details of the training data in a way that does not generalize to unseen test data.
- Using the ValidationSet option of NetTrain to return the net with the best performance determined by the validation loss (or error in a simple classification net).
- Using the TrainingStoppingCriterion option of NetTrain in conjunction with ValidationSet to stop training when overfitting starts to happen.
- Use of regularization layers such as DropoutLayer or features such as the "Dropout" parameter of LongShortTermMemoryLayer etc.
Before we can describe the solutions, we will demonstrate the problem with a simple example. We create a synthetic training dataset by taking noisy samples from a Gaussian curve. Next, we train a net on those samples. The net has much higher capacity than needed, meaning that it can model functions that are far more complex than necessary to fit the Gaussian curve. Because we know the form of the true model, it is visually obvious when overfitting occurs: the trained net produces a function that is quite different from the Gaussian, as it has "learned the noise" in the original data. This can be quantified by sampling a second set of points from the Gaussian and using the trained net to predict their values. The trained net fails to generalize: while it is a good fit to the training data, it does not approximate the new test data well.
The first common approach to mitigate overfitting is to measure the performance of the net on the secondary test data (which is not otherwise used for training) and to choose the particular net that corresponds to the best performance on the test set across the entire history of training. This is possible with the ValidationSet option, which measures the net on the test data after each round. Those measurements have two consequences: they produce validation loss curves (and validation error curves when classification is being performed), and they change the selection process used by training to pick the intermediate net with the lowest validation loss as opposed to the lowest training loss.
A second common approach to mitigate overfitting is to perform early stopping. This is the procedure of stopping training when some measurement of the net's performance starts to become worse. This is accomplished using the TrainingStoppingCriterion option of NetTrain in conjunction with the ValidationSet option. This has two potential advantages over simply using ValidationSet. Firstly, notice that in the previous example we ended up training the net for an extra 1300 rounds. In this case, because the net and training datasets are both small, this was not an issue, but this could be very wasteful. Secondly, we might want to determine the best net using some measurement other than the loss.
A third common regularization technique is called weight decay. In this approach, the magnitude of the weights of the net is decreased slightly after each batch update, effectively moving the weights closer to zero. This is loosely equivalent to adding a loss term that corresponds to the L2 norm of the weights.
Performing weight decay encourages the net to find a parsimonious configuration of its weights that can still adequately model the data or, equivalently, penalizes the net for the complexity incurred by fitting noise rather than data.
A fourth common regularization technique is dropout. Dropout introduces noise into the hidden activations of a net, but in such a fashion that the overall statistics of the activations at a given layer do not change. The noise takes the form of a random pattern of deactivation, in which a random set of components of the input array (often referred to as units or neurons) is zeroed, and the magnitudes of the remaining components are increased to compensate. The basic idea is that dropout prevents neurons from depending too heavily on any particular neuron in the layer below them and hence encourages the learning of more robust representations.
A fifth regularization technique is batch normalization. It normalizes its input data by learning the data mean and variance. Batch normalization has a number of useful properties in practice: it speeds up training and provides regularization. A BatchNormalizationLayer is typically inserted between a LinearLayer or ConvolutionLayer and its activation function.