窝牛号

控制神经网络的另一条途径——推理启发式方程

随着训练数据的规模和覆盖范围的增加,深度神经网络(DNN)可提供更准确的结果。虽然投资于高质量和大规模的标记数据集是改进模型的一种途径,但另一种方法是利用先验知识,简称为“规则”——推理启发式方程、关联逻辑或约束。考虑一个物理学中的常见示例,其中模型的任务是预测双摆系统中的下一个状态。虽然模型可以仅从经验数据中学会估计系统在给定时间点的总能量,但它会经常高估能量,除非还提供反映已知物理约束的方程,例如能量守恒. 该模型本身无法捕捉到如此完善的物理规则。如何有效地教授这样的规则,以便 DNN 吸收相关知识,而不仅仅是从数据中学习?

在“使用规则表示控制神经网络”中,发表于NeurIPS 2021,Google提出了具有可控规则表示的深度神经网络 (DeepCTRL),这是一种用于为模型提供规则的方法,该模型与数据类型和模型架构无关,可应用于为输入和输出定义的任何类型的规则。DeepCTRL 的主要优势在于它不需要重新训练来适应规则强度。在推理时,用户可以根据所需的准确操作点调整规则强度。Google还提出了一种新颖的输入扰动方法,它有助于将 DeepCTRL 推广到不可逾越的约束。在结合规则至关重要的现实世界领域(例如物理和医疗保健)中,Google展示了 DeepCTRL 在深度学习规则方面的有效性。DeepCTRL 确保模型更严格地遵循规则,同时在下游任务中提供准确性增益,从而提高训练模型的可靠性和用户信任度。此外,DeepCTRL 支持新的用例,例如数据样本规则的假设测试和基于数据集之间共享规则的无监督适应。

从规则中学习的好处是多方面的:

规则可以为数据最少的案例提供额外的信息,从而提高测试的准确性。

广泛使用 DNN 的一个主要瓶颈是缺乏对其推理和不一致背后的基本原理的理解。通过最大限度地减少不一致,规则可以提高 DNN 的可靠性和用户信任度。

DNN 对人类无法察觉的细微的变化很敏感。有了规则,这些变化的影响可以最小化,因为模型搜索空间被进一步限制以减少规格不足。

从规则和任务中共同学习

执行规则 的传统方法通过将它们包含在损失计算中来将它们结合起来。Google旨在解决这种方法的三个局限性:(i)需要在学习之前定义规则强度(因此,经过训练的模型无法根据数据满足规则的程度灵活操作);( ii ) 如果与训练设置有任何不匹配,则规则强度不适用于推理时的目标数据;( iii ) 基于规则的目标需要在可学习参数方面是可区分的(以便能够从标记数据中学习)。

DeepCTRL 通过创建规则表示以及数据表示来修改规范训练,这是在推理时控制规则强度的关键。在训练期间,这些表示与控制参数(由α表示)随机连接成单个表示。通过增加α的值,可以提高输出决策规则的强度。通过在推理时修改α,用户可以控制模型的行为以适应看不见的数据。

DeepCTRL 将数据编码器和规则编码器配对,产生两个潜在表示,并与相应的目标相耦合。控制参数 α 在推理时是可调整的,以控制每个编码器的相对权重。

通过输入扰动集成规则

训练与基于规则的目标要求目标相对于模型的可学习参数是可微的。有许多有价值的规则在输入方面是不可微的。例如,“血压高于 140 可能会导致心血管疾病”是很难与传统 DNN 相结合的规则。Google还引入了一种新颖的输入扰动方法,通过向输入特征引入小扰动(随机噪声)并根据结果是否在所需方向构建基于规则的约束,将 DeepCTRL 推广到不可逾越的约束。

在已知物理学原理的情况下提高可靠性

Google用验证率来量化模型的可靠性,验证率是满足规则的输出样本的比例。以更好的验证率运行可能是有益的,尤其是在已知规则始终有效的情况下,例如在自然科学中。通过调整控制参数α,可以实现更高的规则验证率,从而实现更可靠的预测。

为了证明这一点,Google考虑了从给定初始状态的摩擦产生的双摆动力学的时间序列数据。Google将任务定义为从当前状态预测双摆的下一个状态,同时施加能量守恒规则。为了量化规则学习了多少,Google评估了验证率。

DeepCTRL 可以在学习后控制模型的行为,但无需重新训练。对于双摆的例子,传统的学习没有施加任何约束来确保模型遵循物理定律,例如能量守恒。规则强度较低的 DeepCTRL 的情况类似。因此,在时间t +1(蓝色)预测的系统总能量有时可能大于在时间 t(红色)测得的能量,这在物理上是不允许的(左下角)。如果 DeepCTRL 中的规则强度很高,模型可能会遵循给定的规则但会失去准确性(红色和蓝色之间的差异较大;右下角)。如果规则强度介于两个极端之间,模型可能会获得更高的准确性(蓝色曲线接近红色)并正确遵循规则(蓝色曲线低于红色曲线)。

Google将 DeepCTRL 在此任务上的性能与传统的训练基线进行了比较,该基线具有固定的基于规则的约束作为添加到目标λ的正则化项。这些正则化系数中最高的提供了最高的验证率(如下图第二张图中的绿线所示),但是,预测误差比λ = 0.1(橙线)略差。Google发现固定基线的最低预测误差与 DeepCTRL 相当,但固定基线的最高验证率仍然较低,这意味着 DeepCTRL 可以在遵循能量守恒定律的同时提供准确的预测。此外,Google考虑施加规则约束的基准Lagrangian Dual Framework (LDF) 并展示了两个结果,其中超参数由验证集上的最低平均绝对误差 (LDF-MAE) 和最高规则验证比 (LDF-Ratio) 选择。LDF 方法的性能对主要约束是什么高度敏感,其输出不可靠(黑色和粉色虚线)。

双摆任务的实验结果,显示了基于任务的平均绝对误差 (MAE),它测量了地面实况和模型预测之间的差异,而 DeepCTRL 作为控制参数α的函数。TaskOnly 没有规则约束,并且 Task & Rule 具有不同的规则强度 ( λ )。LDF 通过解决约束优化问题来强制执行规则。

同上,但显示了不同模型的验证率。

双摆任务的实验结果分别显示了时间t和t + 1 的当前能量和预测能量。

此外,上图说明了 DeepCTRL 相对于传统方法的优势。例如,将规则强度λ从 0.1 增加到 1.0 会提高验证率(从 0.7 到 0.9),但不会提高平均绝对误差。任意增加λ将继续推动验证比率接近 1,但会导致更差的准确性。因此,找到λ的最佳值将需要通过基线模型进行多次训练,而 DeepCTRL 可以更快地找到控制参数α的最佳值。

适应医疗保健的分配变化

某些规则的强度可能在数据子集之间有所不同。例如,在疾病预测中,老年患者的心血管疾病与高血压之间的相关性强于年轻患者。在这种情况下,当任务是共享的,但数据集之间的数据分布和规则的有效性不同时,DeepCTRL 可以通过控制α来适应分布的变化。

探索这个例子,Google专注于使用心血管疾病数据集 预测心血管疾病是否存在的任务。鉴于已知较高的收缩压与心血管疾病密切相关,Google考虑以下规则:“如果收缩压较高,则风险较高”。基于此,Google将患者分为两组:(1)异常,患者有高血压但没有疾病或血压降低但有疾病;(2)普通,即有高血压病或低血压但无病。

Google在下面证明源数据并不总是遵循规则,因此合并规则的效果可能取决于源数据。测试交叉熵,表示分类准确度(较低的交叉熵更好),与具有不同通常/异常比率的源或目标数据集的规则强度在下面可视化。误差随着α → 1单调增加,因为不能准确反映源数据的强加规则的执行变得更加严格。

测试通常/异常比率为 0.30的源数据集的交叉熵与规则强度。

当一个训练好的模型转移到目标域时,可以通过控制α来减少误差。为了证明这一点,Google展示了三个特定领域的数据集,Google称之为目标 1、2 和 3。在目标 1 中,大多数患者来自普通组,随着α的增加,基于规则的表示具有更多权重,由此产生的误差单调递减。

如上所述,但对于目标数据集 (1),通常/异常比率为 0.77。

当目标 2 和 3 中普通患者的比例降低时,最佳α是介于 0 和 1 之间的中间值。这些证明了通过α适应训练模型的能力。

如上所述,但对于目标 2,通常/异常比率为 0.50。

如上所述,但对于目标 3,通常/异常比率为 0.40。

结论

从规则中学习对于构建可解释、稳健和可靠的 DNN 至关重要。Google提出了 DeepCTRL,这是一种用于将规则合并到数据学习 DNN 中的新方法。DeepCTRL 无需重新训练即可在推理时控制规则强度。Google提出了一种新颖的基于扰动的规则编码方法,将任意规则集成到有意义的表示中。Google展示了 DeepCTRL 的三个用例:在已知原则的情况下提高可靠性、检查候选规则以及使用规则强度进行域适应

本站所发布的文字与图片素材为非商业目的改编或整理,版权归原作者所有,如侵权或涉及违法,请联系我们删除

窝牛号 wwww.93ysy.com   沪ICP备2021036305号-1