NetStateObject

NetStateObject[net]

creates an object derived from net that represents a neural net with additional stored state information that is updated when the net is applied to data.

NetStateObject[net,seed]

creates an object in which additional stored state information is initialized using seed.

Details

  • NetStateObject[][data] updates stored state information in the NetStateObject.
  • State information is associated with the state ports of recurrent net layers such as LongShortTermMemoryLayer.
  • NetStateObject will not store the state of layers whose state ports are initialized from other layers in a NetGraph.
  • When a seed is not provided, initial values for recurrent states will consist of arrays of zeros.
  • The current value of the stored states is given by NetExtract[NetStateObject[],"States"].

Examples

open allclose all

Basic Examples  (3)

Create a recurrent net:

Create a state object from the recurrent net:

Evaluate the state object on some data:

Due to the presence of the stored states, the behavior of the state object can change between evaluations, even on the same input:

Create a state object with a specified initial state:

Create a classifier that predicts the next element of a sequence:

Train the classifier on a set of input sequences:

Create a state object and use it to efficiently generate a maximum-likelihood sequence, starting from a single 1:

Applications  (1)

Training an English character-level language model. First, create 300,000 training examples of 25 characters each from two novels:

The data is of the form of a classification problem: given a sequence of characters, predict the next one. A sample of the data:

Obtain the list of all characters in the text:

Define a net that takes in a string of characters and returns a prediction for the next character:

Train the net. This can take up to an hour on a CPU; it is recommended to specify a GPU if possible:

Predict the next character, given a sequence of characters:

Generate 100 characters of text, given starting text:

One can get more interesting text by sampling from the probability distribution of predictions:

Properties & Relations  (3)

If the initial value of a recurrent layer's state is provided by a connection in a NetGraph, that state will not be stored by NetStateObject.

Create a graph that uses a connection to provide the initial value of the state of a BasicRecurrentLayer:

This graph cannot be used inside a NetStateObject, as there are no states left to store:

The current value of the stored states can be obtained using NetExtract.

First create a NetStateObject:

Apply the object to some data:

Extract the current value of the states:

For recurrent nets, using a NetStateObject is equivalent to manually keeping track of the recurrent states via NetPort[All,"States"].

To see this, create a classifier that predicts the next element of a sequence:

Train the classifier on a set of input sequences:

Create a state object and use it to efficiently generate a maximum-likelihood sequence, starting from a single 1:

Generate from the trained net using NetPort[All,"States"] to set and get the recurrent states, which yields the same result:

Introduced in 2018
 (11.3)