Example-Weighted Neural Network Training

Example weighting is a common variant of neural network training in which different examples in the training data are given different importance. Simply put, this is accomplished by multiplying the loss of each example by the weight associated with this example to accord it higher or lower importance in the optimization process performed by NetTrain.

There are several situations in which this technique can be beneficial:

Weighting of Examples for Regression

This example demonstrates using example weighting to emphasize specific regions of the input space.
We start by defining a function that our training net will attempt to approximate.

Define a simple function and plot it:
In[1]
Click for copyable input
Out[1]
We create training data by sampling this function at regular points:
In[2]
Click for copyable input
In[3]
Click for copyable input
Out[3]

We train a simple linear regression model with no example weighting, to serve as a baseline with which we can compare example-weighted training. Note that we are using MeanSquaredLossLayer as the loss function; this is actually already the default, but later we will have to explicitly construct a training net, so we are highlighting the choice of loss function now.

Create a simple linear regression model:
In[4]
Click for copyable input
Out[4]
Use NetTrain to train the model:
In[5]
Click for copyable input
Out[5]
Visualize the results:
In[6]
Click for copyable input
Out[6]

Next, we perform weighted training. We will create two datasets that emphasize examples to the left and right of the origin, respectively. We will construct a training net that multiplies the cross-entropy loss we used previously with the training weight. This multiplication causes NetTrain to preferentially optimize for examples that have higher weights.

Create weighted datasets using the Exp function to bias either the left or the right side of the input space:
In[7]
Click for copyable input
Show samples from the datasets:
In[8]
Click for copyable input
Out[8]
In[9]
Click for copyable input
Out[9]
Plot the weights:
In[10]
Click for copyable input
Out[10]
Create a training net that uses example weighting:
In[11]
Click for copyable input
Out[11]
For each dataset, train the net with NetTrain, specifying that the "WeightedLoss" output should be optimized directly, then extract the prediction net from the final training net:
In[12]
Click for copyable input
Out[12]

Plotting the behavior of the resulting nets, we can see that the left-weighted net learned a good approximation on the left half of the input space, the right-weighted net learned a good approximation on the right half of the input space, and the unweighted net learned an approximation that does not favor either side.

Plot the predictions of the unweighted and weighted nets alongside the original function they were attempting to approximate:
In[13]
Click for copyable input
Out[13]

Weighting of Examples for Classification

This example shows how to bias the classification of ambiguous examples by using higher example weights for all examples of a specific class.

First, we create a synthetic dataset consisting of two clusters with a certain degree of overlap.

Synthesize clusters from unit-variance normal distributions at -1 and 1:
In[14]
Click for copyable input
Plot a histogram of the points in the clusters:
In[15]
Click for copyable input
Out[15]
Create training data suitable for NetTrain:
In[16]
Click for copyable input

We train a simple logistic regression model with no example weighting, to serve as a baseline with which we can compare example-weighted training. Note that we are using CrossEntropyLossLayer as the loss function; this is actually already the default, but later we will have to explicitly construct a training net, so we are highlighting it now.

Create a simple logistic regression model:
In[17]
Click for copyable input
Out[17]
Train the net:
In[18]
Click for copyable input
Out[18]
Evaluate the probabilities at the centers of the two clusters:
In[19]
Click for copyable input
Out[19]
In[20]
Click for copyable input
Out[20]
Plot the probability of the first class as a function of x:
In[21]
Click for copyable input
Out[21]

Next, we perform weighted training. This requires training data that emphasizes the examples belonging to the first cluster and constructing a training net that multiplies the cross-entropy loss we used previously with the training weight. The multiplication causes NetTrain to preferentially optimize for examples that have higher weights, in this case, examples from the first cluster.

Define some class weights and assign them to the data:
In[22]
Click for copyable input
Show a sample of the weighted training data:
In[23]
Click for copyable input
Out[23]
Create a training net that uses example weighting:
In[24]
Click for copyable input
Out[24]
Train the net with NetTrain, specifying that the "WeightedLoss" output should be optimized directly, then extract the prediction net from the final training net:
In[25]
Click for copyable input
Out[25]
Reattach the "Class" decoder, which was lost when the regression net was embedding in the training net:
In[26]
Click for copyable input
Out[26]
Evaluate the probabilities at the centers of the two clusters:
In[27]
Click for copyable input
Out[27]
In[28]
Click for copyable input
Out[28]

By plotting the probability learned by the weighted net, we can see that the weighted data biases the predictions of the net toward the first cluster, so that the threshold at which the two classes are seen as equally likely is further to the right.

Plot the probability of the first class as a function of x:
In[29]
Click for copyable input
Out[29]

We can also observe the difference by looking at the recall and confusion matrices. The unweighted net has roughly equal recall for the two classes and a symmetric confusion matrix. The weighted net has higher recall for class 1 at the cost of class 2 and an asymmetric confusion matrix.

Use ClassifierMeasurements to calculate the recall and confusion matrix plot:
In[30]
Click for copyable input
In[31]
Click for copyable input
Out[31]
In[32]
Click for copyable input
Out[32]

Related Tutorials