Example-Weighted Neural Network Training
Example weighting is a common variant of neural network training in which different examples in the training data are given different importance. Simply put, this is accomplished by multiplying the loss of each example by the weight associated with this example to accord it higher or lower importance in the optimization process performed by NetTrain.
- Correctly classifying certain examples might be more important than classifying other examples. Imagine a binary classifier used for fraud detection: false positives might be benign, but false negatives catastrophic. One way to address this during training is to place greater example weight on positive examples than on negative examples.
- Similarly, if we have a prior distribution for the occurrence of classes of a classification problem, but our training data consists of relatively balanced numbers of the different classes, we can incorporate this prior distribution directly into the learning task by weighting the examples relative to the prior probability of the corresponding class.
- The training data might represent measurements containing variable amounts of noise, or there might be examples that are mislabeled. This can be addressed by placing higher example weight on examples in which there is higher confidence.
- Certain regions of the training data space might be harder for the net to learn than others. This can be addressed by emphasizing examples that fall in this space, using higher example weights.
- In curriculum learning, the way the model is trained is changed over time as it improves. For example, the model might be trained first on easier examples and later on harder examples. One way to accomplish this is to dynamically change the example weights associated with specific sets of examples as training progresses.
We train a simple linear regression model with no example weighting, to serve as a baseline with which we can compare example-weighted training. Note that we are using MeanSquaredLossLayer as the loss function; this is actually already the default, but later we will have to explicitly construct a training net, so we are highlighting the choice of loss function now.
Next, we perform weighted training. We will create two datasets that emphasize examples to the left and right of the origin, respectively. We will construct a training net that multiplies the cross-entropy loss we used previously with the training weight. This multiplication causes NetTrain to preferentially optimize for examples that have higher weights.
Plotting the behavior of the resulting nets, we can see that the left-weighted net learned a good approximation on the left half of the input space, the right-weighted net learned a good approximation on the right half of the input space, and the unweighted net learned an approximation that does not favor either side.
We train a simple logistic regression model with no example weighting, to serve as a baseline with which we can compare example-weighted training. Note that we are using CrossEntropyLossLayer as the loss function; this is actually already the default, but later we will have to explicitly construct a training net, so we are highlighting it now.
Next, we perform weighted training. This requires training data that emphasizes the examples belonging to the first cluster and constructing a training net that multiplies the cross-entropy loss we used previously with the training weight. The multiplication causes NetTrain to preferentially optimize for examples that have higher weights, in this case, examples from the first cluster.
By plotting the probability learned by the weighted net, we can see that the weighted data biases the predictions of the net toward the first cluster, so that the threshold at which the two classes are seen as equally likely is further to the right.
We can also observe the difference by looking at the recall and confusion matrices. The unweighted net has roughly equal recall for the two classes and a symmetric confusion matrix. The weighted net has higher recall for class 1 at the cost of class 2 and an asymmetric confusion matrix.