NetTrain

NetTrain[net,{input1output1,input2output2,}]

trains the specified neural net by giving the inputi as input and minimizing the discrepancy between the outputi and the actual output of the net, using an automatically chosen loss function.

NetTrain[net,port1{data11,data12,},port2{},]

trains the specified net by supplying training data at the specified ports.

NetTrain[net,"dataset"]

trains on a named dataset from the Wolfram Data Repository.

NetTrain[net,f]

calls the function f during training to produce batches of training data.

NetTrain[net,data,"prop"]

gives data associated with a specific property prop of the training session.

NetTrain[net,data,All]

gives a NetTrainResultsObject[] that summarizes information about the training session.

Details and Options

  • Any input ports of the net whose shapes are not fixed will be inferred from the form of training data, and NetEncoder objects will be attached if the training data contains Image objects, etc.
  • Individual training data inputs can be scalars, vectors, numeric tensors. If the net has appropriate NetEncoder objects attached, the inputs can include Image objects, strings, etc.
  • Named datasets that are commonly used as examples for neural net applications include the following:
  • "MNIST"60,000 classified handwritten digits
    "FashionMNIST"60,000 classified images of items of clothing
    "CIFAR-10","CIFAR-100"50,000 classified images of real-world objects
    "MovieReview"10,662 movie review snippets with sentiment
  • Training on a named dataset is equivalent to training on ResourceData["dataset","TrainingSet"]. If a named dataset is used for the ValidationSet option, it is equivalent to ResourceData["dataset","TestSet"].
  • Additional forms for specifying training data include {input1,input2,}->{output1,} and {port1,port2,,port1,,}.
  • When loss layers are automatically attached by NetTrain to output ports, their "Target" ports will be taken from the training data using the same name as the original output port.
  • When giving training data using the specification inputs->outputs, the network should not already contain any loss layers and should have precisely one input and one output port.
  • The following options are supported:
  • BatchSizeAutomatichow many examples to process in a batch
    LearningRateAutomaticrate at which to adjust weights to minimize loss
    LearningRateMultipliersAutomaticset relative learning rates within the net
    LossFunctionAutomaticthe loss function for assessing outputs
    MaxTrainingRoundsAutomatichow many times to traverse the training data
    MethodAutomaticthe training method to use
    PerformanceGoalAutomaticfavor settings with specific advantages
    TargetDevice"CPU"the target device on which to perform training
    TimeGoalAutomaticnumber of seconds to train for
    TrainingProgressMeasurementsAutomaticmeasurements to monitor, track and plot during training
    TrainingProgressCheckpointingNonehow to periodically save partially trained nets
    RandomSeedingInheritedhow to seed pseudorandom generators internally
    TrainingProgressFunctionNonefunction to call periodically during training
    TrainingProgressReportingAutomatichow to report progress during training
    TrainingStoppingCriterionNonehow to automatically stop training
    ValidationSetNonethe set of data on which to evaluate the model during training
    WorkingPrecisionAutomaticprecision of floating point calculations
  • If the loss is not given explicitly using LossFunction, a loss function will be chosen automatically based on the final layer or layers in the net.
  • With the default setting of MaxTrainingRounds->Automatic, training will occur for approximately 20 seconds, but never for more than 10,000 rounds.
  • With the setting of MaxTrainingRounds->n, training will occur for n rounds, where a round is defined to be a traversal of the entire training dataset.
  • The following settings for ValidationSet can be given:
  • Noneuse only the existing training set to estimate loss (default)
    datavalidation set in the same form as training data
    Scaled[frac]reserve a specified fraction of the training set for validation
    {spec,"Interval"int}specify the interval at which to calculate validation loss
  • For ValidationSet->{spec,"Interval"->int}, the interval can be an integer n, indicating that validation loss should be calculated every n training rounds, or a Quantity in units of seconds, minutes or hours.
  • For a named dataset such as "MNIST", specifying ValidationSet->Automatic will use the corresponding "TestData" content element.
  • If a validation set is specified, NetTrain will return the net that produced the lowest validation loss during training with respect to this set.
  • In NetTrain[net,f], the function f is applied to <|"BatchSize"n,"Round"r|> to generate each batch of training data in the form {input1->output1,} or <|"port1"->data,|>.
  • NetTrain[net,{f,"RoundLength"->n}] can be used to specify that f should be applied enough times during a training round to produce approximately n examples. The default is to apply f once per training round.
  • NetTrain[net,,ValidationSet->{g,"RoundLength"->n}] can be used to specify that the function g should be applied in an equivalent manner to NetTrain[net,{f,"RoundLength"->n}] in order to produce approximately n examples for the purposes of computing validation loss and accuracy.
  • Possible settings for WorkingPrecision include:
  • "Real32"use single-precision real (32-bit)
    "Real64"use double-precision real (64-bit)
    "Mixed"use half-precision real for certain operations
  • WorkingPrecision->"Mixed" is only supported for TargetDevice->"GPU", where it can result in significant performance increases on certain devices.
  • In NetTrain[net,data,prop], the property prop can be any of the following:
  • "TrainedNet"the optimal trained network found (default)
    "BatchesPerRound"the number of batches contained in a single round
    "BatchLossList"a list of the mean losses for each batch update
    "BatchMeasurementsLists"list of training measurements associations for each batch update
    "BatchPermutation"an array of the indices from the training data used to populate each batch
    "BatchSize"the effective value of BatchSize
    "BestValidationRound"the training round corresponding to the final trained net
    "CheckpointingFiles"list of checkpointing files generated during training
    "ExampleLosses"losses taken by each example during training
    "ExamplesProcessed"total number of examples processed during training
    "FinalLearningRate"the learning rate at the end of training
    "FinalNet"the final network generated in the training process
    "FinalPlots"association of plots for all losses and measurements
    "InitialLearningRate"the learning rate at the start of training
    "LossPlot"a plot of the evolution of the mean training loss
    "MeanBatchesPerSecond"the mean number of batches processed per second
    "MeanExamplesPerSecond"the mean number of input examples processed per second
    "NetTrainInputForm"an expression representing the originating call to NetTrain
    "OptimizationMethod"the name of the optimization method used
    "Properties"the full list of available properties
    "ReasonTrainingStopped"brief description of why training stopped
    "ResultsObject"a NetTrainResultsObject[] containing a majority of the available properties in this table
    "RoundLoss"the mean loss for the most recent round
    "RoundLossList"a list of the mean losses for each round
    "RoundMeasurements"association of training measurements for the most recent round
    "RoundMeasurementsLists"list of training measurements associations for each round
    "RoundPositions"the batch numbers corresponding to each round measurement
    "TargetDevice"the device used for training
    "TotalBatches"the total number of batches encountered during training
    "TotalRounds"the total number of rounds of training performed
    "TotalTrainingTime"the total time spent training, in seconds
    "TrainingExamples"the number of examples in the training set
    "TrainingNet"the network as prepared for training
    "ValidationExamples"the number of examples in the validation set
    "ValidationLoss"the mean loss obtained on the ValidationSet for the most recent validation measurement
    "ValidationLossList"list of the mean losses on the ValidationSet for each validation measurement
    "ValidationMeasurements"association of training measurements on the ValidationSet after the most recent validation measurement
    "ValidationMeasurementsLists"list of training measurements associations on the ValidationSet for each validation measurement
    "ValidationPositions"the batch numbers corresponding to each validation measurement
    "WeightsLearningRateMultipliers"an association of the learning rate multiplier used for each weight
  • An association of the form <|"Property"->prop,"Form"->form,"Interval"->int|> can be used to specify a custom property whose value will be collected repeatedly during training.
  • For a custom property, valid settings for prop can be any of the properties available in TrainingProgressFunction, or a user-defined function that is given the association of all the properties. Valid settings for form include "List", "TransposedList" and "Plot". Valid settings for "Interval" can be "Batch", "Round" or a Quantity[]. Supported units include "Batches", "Rounds", "Percent" and time units like "Seconds", "Minutes" and "Hours".
  • NetTrain[net,data,{prop1,prop2,}] returns a list of the results for the propi.
  • NetTrain[net,data,All] returns a NetTrainResultsObject[] that contains values for all properties that do not require significant additional computation or memory.
  • With the default setting of ValidationSet->None, the "TrainedNet" property yields the net as it is at the end of training. When a validation set is provided, the default criterion for selecting the optimal net depends on the type of net:
  • classification netchoose the net with the lowest error rate; ties are broken using the lowest loss
    non-classification netchoose the net with the lowest loss
  • The criterion used to select the "TrainedNet" property can be customized using the TrainingStoppingCriterion option.
  • The property "BestValidationRound" gives the exact round from which the final net was selected.
  • Possible settings for Method include:
  • "ADAM"stochastic gradient descent using an adaptive learning rate that is invariant to diagonal rescaling of the gradients
    "RMSProp"stochastic gradient descent using an adaptive learning rate derived from exponentially smoothed average of gradient magnitude
    "SGD"ordinary stochastic gradient descent with momentum
    "SignSGD"stochastic gradient descent for which the maginitude of the gradient is discarded
  • Valid settings for PerformanceGoal include Automatic, "TrainingMemory", "TrainingSpeed" or a list of goals to combine.
  • Valid settings for WorkingPrecision include the default value of "Real32", indicating single-precision floating point; "Real64", indicating double-precision floating point; and "Mixed", indicating a mixture of "Real32" and half-precision. Mixed-precision training is only supported for GPUs.
  • Suboptions for specific methods can be specified using Method{"method",opt1val1,}. The following suboptions are supported for all methods:
  • "LearningRate"Automaticthe size of steps to take in the direction of the derivative
    "LearningRateSchedule"Automatichow to scale the learning rate as training progresses
    "L2Regularization"Nonethe global loss associated with the L2 norm of all learned arrays
    "GradientClipping"Nonethe magnitude above which gradients should be clipped
    "WeightClipping"Nonethe magnitude above which weights should be clipped
  • With "LearningRateSchedule"->f, the learning rate for a given batch will be calculated as initial*f[batch,total], where batch is the current batch number, total is the total number of batches that will be visited during training, and initial is the initial learning rate specified using "LearningRate". The value returned by f should be a number between 0 and 1.
  • The suboptions "L2Regularization", "GradientClipping" and "WeightClipping" can be given in the following forms:
  • ruse the value r for all weights in the net
    {lspec1r1,lspec2r2,}use the value ri for the specific part lspeci of the net
  • The rules lspeciri are given in the same form as for LearningRateMultipliers.
  • For the method "SGD", the following additional suboptions are supported:
  • "Momentum"0.93how much to preserve the previous step when updating the derivative
  • For the method "ADAM", the following additional suboptions are supported:
  • "Beta1"0.9exponential decay rate for the first moment estimate
    "Beta2"0.999exponential decay rate for the second moment estimate
    "Epsilon"0.00001`stability parameter
  • For the method "RMSProp", the following additional suboptions are supported:
  • "Beta"0.95exponential decay rate for the moving average of the gradient magnitude
    "Epsilon"0.000001stability parameter
    "Momentum"0.9momentum term
  • For the method "SignSGD", the following additional suboption is supported:
  • "Momentum"0.93how much to preserve the previous step when updating the derivative
  • If a net already contains initialized or previously trained weights, these will be not be reinitialized by NetTrain before training is performed.

Examples

open all close all

Basic Examples  (6)

Train a single-layer linear net on input output pairs:

In[1]:=
Click for copyable input
Out[1]=

Predict the value of a new input:

In[2]:=
Click for copyable input
Out[2]=

Make several predictions at once:

In[3]:=
Click for copyable input
Out[3]=

The prediction is a linear function of the input:

In[4]:=
Click for copyable input
Out[4]=

Train a perceptron that classifies inputs as either True or False:

In[1]:=
Click for copyable input
Out[1]=

Predict whether a new input is True or False:

In[2]:=
Click for copyable input
Out[2]=

Obtain the probability of the input being True by disabling the NetDecoder:

In[3]:=
Click for copyable input
Out[3]=

Make several predictions at once:

In[4]:=
Click for copyable input
Out[4]=

Plot the probability as a function of the input:

In[5]:=
Click for copyable input
Out[5]=

Train a three-layer network to learn a 2D function:

In[1]:=
Click for copyable input
Out[1]=

Evaluate the network on an input:

In[2]:=
Click for copyable input
Out[2]=

Plot the prediction of the net as a function of x and y:

In[3]:=
Click for copyable input
Out[3]=

Train a recurrent network that predicts the maximum value seen in the input sequence:

In[1]:=
Click for copyable input
Out[1]=

Evaluate the network on an input:

In[2]:=
Click for copyable input
Out[2]=

Plot the output of the network as one element of a sequence is varied:

In[3]:=
Click for copyable input
Out[3]=

Train a net and produce a results object that summarizes the training process:

In[1]:=
Click for copyable input
Out[1]=

Obtain the trained net from the result:

In[2]:=
Click for copyable input
Out[2]=

Obtain the network used to train the net:

In[3]:=
Click for copyable input
Out[3]=

Train a net to classify handwritten digits using a named dataset and model:

In[1]:=
Click for copyable input
Out[1]=

Classify a difficult digit:

In[2]:=
Click for copyable input
Out[2]=

Scope  (14)

Options  (27)

Properties & Relations  (2)

Possible Issues  (1)

Interactive Examples  (1)

Neat Examples  (2)

Introduced in 2016
(11.0)
|
Updated in 2019
(12.0)