内容纲要

欢迎转载,作者:Ling,注明出处:深度学习:前沿技术-Attention:一个实例说明Attention机制

 

Attention机制早在一两年前就有所耳闻,它作为一般NN,CNN和RNN(LSTM)等深度学习的一个加强技术,当时已经成为NLP领域的研究热点。随着Attention机制在机器翻译、图片描述、语义蕴涵、语音识别和文本摘要等各大领域取得成功,使得它成为现在成为一个不可不学习的技术。

本文将由浅入深,通过一个简单例子介绍Attention的机制原理。

 

预备知识:假设你已经对深度学习已经熟悉,并且使用过Keras。

 

下面分一个实例介绍Attention机制。

问题:给定一些实例,每个实例有32个特征,每个特征是一个数,每个实例都属于一个类别(0,1表示),共两个类别,训练一个神经网络对这些实例进行分类,这是一个典型的二分类问题。

起因

对于某一个实例,假设每个特征是一个数,例如实例对应的特征相应值的数组为 inputs= [0.1, 1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.5…], 假设其对应类别为1,从这个实例我们可以看出,该实例类别由第二个特征就能决定(第二个特征为1),我们假设所有实例的类别都由其第二个特征决定,和其他特征无关。通过神经网络,我们可以自动学习参数,可以在一定程度上让第二个特征的权值大一些,但是这样做还不直接,我们是否有更直接的办法让第二个特征在所有特征中的比重加强?

解决:Attention机制。

假设我们可以得到一个概率数组,理想情况下attention_probs=[0, 1, 0, 0, 0…], 也就是第二个特征的概率为1,其他都为0,那么问题就得到了解决,将inputs和attention_probs按位相乘 = [0, 1, 0, 0…],只留下了第二个特征,这样就可以很好地进行分类了。理想的p很难得到,我们如果可以得到一个近似的attention_probs,比如attention_probs=[0.1, 0.5, 0.1, 0.1, …],只要能让第二个特征概率远大于其他,也可以解决问题。问题来了,如何得到这样一个attention_probs呢?

实现:基于Keras的一个简单实现

 

 

  1. def build_model():

 

 

  1.     inputs = Input(shape=(input_dim,))

  2.     # ATTENTION PART STARTS HERE
        attention_probs = Dense(input_dim, activation='softmax', name='attention_vec')(inputs)
  3.     attention_mul = merge([inputs, attention_probs], output_shape=32, name='attention_mul', mode='mul')
  4.     # ATTENTION PART FINISHES HERE
  5.     attention_mul = Dense(64)(attention_mul)
  6.     output = Dense(1, activation='sigmoid')(attention_mul)
  7.     model = Model(input=[inputs], output=output)
  8.     return model

解释:

  • 第二行是模型的输入,实例中输入维度就是32
  • 第三行经过一个全连接层,通过一个softmax得到同样32维的一个输出,这些输出就是attention的attention_probs,由每个特征的概率构成
  • 第四行通过按位相乘,让每个输入乘上其概率,得到新的输出,输出还是32维,不变。
  • 第六行增加一个全连接隐含层
  • 第七行通过一个sigmoid进行二分类

整个模型结构图如下

frontier_attention01

这样整个模型就建好后,放入训练数据训练,自动就可以学习得到attention的attention_probs数组(vector),该数组值分布如图:

frontier_attention02

可以看到第二个特征的attention概率远大于其他。

同理,LSTM类似,只是要学习得到一个针对时间步的概率数组。

我写的这篇博文应该是理解Attention机制最简明的教程,希望对大家有帮助。

 

其他扩展材料

https://www.tuicool.com/articles/A7Nj63V

https://distill.pub/2016/augmented-rnns/