带有不确定性的回归

混合密度网络
该章节演示了使用混合密度网络模拟回归问题的不确定性. 这种网络通过接受输入 x 值,产生近似 的混合分布的输出参数模拟后验分布 .
以下代码创建了对应 xy 值格式的合成训练数据. 数据不代表普通函数 ,因为对于每个 x 值,可以有多个 y 值. 另外,数据包含相当数量的噪声:
当给出输入值 x(其中,xy 可以是标量、向量、矩阵等),普通递归网络预测单个 y 值. 密度网络的基本思想是计算 y 值的分布. 网络按 x 函数学习该分布的参数. 混合密度网络学习简单分布的混合. 例如,我们使用 6 个高斯混合.
构建一个接受输入数并使用多层感知产生三个独立向量的网络. 每个向量包含 6 个数字表示 6 个独立高斯分量的参数. 这些向量的两个("mean" 和 "stddev")表示高斯的平均和标准差. 最终向量 ("weight") 是概率向量表示如何混合这 6 个高斯产生单个分布. 注意,我们使用 Exp 激活确保标准差是正的,并使用 SoftmaxLayer 确保权和为 1:
接下来,我们训练一个更大的网络培训这个参数网络. 这个更大的网络从我们的数据分布中接受实际 xy 值并计算负的对数似然性,它是度量我们参数网络表示的模型中数据的可能性. 通过最小化负的对数似然性,我们有效地最大化实际数据的似然性,它是训练概率模型的常用技术.
训练网计算由参数网产生的六个高斯下的单个 y 值的似然性. 为了将这些单独的似然性合并为高斯混合的单个似然性,我们使用权向量执行加权和. 最后,我们接受负对数.
构建培训网络,使用 ThreadingLayer 计算高斯似然性,DotLayer 接受 6 个高斯似然性的加权和,以及一个 ElementwiseLayer 把这个变成负的对数-似然性:
我们看一下单个数据点上随机初始化网络的损失,以确保工作正常.
NetInitialize 随机初始化网络并把之应用于单个输入:
我们现在可以训练模型,这对应于同时最大化生成数据集中每个点的模型的似然性. 训练完成后,我们将从训练网内提取参数网. 我们不再需要训练网络,因为我们不需要再次计算训练数据的负对数-似然性. 当给定 x 值时,参数网产生均值、标准偏差和权重的关联.
Train the net with用 NetTrain 训练网络 3000 轮. 指明 LossFunction"Loss" 以确保 NetTrain 直接从 "Loss" 端口最小化输出,而不是试着自动附加损失层:
使用 NetExtract 从最终的训练网络中提取已训练的参数网络:
把已训练的参数网络应用到输入:
当给定 x 值,定义一个函数构建一个 MixtureDistribution
在指定的 x 值上应用函数:
从该分布采样新的 y 值:
绘制该分布的 PDF. 这是后验
对于原始数据中的每个 x 值,从由模型计算的后验分布 中取一个样本:
重叠原始数据和模型采样的图:
我们已经学习的密度模型,因为它已经足够计算指定值 xy 的概率密度 . 我们可以从已训练的网络中删除计算似然性的负对数的层,产生计算似然性的网络.
使用 NetDelete 删除计算似然性的负对数的 ElementwiseLayer
由多种方式可以可视化该密度模型的行为. 最简单的是在 xy 值得密度网格上采样似然性来产生密度图. 我们也可以可视化单个分量和它们的均值和权值如何随着 x 函数而变化.
使用 CoordinateBoundsArray 创建 {x,y} 对的网格,然后分别展平成 xy 值列表:
在单个批次中使用似然性网络有效计算这些值的概率,然后不展平概率,再次形成矩阵:
绘制矩阵:
计算 x 值范围上的混合参数:
变量 "means" 和 "stddevs" 现在包含每个 x 位置 6 个高斯分量的参数. 混合权由 "weights" 包含:
将各个混合分量绘制为包络线,其中实线是每个分量的平均值,阴影区域显示由标准偏差覆盖的 y 值的范围.
通过绘制线 mean-stddevmeanmean+stddev,以及它们之间适当的阴影定义绘制单个分量的函数. 然后把这些图组合成单张图:
正如您所看到的,当相应的混合权重接近于零时,与给定分量相关的标准差变得非常大,因为该分量对模型没有贡献并因此造成损失.
接下来,我们以 x 值得函数绘制混合权重,其中,颜色匹配之前图中显示的分量.
使用 StackedListPlot 作为 x 值函数彼此堆叠的 6 个权重值. 在每个 x 值,它们总和为 1,这归功于参数网中的 SoftmaxLayer[]
最后,我们尝试同时显示混合分量的平均值和混合权重. 随着混合权重的减少,我们使与该分量相关的线条淡出. 将其与原始数据集进行比较,可以更容易地看到每个 x 值处的主混合分量如何在该 x 值处的原始数据集上反映 y 值的聚类.
使用 ColorFunction 选项允许每个 ListLinePlot 淡出对应的平均线,根据权重列表查询插值函数中的淡出量: