Training Neural Networks with Regularization

Regularization refers to a suite of techniques used to prevent overfitting, which is the tendency of highly expressive models such as deep neural networks to memorize details of the training data in a way that does not generalize to unseen test data.

There are essentially three ways to accomplish regularization in the Wolfram Language:

Overfitting

Before we can describe the solutions, we will demonstrate the problem with a simple example. We create a synthetic training dataset by taking noisy samples from a Gaussian curve. Next, we train a net on those samples. The net has much higher capacity than needed, meaning that it can model functions that are far more complex than necessary to fit the Gaussian curve. Because we know the form of the true model, it is visually obvious when overfitting occurs: the trained net produces a function that is quite different from the Gaussian, as it has "learned the noise" in the original data. This can be quantified by sampling a second set of points from the Gaussian and using the trained net to predict their values. The trained net fails to generalize: while it is a good fit to the training data, it does not approximate the new test data well.

Create a noisy dataset based on a Gaussian curve:
In[1]
Click for copyable input
Out[1]
In[2]
Click for copyable input
Out[2]
Create a multilayer perceptron with a large number of hidden units:
In[3]
Click for copyable input
Out[3]
Train the net for 30 seconds:
In[4]
Click for copyable input
Out[4]
Despite the noise in the data, the final loss is very low:
In[5]
Click for copyable input
Out[5]

The resulting net overfits the data, learning the noise in addition to the underlying function. To see this, we plot the function learned by the net alongside the original data.

Obtain the net from the NetTrainResultsObject:
In[6]
Click for copyable input
Out[6]
In[7]
Click for copyable input
Out[7]

A more quantitative way to demonstrate that overfitting has occurred is to test the net on data that comes from the same underlying distribution but that was not used to train the net.

Synthesize a second copy of the training data to use as a test set:
In[8]
Click for copyable input
Create a function to measure the mean error on the new test set:
In[9]
Click for copyable input
The average loss on the test set is much higher than on the training set, showing that overfitting has occurred:
In[10]
Click for copyable input
Out[10]
Out[10]
The fitted net is also visually not a good explanation for the test samples:
In[11]
Click for copyable input
Out[11]
Click for copyable input

Early Stopping

The first common approach to mitigate overfitting is to measure the performance of the net on the secondary test data (which is not otherwise used for training) and to choose the particular net that corresponds to the best performance on the test set across the entire history of training. This is possible with the ValidationSet option, which measures the net on the test data after each round. Those measurements have two consequences: they produce validation loss curves (and validation error curves when classification is being performed), and they change the selection process used by training to pick the intermediate net with the lowest validation loss as opposed to the lowest training loss.

Use the ValidationSet option to NetTrain to ensure that the net we actually obtain minimizes the validation loss. Note that we limit the training rounds to make it easier to see the portion of training before overfitting starts to occur:
In[1]
Click for copyable input
Out[1]
The validation loss has a minima at around 150 rounds of training:
In[2]
Click for copyable input
Out[2]
Extract the trained net from the NetTrainResultsObject and plot it. The result is much smoother, as NetTrain effectively took a snapshot of the net before it started to memorize the idiosyncrasies of the noise in the training set:
In[3]
Click for copyable input
Out[3]
In[4]
Click for copyable input
Out[4]
The results object also stores the average loss on the test set for the net that was eventually picked. Notice that it is much lower than the loss we computed for the overfitted net:
In[5]
Click for copyable input
Out[5]
In[6]
Click for copyable input
Out[6]

Weight Decay

A second common regularization technique is called weight decay. In this approach, the magnitude of the weights of the net is decreased slightly after each batch update, effectively moving the weights closer to zero. This is loosely equivalent to adding a loss term that corresponds to the L2 norm of the weights.

Performing weight decay encourages the net to find a parsimonious configuration of its weights that can still adequately model the data or, equivalently, penalizes the net for the complexity incurred by fitting noise rather than data.

In general, the optimal value for the strength of the weight decay is difficult to derive a priori, so a hyperparameter search should be performed using a validation set to find a good value.

Train a net with a small value for L2 regularization:
In[1]
Click for copyable input
Out[1]
The resulting net is a good fit to the Gaussian and demonstrates better generalization than the original overfitted net:
In[2]
Click for copyable input
Out[2]
In[3]
Click for copyable input
Out[3]
In[4]
Click for copyable input
Out[4]
In[5]
Click for copyable input
Out[5]

Dropout

A third common regularization technique is dropout. Dropout introduces noise into the hidden activations of a net, but in such a fashion that the overall statistics of the activations at a given layer do not change. The noise takes the form of a random pattern of deactivation, in which a random set of components of the input tensor (often referred to as units or neurons) is zeroed, and the magnitudes of the remaining components are increased to compensate. The basic idea is that dropout prevents neurons from depending too heavily on any particular neuron in the layer below them and hence encourages the learning of more robust representations.

Dropout can be introduced into a net using a DropoutLayer or by specifying the "Dropout" parameter of certain layers such as LongShortTermMemoryLayer.

Create a multilayer perceptron that includes dropout:
In[1]
Click for copyable input
Out[1]
Train the net:
In[2]
Click for copyable input
Out[2]
The resulting net is an acceptable fit to the Gaussian and demonstrates better generalization than the original overfitted net:
In[3]
Click for copyable input
Out[3]
In[4]
Click for copyable input
Out[4]
In[5]
Click for copyable input
Out[5]
In[6]
Click for copyable input
Out[6]

Related Tutorials