Beam search

引入

Beam search是经常应用在Seq2seq任务中的解码方法, 算法本身复杂度不高, 但往往能大幅提高解码质量.

Beam search只用在测试样本的解码过程中, 在训练时是用不到的.

对比

Seq2seq任务中, 对于一个输入, 我们的任务是具有最大概率的解码序列. 可以用公式如下表示:

argmaxyP(y<1>,y<2>,y<3>,y<Ty>x<1>,x<2>,x<Tx>)\arg \max _{y} P\left(y^{<1>}, y^{<2>}, y^{<3>}, \ldots y^{<T_{y}>} | x^{<1>}, x^{<2>}, \ldots x^{<T_{x}>}\right)

如果不使用任何解码方法, 在seq2seq模型中我们使用的就是贪心搜索(greedy search). 具体来说, 在tt时刻, 根据所有的输入(encoder的输入, 来自上一个元素的输出作为当前decoder的输入), 在字典中挑选出条件概率最大的词y<t>y^{<t>}, 之后的每个时刻一次类推.

因此, 贪心算法在每个时刻, 始终是选择最大概率的词, 但这样选择出的词拼接在一起组成的序列, 往往不是上式概率最大的情况, 甚至相差很大, 而我们需要的, 正是上式概率最大的序列. 这种情况的反面例子可以参考参考资料中的第一篇的Why not a greedy search?部分.

如果完全从序列的角度出发, 一定能找到最优的序列, 但是字典中所有单词组成的序列数量, 是无法枚举判断得到最优序列的.

Beam search就是一种折中的方案, 在tt时刻确定输出时, 会考虑之前时刻的输出序列, 并且会考虑多种前置序列, 但也使用了超参数beam width, 保证了搜索范围的高效.

Beam search虽然可能不会找到最优的方案, 但已经能够保证在高效的前提下, 找到接近于最优答案, 甚至就是最优答案的结果.

原理

具体的原理可以参考资料中第一篇的Beam Search部分, 以及第二篇. 这两篇文章介绍算法时都集合了具体的例子进行推导, 形象易懂.

代码

这是结合在seq2seq模型中的一个beam search算法的代码. Beam search算法的唯一超参数就是上文中提到的beam width, 也就是代码中的topK.

从投入开始标志<s>(代码对应的字典中的index为2)为起始, 逐个预测之后的每个元素. 在预测过程中, 除了需要保存现有的topK个候选序列(保存在变量target_seq中), 还要存储对应序列的整体概率(保存在topk_prob中).

另外需要注意的是, 代码中关于概率计算终止符</s>的处理方法.

参考资料

最后更新于

这有帮助吗?