NetStateObject
NetStateObject[net]
creates an object derived from net that represents a neural net with additional stored state information that is updated when the net is applied to data.
NetStateObject[net,seed]
creates an object in which additional stored state information is initialized using seed.
Details
- NetStateObject[…][data] updates stored state information in the NetStateObject.
- State information is associated with the state ports of recurrent net layers such as LongShortTermMemoryLayer.
- NetStateObject will not store the state of layers whose state ports are initialized from other layers in a NetGraph.
- When a seed is not provided, initial values for recurrent states will consist of arrays of zeros.
- The current value of the stored states is given by NetExtract[NetStateObject[…],"States"].
Examples
open allclose allBasic Examples (3)
Create a state object from the recurrent net:
Evaluate the state object on some data:
Due to the presence of the stored states, the behavior of the state object can change between evaluations, even on the same input:
Create a state object with a specified initial state:
Create a classifier that predicts the next element of a sequence:
Train the classifier on a set of input sequences:
Create a state object and use it to efficiently generate a maximum-likelihood sequence, starting from a single 1:
Applications (1)
Training an English character-level language model. First, create 300,000 training examples of 25 characters each from two novels:
The data is of the form of a classification problem: given a sequence of characters, predict the next one. A sample of the data:
Obtain the list of all characters in the text:
Define a net that takes in a string of characters and returns a prediction for the next character:
Train the net. This can take up to an hour on a CPU; it is recommended set option TargetDevice to "GPU" if a supported GPU is available:
Predict the next character, given a sequence of characters:
Generate 100 characters of text, given starting text:
One can get more interesting text by sampling from the probability distribution of predictions:
Properties & Relations (3)
If the initial value of a recurrent layer's state is provided by a connection in a NetGraph, that state will not be stored by NetStateObject.
Create a graph that uses a connection to provide the initial value of the state of a BasicRecurrentLayer:
This graph cannot be used inside a NetStateObject, as there are no states left to store:
The current value of the stored states can be obtained using NetExtract.
First create a NetStateObject:
Apply the object to some data:
Extract the current value of the states:
For recurrent nets, using a NetStateObject is equivalent to manually keeping track of the recurrent states via NetPort[All,"States"].
To see this, create a classifier that predicts the next element of a sequence:
Train the classifier on a set of input sequences:
Create a state object and use it to efficiently generate a maximum-likelihood sequence, starting from a single 1:
Generate from the trained net using NetPort[All,"States"] to set and get the recurrent states, which yields the same result:
Text
Wolfram Research (2018), NetStateObject, Wolfram Language function, https://reference.wolfram.com/language/ref/NetStateObject.html.
CMS
Wolfram Language. 2018. "NetStateObject." Wolfram Language & System Documentation Center. Wolfram Research. https://reference.wolfram.com/language/ref/NetStateObject.html.
APA
Wolfram Language. (2018). NetStateObject. Wolfram Language & System Documentation Center. Retrieved from https://reference.wolfram.com/language/ref/NetStateObject.html