NetTrain

NetTrain[net,{input1output1,input2output2,}]

trains the specified neural net by giving the inputi as input and minimizing the discrepancy between the outputi and the actual output of the net, using an automatically chosen loss function.

NetTrain[net,port1{data11,data12,},port2{},]

trains the specified net by supplying training data at the specified ports.

NetTrain[net,"dataset"]

trains on a named dataset from the Wolfram Data Repository.

NetTrain[net,f]

calls the function f during training to produce batches of training data.

NetTrain[net,data,"prop"]

gives data associated with a specific property prop of the training session.

NetTrain[net,data,All]

gives a NetTrainResultsObject[] that summarizes information about the training session.

Details and Options

  • NetTrain is used to teach a neural net to recognize patterns and make predictions by adjusting its parameters based on input data and correct outputs.
  • During training, the network's parameters, such as weights and biases, are adjusted using optimization algorithms like gradient descent to minimize the difference between the predicted outputs and the actual outputs, thus improving the network's accuracy over time.
  • Any input ports of the net whose shapes are not fixed will be inferred from the form of training data, and NetEncoder objects will be attached if the training data contains Image objects etc.
  • Possible forms of data include:
  • "dataset"a named dataset
    {input1output1,}a list of Rule instances between input and output
    {input1,}->{output1,}a Rule between inputs and corresponding outputs
    {port1,,}a list of associations with inputs for the specified ports
    port1{data11,data12,},a association of lists of input for the specified ports
    Dataset[]a dataset object
    Tabular[]a tabular object
    fa function that creates training batches
  • Individual training data inputs can be scalars, vectors, numeric tensors. If the net has appropriate NetEncoder objects attached, the inputs can include Image objects, strings, etc.
  • Named datasets that are commonly used as examples for neural net applications include the following:
  • "MNIST"60,000 classified handwritten digits
    "FashionMNIST"60,000 classified images of items of clothing
    "CIFAR-10","CIFAR-100"50,000 classified images of real-world objects
    "MovieReview"10,662 movie review snippets with sentiment polarity
  • Training on a named dataset is equivalent to training on ResourceData["dataset","TrainingData"], or ExampleData[{"MachineLearning","dataset"},"TrainingData"] if ResourceObject["dataset"] does not exist. If a named dataset is used for the ValidationSet option, it is equivalent to ResourceData["dataset","TestData"] or ExampleData[{"MachineLearning","dataset"},"TestData"].
  • When giving training data using the specification {input1output1,}, the network should not already contain any loss layers and should have precisely one input and one output port.
  • Additional forms for specifying training data include {input1,input2,}->{output1,} and {port1,port2,,port1,,}.
  • When loss layers are automatically attached by NetTrain to output ports, their "Target" ports will be taken from the training data using the same name as the original output port.
  • The following options are supported:
  • BatchSize Automatichow many examples to process in a batch
    LearningRate Automaticrate at which to adjust weights to minimize loss
    LearningRateMultipliers Automaticset relative learning rates within the net
    LossFunction Automaticthe loss function for assessing outputs
    MaxTrainingRounds Automatichow many times to traverse the training data
    Method Automaticthe training method to use
    PerformanceGoalAutomaticfavor settings with specific advantages
    TargetDevice "CPU"the target device on which to perform training
    TimeGoal Automaticnumber of seconds to train for
    TrainingProgressMeasurements Automaticmeasurements to monitor, track and plot during training
    TrainingProgressCheckpointing Nonehow to periodically save partially trained nets
    RandomSeeding1234how to seed pseudorandom generators internally
    TrainingProgressFunction Nonefunction to call periodically during training
    TrainingProgressReporting Automatichow to report progress during training
    TrainingStoppingCriterion Nonehow to automatically stop training
    TrainingUpdateSchedule Automaticwhen to update specific parts of the net
    ValidationSet Nonethe set of data on which to evaluate the model during training
    WorkingPrecision Automaticprecision of floating-point calculations
  • If the loss is not given explicitly using LossFunction, a loss function will be chosen automatically based on the final layer or layers in the net.
  • With the default setting of BatchSize->Automatic, the batch size will be chosen automatically, based on the memory requirements of the network and the memory available on the target device. The maximum batch size that will be automatically chosen is 64.
  • With the default setting of MaxTrainingRounds->Automatic, training will occur for approximately 20 seconds, but never for more than 10,000 rounds.
  • With the setting of MaxTrainingRounds->n, training will occur for n rounds, where a round is defined to be a traversal of the entire training dataset.
  • The following settings for ValidationSet can be given:
  • Noneuse only the existing training set to estimate loss (default)
    datavalidation set in the same form as training data
    Scaled[frac]reserve a specified fraction of the training set for validation
    {spec,"Interval"int}specify the interval at which to calculate validation loss
  • For ValidationSet->{spec,"Interval"->int}, the interval can be an integer n, indicating that validation loss should be calculated every n training rounds, or a Quantity in units of seconds, minutes or hours.
  • For a named dataset such as "MNIST", specifying ValidationSet->Automatic will use the corresponding "TestData" content element.
  • If a validation set is specified, NetTrain will return the net that produced the lowest validation loss during training with respect to this set.
  • In NetTrain[net,f], the function f is applied to <|"BatchSize"n,"Round"r|> to generate each batch of training data in the form {input1->output1,} or <|"port1"->data,|>.
  • NetTrain[net,{f,"RoundLength"->n}] can be used to specify that f should be applied enough times during a training round to produce approximately n examples. The default is to apply f once per training round.
  • NetTrain[net,,ValidationSet->{g,"RoundLength"->n}] can be used to specify that the function g should be applied in an equivalent manner to NetTrain[net,{f,"RoundLength"->n}] in order to produce approximately n examples for the purposes of computing validation loss and accuracy.
  • Possible settings for TargetDevice include:
  • "CPU"train on the CPU
    "GPU"train on a CUDA-compatible GPU
  • The "GPU" setting is resolved to "CUDA". Other settings are currently not supported.
  • Possible settings for WorkingPrecision include:
  • "Real32"use single-precision real (32-bit)
    "Real64"use double-precision real (64-bit)
    "Mixed"use half-precision real for certain operations
  • WorkingPrecision->"Mixed" is only supported for TargetDevice->"GPU", where it can result in significant performance increases on certain devices.
  • In NetTrain[net,data,prop], the property prop can be any of the following:
  • "TrainedNet"the optimal trained network found (default)
    "BatchesPerRound"the number of batches contained in a single round
    "BatchLossList"a list of the mean losses for each batch update
    "BatchMeasurementsLists"list of training measurements associations for each batch update
    "BatchPermutation"an array of the indices from the training data used to populate each batch
    "BatchSize"the effective value of BatchSize
    "BestValidationRound"the training round corresponding to the final trained net
    "CheckpointingFiles"list of checkpointing files generated during training
    "ExampleLosses"losses taken by each example during training
    "ExamplesProcessed"total number of examples processed during training
    "FinalLearningRate"the learning rate at the end of training
    "FinalNet"the most recent network generated in the training process, regardless of its performance on the validation set or other metrics
    "FinalPlots"association of plots for all losses and measurements
    "InitialLearningRate"the learning rate at the start of training
    "LossPlot"a plot of the evolution of the mean training loss
    "MeanBatchesPerSecond"the mean number of batches processed per second
    "MeanExamplesPerSecond"the mean number of input examples processed per second
    "NetTrainInputForm"an expression representing the originating call to NetTrain
    "OptimizationMethod"the name of the optimization method used
    "Properties"the full list of available properties
    "ReasonTrainingStopped"brief description of why training stopped
    "ResultsObject"a NetTrainResultsObject[] containing a majority of the available properties in this table
    "RoundLoss"the mean loss for the most recent round
    "RoundLossList"a list of the mean losses for each round
    "RoundMeasurements"association of training measurements for the most recent round
    "RoundMeasurementsLists"list of training measurements associations for each round
    "RoundPositions"the batch numbers corresponding to each round measurement
    "TargetDevice"the device used for training
    "TotalBatches"the total number of batches encountered during training
    "TotalRounds"the total number of rounds of training performed
    "TotalTrainingTime"the total time spent training, in seconds
    "TrainingExamples"the number of examples in the training set
    "TrainingNet"the network as prepared for training
    "TrainingUpdateSchedule"the value of TrainingUpdateSchedule
    "ValidationExamples"the number of examples in the validation set
    "ValidationLoss"the mean loss obtained on the ValidationSet for the most recent validation measurement
    "ValidationLossList"list of the mean losses on the ValidationSet for each validation measurement
    "ValidationMeasurements"association of training measurements on the ValidationSet after the most recent validation measurement
    "ValidationMeasurementsLists"list of training measurements associations on the ValidationSet for each validation measurement
    "ValidationPositions"the batch numbers corresponding to each validation measurement
    "WeightsLearningRateMultipliers"an association of the learning rate multiplier used for each weight
  • An association of the form <|"Property"->prop,"Form"->form,"Interval"->int|> can be used to specify a custom property whose value will be collected repeatedly during training.
  • For a custom property, valid settings for prop can be any of the properties available in TrainingProgressFunction, or a user-defined function that is given the association of all the properties. Valid settings for form include "List", "TransposedList" and "Plot". Valid settings for "Interval" can be "Batch", "Round" or a Quantity[]. Supported units include "Batches", "Rounds", "Percent" and time units like "Seconds", "Minutes" and "Hours".
  • NetTrain[net,data,{prop1,prop2,}] returns a list of the results for the propi.
  • NetTrain[net,data,All] returns a NetTrainResultsObject[] that contains values for all properties that do not require significant additional computation or memory.
  • With the default setting of ValidationSet->None, the "TrainedNet" property yields the net as it is at the end of training. When a validation set is provided, the default criterion for selecting the optimal net depends on the type of net:
  • classification netchoose the net with the lowest error rate; ties are broken using the lowest loss
    non-classification netchoose the net with the lowest loss
  • The criterion used to select the "TrainedNet" property can be customized using the TrainingStoppingCriterion option.
  • The property "BestValidationRound" gives the exact round from which the final net was selected.
  • Possible settings for Method include:
  • "ADAM"stochastic gradient descent using an adaptive learning rate that is invariant to diagonal rescaling of the gradients
    "RMSProp"stochastic gradient descent using an adaptive learning rate derived from exponentially smoothed average of gradient magnitude
    "SGD"ordinary stochastic gradient descent with momentum
    "SignSGD"stochastic gradient descent for which the magnitude of the gradient is discarded
  • Valid settings for PerformanceGoal include Automatic, "TrainingMemory", "TrainingSpeed" or a list of goals to combine.
  • Valid settings for WorkingPrecision include the default value of "Real32", indicating single-precision floating point; "Real64", indicating double-precision floating point; and "Mixed", indicating a mixture of "Real32" and half-precision. Mixed-precision training is only supported for GPUs.
  • Suboptions for specific methods can be specified using Method{"method",opt1val1,}. The following suboptions are supported for all methods:
  • "LearningRateSchedule"Automatichow to scale the learning rate as training progresses
    "L2Regularization"Nonethe global loss associated with the L2 norm of all learned arrays
    "GradientClipping"Nonethe magnitude above which gradients should be clipped
    "WeightClipping"Nonethe magnitude above which weights should be clipped
  • With "LearningRateSchedule"->f, the learning rate for a given batch will be calculated as initial*f[batch,total], where batch is the current batch number, total is the total number of batches that will be visited during training, and initial is the initial learning rate specified using the LearningRate option. The value returned by f should be a number between 0 and 1.
  • The suboptions "L2Regularization", "GradientClipping" and "WeightClipping" can be given in the following forms:
  • ruse the value r for all weights in the net
    {lspec1r1,lspec2r2,}use the value ri for the specific part lspeci of the net
  • The rules lspeciri are given in the same form as for LearningRateMultipliers.
  • For the method "SGD", the following additional suboptions are supported:
  • "Momentum"0.93how much to preserve the previous step when updating the derivative
  • For the method "ADAM", the following additional suboptions are supported:
  • "Beta1"0.9exponential decay rate for the first moment estimate
    "Beta2"0.999exponential decay rate for the second moment estimate
    "Epsilon"0.00001`stability parameter
  • For the method "RMSProp", the following additional suboptions are supported:
  • "Beta"0.95exponential decay rate for the moving average of the gradient magnitude
    "Epsilon"0.000001stability parameter
    "Momentum"0.9momentum term
  • For the method "SignSGD", the following additional suboption is supported:
  • "Momentum"0.93how much to preserve the previous step when updating the derivative
  • If a net already contains initialized or previously trained weights, these will be not be reinitialized by NetTrain before training is performed.

Examples

open allclose all

Basic Examples  (6)

Train a single-layer linear net on input output pairs:

Predict the value of a new input:

Make several predictions at once:

The prediction is a linear function of the input:

Train a perceptron that classifies inputs as either True or False:

Predict whether a new input is True or False:

Obtain the probability of the input being True by disabling the NetDecoder:

Make several predictions at once:

Plot the probability as a function of the input:

Train a three-layer network to learn a 2D function:

Evaluate the network on an input:

Plot the prediction of the net as a function of x and y:

Train a recurrent network that predicts the maximum value seen in the input sequence:

Evaluate the network on an input:

Plot the output of the network as one element of a sequence is varied:

Train a net and produce a results object that summarizes the training process:

Obtain the trained net from the result:

Obtain the network used to train the net:

Train a net to classify handwritten digits using a named dataset and model:

Classify a difficult digit:

Scope  (15)

Networks  (3)

Train a single-layer network:

Train a linear chain of layers:

Train a directed graph of layers:

Data Formats  (8)

Specify training data as a list of rules, each representing an input and the corresponding target output:

Specify training data as a rule whose left-hand side is a list of inputs, and whose right-hand side is a list of the corresponding target outputs:

Specify training data as an association whose keys are port names:

Specify training data as a list of associations, each representing a single training example:

Train a net using a named dataset:

Specify training data as a Dataset whose rows are individual examples:

Specify training data as a Tabular object whose rows are individual examples:

Generate training data during training. First create a network to train:

Define a generator to produce a batch of examples in the form of a list of rules:

Train using the generator with a specific BatchSize:

Specify that the generator should be called 4 times per round, to produce 64 examples per round:

Properties  (4)

Obtain a NetTrainResultsObject for a training session:

Query the results object for specific properties:

Get the original form of the call to NetTrain:

Get a list of all available properties:

Obtain plots of the evolution of the magnitude of the per-array gradients during the training of a convolutional net:

Train on the MNIST dataset and record the losses of individual examples through time:

Plot the loss associated with a single example over time:

Find the most difficult examples by calculating the mean loss for each example and taking the indices of the 20 largest such mean losses:

Show the average evolution of losses and error rates for the digits "0", "3", "8" and "9":

Create customized versions of the final plots:

Examine the original loss plot:

Create a new loss plot using only the round and validation measurements:

Create a new loss plot without the logarithmic scaling:

Options  (27)

BatchSize  (1)

Using a large batch size will typically increase the number of examples that can be evaluated per second:

A smaller batch size results in fewer examples being evaluated per second:

Depending on the task and network, larger batch sizes will allow higher learning rates to be used and can also allow for more efficient use of available hardware.

LearningRate  (1)

Train a network with a learning rate of 0.01:

LearningRateMultipliers  (1)

Create and initialize a net with three layers, but train only the last layer:

Evaluate the trained net on an input:

The first layer of the initial net started with zero biases:

The biases of the first layer remain zero in the trained net:

The biases of the third layer have been trained:

LossFunction  (4)

Train a simple net using MeanSquaredLossLayer, the default loss applied to an output when it is not produced by a SoftmaxLayer:

Evaluate the trained net on a set of inputs:

Specify a different loss layer to attach to the output of the network. First create a loss layer:

The loss layer takes an input and a target and produces a loss:

Use this loss layer during training:

Evaluate the trained net on a set of inputs:

Create a net that takes a vector of length 2 and produces one of the class labels Less or Greater:

NetTrain will automatically use a CrossEntropyLossLayer object with the correct class encoder:

Evaluate the trained net on a set of inputs:

Create an explicit loss layer that expects the targets to be in the form of probabilities for each class:

The training data should now consist of vectors of probabilities instead of symbolic classes:

Evaluate the trained net on a set of inputs:

Start with an "evaluation net" to be trained:

Create a "loss net" that explicitly computes the loss of the evaluation net (here, the custom loss is equivalent to MeanSquaredLossLayer):

Train this net on some synthetic data, specifying that the output port named "Loss" should be interpreted as a loss:

Obtain the trained "evaluation" network using NetExtract:

Part syntax can also be used:

Plot the net's output on the plane:

Create a training net that computes both an output and multiple explicit losses:

This network requires an input and a target in order to produce the output and losses:

Train using a specific output as loss, ignoring all other outputs:

Measure the loss on a single pair after training:

Specify multiple losses to be trained jointly:

Use NetTake to remove all but the desired output and input:

Evaluate this network on a set of inputs:

MaxTrainingRounds  (2)

Train a network such that it visits each example exactly once:

When both MaxTrainingRounds and TimeGoal are specified, the shorter of the two will be used (note that this example should be run twice to avoid the initial preprocessing overhead):

Method  (2)

Use stochastic gradient descent with momentum to train a simple network:

Specify a learning rate schedule with an initial learning rate:

Use regularization to prevent overfitting. Create synthetic training data based on a Gaussian curve:

Train a net with a large number of parameters relative to the amount of training data:

The resulting net overfits the data, learning the noise in addition to the underlying function:

Using the "L2Regularization" option will inject a loss proportional to the square of each weight parameter. This will tend to promote sparser weight matrices and hence mitigate overfitting:

TargetDevice  (1)

Train a net using the default system GPU, if a CUDA-enabled card is available:

If a compatible GPU is not available, a message is issued and $Failed is returned:

TimeGoal  (1)

Train a network for approximately 5 seconds:

TrainingProgressCheckpointing  (1)

Take periodic checkpoints of a convolutional network during training on the MNIST dataset:

List all created checkpoints:

Import the final checkpoint:

TrainingProgressFunction  (1)

Use TrainingProgressFunction to append information about the state of training to a file. Create a log file:

Define functions to append the batch number and loss to the log file:

Define the training data and perform training:

Read the log file:

Put the saved data into a Dataset:

Plot the loss over value over time:

TrainingProgressMeasurements  (1)

Examine the final validation precision and recall for LeNet trained on FashionMNIST:

Animate the confusion matrix over the training period:

TrainingProgressReporting  (6)

Show training progress interactively during training:

Print training progress periodically during training:

Show a simple progress indicator:

Perform custom reporting:

Write training progress information to a file:

Do not report progress:

TrainingStoppingCriterion  (1)

Prevent overfitting by stopping training when the validation loss stops improving. Set up a simple net as well as some training and validation data:

Use TrainingStoppingCriterion to stop training when the validation loss stops improving:

Use TrainingStoppingCriterion to stop training if the validation loss does not improve by at least 0.001 for more than 5 rounds in a row:

Use a callback function to stop training. Set up the net and data:

Stop training when the validation loss is higher than 1.75:

TrainingUpdateSchedule  (1)

Train a NetGANOperator by alternating updates of the discriminator and updates of the generator:

ValidationSet  (1)

Provide a ValidationSet to NetTrain to prevent overfitting. Create synthetic training data based on a Gaussian curve:

Train a net with a large number of parameters relative to the amount of training data:

The resulting net overfits the data, learning the noise in addition to the underlying function:

Use the ValidationSet option to have NetTrain select the net that achieved the lowest validation loss during training. NetTrain will use 20% of the training data, selected at random, to create a validation set:

The result returned by NetTrain was the net that generalized best to points in the validation set, as measured by validation loss. This penalizes overfitting, as the noise present in the training data is uncorrelated with the noise present in the validation set:

WorkingPrecision  (2)

Train a net with 64-bit precision:

Evaluate the trained net with 64-bit precision:

Train a net with mixed precision, taking advantage of hardware optimizations like NVIDIA Tensor Cores:

Properties & Relations  (2)

NetChain objects can be used as layers in a NetGraph:

NetGraph objects with one input and one output can be used as layers inside NetChain objects:

Possible Issues  (1)

By default, NetTrain uses RandomSeeding1234, which will use the same random seed to initialize the net when NetTrain is called repeatedly:

Use RandomSeedingAutomatic to ensure that repeated calls to NetTrain use different initializations:

Interactive Examples  (1)

Monitor the solution while training a net to solve a least-squares problem. First generate training data:

Create a network to fit the data:

Replace the default progress panel with a dynamically updated plot of the current behavior of the net:

Plot the final net after 10 seconds of training:

Neat Examples  (2)

Convert a test image into a training set, in which pixel positions (x,y) are mapped to color values (r,g,b):

Create a network to predict the color based on pixel position:

Train the network:

Use the network to predict the entire original image:

A high-dimensional embedding may provide a better prediction:

Train this alternative network:

Predict the entire image using this new net:

Use the association form of property argument to plot the intermediate curves fitted to a least-squares problem. First generate training data:

Create a network to fit the data:

Train the net, specifying that the property to return is the net evaluated every 100 rounds:

Animate the list to show the solution converging over time:

Wolfram Research (2016), NetTrain, Wolfram Language function, https://reference.wolfram.com/language/ref/NetTrain.html (updated 2025).

Text

Wolfram Research (2016), NetTrain, Wolfram Language function, https://reference.wolfram.com/language/ref/NetTrain.html (updated 2025).

CMS

Wolfram Language. 2016. "NetTrain." Wolfram Language & System Documentation Center. Wolfram Research. Last Modified 2025. https://reference.wolfram.com/language/ref/NetTrain.html.

APA

Wolfram Language. (2016). NetTrain. Wolfram Language & System Documentation Center. Retrieved from https://reference.wolfram.com/language/ref/NetTrain.html

BibTeX

@misc{reference.wolfram_2024_nettrain, author="Wolfram Research", title="{NetTrain}", year="2025", howpublished="\url{https://reference.wolfram.com/language/ref/NetTrain.html}", note=[Accessed: 15-January-2025 ]}

BibLaTeX

@online{reference.wolfram_2024_nettrain, organization={Wolfram Research}, title={NetTrain}, year={2025}, url={https://reference.wolfram.com/language/ref/NetTrain.html}, note=[Accessed: 15-January-2025 ]}