从训练和预测的角度来理解Transformer中Masked Self-Attention的原理


Transformer模型结构图

在Transformer中Decoder会先经过一个masked self-attention层

使用Masked Self-Attention层可以解决下文提到的训练阶段和预测阶段Decoder可能遇到的所有问题。

什么是Masked Self-attention层

你只需要记住:masked self-attention层就是下面的网络连线(如果实现这样的神经元连接,你只要记住一个sequence mask,让右侧的注意力系数$\alpha_{ij}=0$,那么就可以达到这个效果)

训练阶段:

训练时,你有decoder的target句子,你会直接输入到masked self-attention层。(对于同一个src句子,只有这一个target输入序列)

测试阶段:

预测时,你会有很多的序列(对于同一个src句子,会有多个target输入序列,慢慢生成的)

  1. 你会先把<Start>作为序列,输入到masked self-attention层,预测结果是y1
  2. 然后把<Start> y1作为序列,输入到masked self-attention层(和训练时一样,都会用到mask矩阵来实现masked self-attention层的神经元连接方式),预测结果是y1, y2(由于可能有dropout,这个y1可能与第一步的y1稍微有点不同)
  3. <Start> y1 y2作为序列,输入到masked self-attention层,每个位置上的预测结果是y1, y2, y3

可以看到预测阶段我们希望是增量更新的,对于重复的单词,我们希望预测的结果是一样的,而且$y_i$永远只是用到它和它左侧的decoder输入信息,不会用到右侧的。

  • greedy_decoder代码预测的输入序列是一个个生成的

      def greedy_decoder(model, enc_input, start_symbol):
          """贪心编码
          For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
          target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
          Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
          :param model: Transformer Model
          :param enc_input: The encoder input
          :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
          :return: The target input
          """
          enc_outputs, enc_self_attns = model.encoder(enc_input)
          dec_input = torch.zeros(1, 0).type_as(enc_input.data)
          terminal = False
          next_symbol = start_symbol
          while not terminal:
              # 预测阶段:dec_input序列会一点点变长(每次添加一个新预测出来的单词)
              dec_input = torch.cat([dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
                                    -1)
              dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
              projected = model.projection(dec_outputs)
              prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
              next_word = prob.data[-1]  # 拿出当前预测的单词(数字)
              next_symbol = next_word
              if next_symbol == tgt_vocab["E"]:
                  terminal = True
              # print(next_word)
    
          # greedy_dec_predict = torch.cat(
          #     [dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
          #     -1)
          greedy_dec_predict = dec_input[:, 1:]
          return greedy_dec_predict

按照上面的说法,如果我在代码中关掉了dropout,那么当预测序列是$x’_1,x’_2,x’_3$时的输出结果,应该是和预测序列时$x’_1,x’_2,x’_3,x’_4$的前3个位置结果是一样的(增量更新)

验证:关掉位置编码中dropout后,你会发现之前的输入x’1,x’2,x’3经过decoder网络的结果果然是不变的。

为什么需要Masked Self-attention层

对于这个疑问,我在知乎上看到别人也有这样的困惑,在测试或者预测时,Transformer里decoder为什么还需要seq mask?

Transformer原论文中是这样解释的:

反正我知道他说的意思,但是又好像没完全懂。

直到后来看到一个博客深入理解transformer源码,才理解透彻了这个问题。

这个问题我们分两个角度来看:

训练阶段为什么要用masked?

这个比较好理解,因为你训练的时候算loss,是用当前decoder输入所有单词对应位置的输出$y_1,y_2,…y_t$与真实的翻译结果ground truth去分别算cross entropy loss,然后把t个loss加起来的,如果你用的是self-attention,那么$y_1$这个输出里面是包含了$x’_1$右侧的单词信息的(特别是包含了$x’_2$这个你要预测的下一个单词的信息),这是用到了未来的信息,模型是在作弊,属于信息泄露。在实际推理过程中,我们显然是不可能提前知道未来信息的。

我们可以看下面Transformer训练过程的代码:

训练是用$y_1,y_2,…y_t$与真实的翻译结果ground truth去分别算cross entropy loss,然后把t个loss加起来,得到loss的(那显然是需要用masked self attention的,否则$y_1$是包含右侧的信息的,那你还把他加到loss里面,它都作弊了,显然这个位置算出来的loss是会比较低,但有什么意义呢,真正推理的时候你哪来未来的信息啊)

预测阶段为什么也要用masked?

很多博客说是Transformer使用sequence masked是在模拟预测时的情况,因为预测结果是迭代生成的,这是为了不让模型偷看到未来的内容,这样解释也没有错,但是没有讲清楚预测阶段为什么要用masked。

预测阶段还使用Masked的原因:

  • 原因1:预测阶段要保持重复的单词预测结果是一样的,这样不仅合理,而且可以增量更新(我们在预测是会选择性忽略重复的预测的词,只摘取最新预测的单词拼接到输入序列中)

    如果我在代码中关掉了dropout,那么当预测序列是$x’_1,x’_2,x’_3$时的输出结果,应该是和预测序列时$x’_1,x’_2,x’_3,x’_4$的前3个位置结果是一样的(增量更新)

原因2:恰好也可以与训练时的模型架构保持一致,前向传播的方式是一致的

附:Transformer预测阶段为什么要保持重复的单词预测结果是一样的?

摘自:blog

简单来说就是,你如果用self-attention,那么$x’_1,x’_2$的输出是$y_1,y_2$,而$x’_1,x’_2,x’_3$的输出是$z_1,z_2,z_3$(这里面$y_1$会包含$x’_2$的信息,$z_1$更离谱会包含$x’_2,x’_3$的信息,这tm属于信息泄露了)。而且我们希望的是增量更新,前面的同一个单词预测结果,我们希望是一样的。

你想想我们在作机器翻译的时候,是一个个把输出结果加到最终结果里的。但我们是不会因为$z_1$预测更准(作弊了),就用这个单词去替换之前已经预测出来的单词$y_1$的


实战:Transformer预测阶段输入序列$x’_1,x’_2,x’_3$,我们是会得到$y_1,y_2,y_3$的,但是因为有masked机制的存在,我们基本可以保证此时的$y_1,y_2$和之前的输入序列$x’_1,x’_2$的预测结果$y_1,y_2$是相同的(代码中关掉dropout的情况下)

所以一般预测的时候我们是直接用当前最后一个位置$x’_t$对应的概率向量去预测下一个单词$y_t$


Author: SHWEI
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint polocy. If reproduced, please indicate source SHWEI !
评论
 Previous
深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用 深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用
深入理解BatchNorm的原理、代码实现以及BN在CNN中的应用 BatchNorm是算法岗面试中几乎必考题,本文将带你理解BatchNorm的原理和代码实现,以及详细介绍BatchNorm在CNN中的应用。 一、BatchNorm论文
2020-11-18
Next 
【ICML 2020联邦学习论文解读】SCAFFOLD: Stochastic Controlled Averaging for Federated Learning 【ICML 2020联邦学习论文解读】SCAFFOLD: Stochastic Controlled Averaging for Federated Learning
论文解读:SCAFFOLD: Stochastic Controlled Averaging for Federated Learning 作者:Sai Praneeth Karimireddy 主页, Satyen Kale, Mehry
2020-10-28
  TOC