190 lines
10 KiB
Plaintext
190 lines
10 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"id": "06969ee4",
|
||
"metadata": {
|
||
"origin_pos": 0
|
||
},
|
||
"source": [
|
||
"# 束搜索\n",
|
||
":label:`sec_beam-search`\n",
|
||
"\n",
|
||
"在 :numref:`sec_seq2seq`中,我们逐个预测输出序列,\n",
|
||
"直到预测序列中出现特定的序列结束词元“<eos>”。\n",
|
||
"本节将首先介绍*贪心搜索*(greedy search)策略,\n",
|
||
"并探讨其存在的问题,然后对比其他替代策略:\n",
|
||
"*穷举搜索*(exhaustive search)和*束搜索*(beam search)。\n",
|
||
"\n",
|
||
"在正式介绍贪心搜索之前,我们使用与 :numref:`sec_seq2seq`中\n",
|
||
"相同的数学符号定义搜索问题。\n",
|
||
"在任意时间步$t'$,解码器输出$y_{t'}$的概率取决于\n",
|
||
"时间步$t'$之前的输出子序列$y_1, \\ldots, y_{t'-1}$\n",
|
||
"和对输入序列的信息进行编码得到的上下文变量$\\mathbf{c}$。\n",
|
||
"为了量化计算代价,用$\\mathcal{Y}$表示输出词表,\n",
|
||
"其中包含“<eos>”,\n",
|
||
"所以这个词汇集合的基数$\\left|\\mathcal{Y}\\right|$就是词表的大小。\n",
|
||
"我们还将输出序列的最大词元数指定为$T'$。\n",
|
||
"因此,我们的目标是从所有$\\mathcal{O}(\\left|\\mathcal{Y}\\right|^{T'})$个\n",
|
||
"可能的输出序列中寻找理想的输出。\n",
|
||
"当然,对于所有输出序列,在“<eos>”之后的部分(非本句)\n",
|
||
"将在实际输出中丢弃。\n",
|
||
"\n",
|
||
"## 贪心搜索\n",
|
||
"\n",
|
||
"首先,让我们看看一个简单的策略:*贪心搜索*,\n",
|
||
"该策略已用于 :numref:`sec_seq2seq`的序列预测。\n",
|
||
"对于输出序列的每一时间步$t'$,\n",
|
||
"我们都将基于贪心搜索从$\\mathcal{Y}$中找到具有最高条件概率的词元,即:\n",
|
||
"\n",
|
||
"$$y_{t'} = \\operatorname*{argmax}_{y \\in \\mathcal{Y}} P(y \\mid y_1, \\ldots, y_{t'-1}, \\mathbf{c})$$\n",
|
||
"\n",
|
||
"一旦输出序列包含了“<eos>”或者达到其最大长度$T'$,则输出完成。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_s2s-prob1`\n",
|
||
"\n",
|
||
"如 :numref:`fig_s2s-prob1`中,\n",
|
||
"假设输出中有四个词元“A”“B”“C”和“<eos>”。\n",
|
||
"每个时间步下的四个数字分别表示在该时间步\n",
|
||
"生成“A”“B”“C”和“<eos>”的条件概率。\n",
|
||
"在每个时间步,贪心搜索选择具有最高条件概率的词元。\n",
|
||
"因此,将在 :numref:`fig_s2s-prob1`中\n",
|
||
"预测输出序列“A”“B”“C”和“<eos>”。\n",
|
||
"这个输出序列的条件概率是\n",
|
||
"$0.5\\times0.4\\times0.4\\times0.6 = 0.048$。\n",
|
||
"\n",
|
||
"那么贪心搜索存在的问题是什么呢?\n",
|
||
"现实中,*最优序列*(optimal sequence)应该是最大化\n",
|
||
"$\\prod_{t'=1}^{T'} P(y_{t'} \\mid y_1, \\ldots, y_{t'-1}, \\mathbf{c})$\n",
|
||
"值的输出序列,这是基于输入序列生成输出序列的条件概率。\n",
|
||
"然而,贪心搜索无法保证得到最优序列。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_s2s-prob2`\n",
|
||
"\n",
|
||
" :numref:`fig_s2s-prob2`中的另一个例子阐述了这个问题。\n",
|
||
"与 :numref:`fig_s2s-prob1`不同,在时间步$2$中,\n",
|
||
"我们选择 :numref:`fig_s2s-prob2`中的词元“C”,\n",
|
||
"它具有*第二*高的条件概率。\n",
|
||
"由于时间步$3$所基于的时间步$1$和$2$处的输出子序列已从\n",
|
||
" :numref:`fig_s2s-prob1`中的“A”和“B”改变为\n",
|
||
" :numref:`fig_s2s-prob2`中的“A”和“C”,\n",
|
||
"因此时间步$3$处的每个词元的条件概率也在 :numref:`fig_s2s-prob2`中改变。\n",
|
||
"假设我们在时间步$3$选择词元“B”,\n",
|
||
"于是当前的时间步$4$基于前三个时间步的输出子序列“A”“C”和“B”为条件,\n",
|
||
"这与 :numref:`fig_s2s-prob1`中的“A”“B”和“C”不同。\n",
|
||
"因此,在 :numref:`fig_s2s-prob2`中的时间步$4$生成\n",
|
||
"每个词元的条件概率也不同于 :numref:`fig_s2s-prob1`中的条件概率。\n",
|
||
"结果, :numref:`fig_s2s-prob2`中的输出序列\n",
|
||
"“A”“C”“B”和“<eos>”的条件概率为\n",
|
||
"$0.5\\times0.3 \\times0.6\\times0.6=0.054$,\n",
|
||
"这大于 :numref:`fig_s2s-prob1`中的贪心搜索的条件概率。\n",
|
||
"这个例子说明:贪心搜索获得的输出序列\n",
|
||
"“A”“B”“C”和“<eos>”\n",
|
||
"不一定是最佳序列。\n",
|
||
"\n",
|
||
"## 穷举搜索\n",
|
||
"\n",
|
||
"如果目标是获得最优序列,\n",
|
||
"我们可以考虑使用*穷举搜索*(exhaustive search):\n",
|
||
"穷举地列举所有可能的输出序列及其条件概率,\n",
|
||
"然后计算输出条件概率最高的一个。\n",
|
||
"\n",
|
||
"虽然我们可以使用穷举搜索来获得最优序列,\n",
|
||
"但其计算量$\\mathcal{O}(\\left|\\mathcal{Y}\\right|^{T'})$可能高的惊人。\n",
|
||
"例如,当$|\\mathcal{Y}|=10000$和$T'=10$时,\n",
|
||
"我们需要评估$10000^{10} = 10^{40}$序列,\n",
|
||
"这是一个极大的数,现有的计算机几乎不可能计算它。\n",
|
||
"然而,贪心搜索的计算量\n",
|
||
"$\\mathcal{O}(\\left|\\mathcal{Y}\\right|T')$\n",
|
||
"通它要显著地小于穷举搜索。\n",
|
||
"例如,当$|\\mathcal{Y}|=10000$和$T'=10$时,\n",
|
||
"我们只需要评估$10000\\times10=10^5$个序列。\n",
|
||
"\n",
|
||
"## 束搜索\n",
|
||
"\n",
|
||
"那么该选取哪种序列搜索策略呢?\n",
|
||
"如果精度最重要,则显然是穷举搜索。\n",
|
||
"如果计算成本最重要,则显然是贪心搜索。\n",
|
||
"而束搜索的实际应用则介于这两个极端之间。\n",
|
||
"\n",
|
||
"*束搜索*(beam search)是贪心搜索的一个改进版本。\n",
|
||
"它有一个超参数,名为*束宽*(beam size)$k$。\n",
|
||
"在时间步$1$,我们选择具有最高条件概率的$k$个词元。\n",
|
||
"这$k$个词元将分别是$k$个候选输出序列的第一个词元。\n",
|
||
"在随后的每个时间步,基于上一时间步的$k$个候选输出序列,\n",
|
||
"我们将继续从$k\\left|\\mathcal{Y}\\right|$个可能的选择中\n",
|
||
"挑出具有最高条件概率的$k$个候选输出序列。\n",
|
||
"\n",
|
||
"\n",
|
||
":label:`fig_beam-search`\n",
|
||
"\n",
|
||
" :numref:`fig_beam-search`演示了束搜索的过程。\n",
|
||
"假设输出的词表只包含五个元素:\n",
|
||
"$\\mathcal{Y} = \\{A, B, C, D, E\\}$,\n",
|
||
"其中有一个是“<eos>”。\n",
|
||
"设置束宽为$2$,输出序列的最大长度为$3$。\n",
|
||
"在时间步$1$,假设具有最高条件概率\n",
|
||
"$P(y_1 \\mid \\mathbf{c})$的词元是$A$和$C$。\n",
|
||
"在时间步$2$,我们计算所有$y_2 \\in \\mathcal{Y}$为:\n",
|
||
"\n",
|
||
"$$\\begin{aligned}P(A, y_2 \\mid \\mathbf{c}) = P(A \\mid \\mathbf{c})P(y_2 \\mid A, \\mathbf{c}),\\\\ P(C, y_2 \\mid \\mathbf{c}) = P(C \\mid \\mathbf{c})P(y_2 \\mid C, \\mathbf{c}),\\end{aligned}$$ \n",
|
||
"\n",
|
||
"从这十个值中选择最大的两个,\n",
|
||
"比如$P(A, B \\mid \\mathbf{c})$和$P(C, E \\mid \\mathbf{c})$。\n",
|
||
"然后在时间步$3$,我们计算所有$y_3 \\in \\mathcal{Y}$为:\n",
|
||
"\n",
|
||
"$$\\begin{aligned}P(A, B, y_3 \\mid \\mathbf{c}) = P(A, B \\mid \\mathbf{c})P(y_3 \\mid A, B, \\mathbf{c}),\\\\P(C, E, y_3 \\mid \\mathbf{c}) = P(C, E \\mid \\mathbf{c})P(y_3 \\mid C, E, \\mathbf{c}),\\end{aligned}$$ \n",
|
||
"\n",
|
||
"从这十个值中选择最大的两个,\n",
|
||
"即$P(A, B, D \\mid \\mathbf{c})$和$P(C, E, D \\mid \\mathbf{c})$,\n",
|
||
"我们会得到六个候选输出序列:\n",
|
||
"(1)$A$;(2)$C$;(3)$A,B$;(4)$C,E$;(5)$A,B,D$;(6)$C,E,D$。\n",
|
||
"\n",
|
||
"最后,基于这六个序列(例如,丢弃包括“<eos>”和之后的部分),\n",
|
||
"我们获得最终候选输出序列集合。\n",
|
||
"然后我们选择其中条件概率乘积最高的序列作为输出序列:\n",
|
||
"\n",
|
||
"$$ \\frac{1}{L^\\alpha} \\log P(y_1, \\ldots, y_{L}\\mid \\mathbf{c}) = \\frac{1}{L^\\alpha} \\sum_{t'=1}^L \\log P(y_{t'} \\mid y_1, \\ldots, y_{t'-1}, \\mathbf{c}),$$\n",
|
||
":eqlabel:`eq_beam-search-score`\n",
|
||
"\n",
|
||
"其中$L$是最终候选序列的长度,\n",
|
||
"$\\alpha$通常设置为$0.75$。\n",
|
||
"因为一个较长的序列在 :eqref:`eq_beam-search-score`\n",
|
||
"的求和中会有更多的对数项,\n",
|
||
"因此分母中的$L^\\alpha$用于惩罚长序列。\n",
|
||
"\n",
|
||
"束搜索的计算量为$\\mathcal{O}(k\\left|\\mathcal{Y}\\right|T')$,\n",
|
||
"这个结果介于贪心搜索和穷举搜索之间。\n",
|
||
"实际上,贪心搜索可以看作一种束宽为$1$的特殊类型的束搜索。\n",
|
||
"通过灵活地选择束宽,束搜索可以在正确率和计算代价之间进行权衡。\n",
|
||
"\n",
|
||
"## 小结\n",
|
||
"\n",
|
||
"* 序列搜索策略包括贪心搜索、穷举搜索和束搜索。\n",
|
||
"* 贪心搜索所选取序列的计算量最小,但精度相对较低。\n",
|
||
"* 穷举搜索所选取序列的精度最高,但计算量最大。\n",
|
||
"* 束搜索通过灵活选择束宽,在正确率和计算代价之间进行权衡。\n",
|
||
"\n",
|
||
"## 练习\n",
|
||
"\n",
|
||
"1. 我们可以把穷举搜索看作一种特殊的束搜索吗?为什么?\n",
|
||
"1. 在 :numref:`sec_seq2seq`的机器翻译问题中应用束搜索。\n",
|
||
" 束宽是如何影响预测的速度和结果的?\n",
|
||
"1. 在 :numref:`sec_rnn_scratch`中,我们基于用户提供的前缀,\n",
|
||
" 通过使用语言模型来生成文本。这个例子中使用了哪种搜索策略?可以改进吗?\n",
|
||
"\n",
|
||
"[Discussions](https://discuss.d2l.ai/t/5768)\n"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"language_info": {
|
||
"name": "python"
|
||
},
|
||
"required_libs": []
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |