Training Neural Networks with Regularization
Regularization refers to a suite of techniques used to prevent overfitting, which is the tendency of highly expressive models such as deep neural networks to memorize details of the training data in a way that does not generalize to unseen test data.
- Using the ValidationSet option of NetTrain to stop training early, as soon as the validation error or loss starts to get worse.
- Use of regularization layers such as DropoutLayer or features such as the "Dropout" parameter of LongShortTermMemoryLayer etc.
Before we can describe the solutions, we will demonstrate the problem with a simple example. We create a synthetic training dataset by taking noisy samples from a Gaussian curve. Next, we train a net on those samples. The net has much higher capacity than needed, meaning that it can model functions that are far more complex than necessary to fit the Gaussian curve. Because we know the form of the true model, it is visually obvious when overfitting occurs: the trained net produces a function that is quite different from the Gaussian, as it has "learned the noise" in the original data. This can be quantified by sampling a second set of points from the Gaussian and using the trained net to predict their values. The trained net fails to generalize: while it is a good fit to the training data, it does not approximate the new test data well.
The first common approach to mitigate overfitting is to measure the performance of the net on the secondary test data (which is not otherwise used for training) and to choose the particular net that corresponds to the best performance on the test set across the entire history of training. This is possible with the ValidationSet option, which measures the net on the test data after each round. Those measurements have two consequences: they produce validation loss curves (and validation error curves when classification is being performed), and they change the selection process used by training to pick the intermediate net with the lowest validation loss as opposed to the lowest training loss.
A second common regularization technique is called weight decay. In this approach, the magnitude of the weights of the net is decreased slightly after each batch update, effectively moving the weights closer to zero. This is loosely equivalent to adding a loss term that corresponds to the L2 norm of the weights.
Performing weight decay encourages the net to find a parsimonious configuration of its weights that can still adequately model the data or, equivalently, penalizes the net for the complexity incurred by fitting noise rather than data.
A third common regularization technique is dropout. Dropout introduces noise into the hidden activations of a net, but in such a fashion that the overall statistics of the activations at a given layer do not change. The noise takes the form of a random pattern of deactivation, in which a random set of components of the input tensor (often referred to as units or neurons) is zeroed, and the magnitudes of the remaining components are increased to compensate. The basic idea is that dropout prevents neurons from depending too heavily on any particular neuron in the layer below them and hence encourages the learning of more robust representations.