.. _sec_multihead-attention:
Atenção Multi-Head
==================
Na prática, dado o mesmo conjunto de consultas, chaves e valores podemos
querer que nosso modelo combine conhecimento de diferentes
comportamentos do mesmo mecanismo de atenção, como capturar dependências
de vários intervalos (por exemplo, intervalo mais curto vs. intervalo
mais longo) dentro de uma sequência. Desse modo, pode ser benéfico
permitir nosso mecanismo de atenção para usar em conjunto diferentes
subespaços de representação de consultas, chaves e valores.
Para este fim, em vez de realizar um único agrupamento de atenção,
consultas, chaves e valores podem ser transformados com :math:`h`
projeções lineares aprendidas independentemente. Então, essas :math:`h`
consultas, chaves e valores projetados são alimentados em agrupamento de
atenção em paralelo. No fim, :math:`h` resultados de concentração de
atenção são concatenados e transformados com outra projeção linear
aprendida para produzir a saída final. Este design é chamado de *atenção
multi-head*, onde cada uma das saídas de concentração de :math:`h` é um
*head* :cite:`Vaswani.Shazeer.Parmar.ea.2017`. Usando camadas
totalmente conectadas para realizar transformações lineares que podem
ser aprendidas, :numref:`fig_multi-head-attention` descreve a atenção
de *multi-head*.
.. _fig_multi-head-attention:
.. figure:: ../img/multi-head-attention.svg
Multi-head attention, where multiple heads are concatenated then
linearly transformed.
Modelo
------
Antes de fornecer a implementação da atenção *multi-head*, vamos
formalizar este modelo matematicamente. Dada uma consulta
:math:`\mathbf{q} \in \mathbb{R}^{d_q}`, uma chave
:math:`\mathbf{k} \in \mathbb{R}^{d_k}`, e um valor
:math:`\mathbf{v} \in \mathbb{R}^{d_v}`, cada *head* de atenção
:math:`\mathbf{h}_i` (:math:`i = 1, \ldots, h`) é calculado como
.. math:: \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},
onde parâmetros aprendíveis
:math:`\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}`,
:math:`\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}` e
:math:`\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}`, e :math:`f` é
concentração de atenção, tal como atenção aditiva e atenção de produto
escalonado em :numref:`sec_attention-scoring-functions`. A saída de
atenção *multi-head* é outra transformação linear via parâmetros
aprendíveis :math:`\mathbf W_o\in\mathbb R^{p_o\times h p_v}` da
concatenação de :math:`h` cabeças:
.. math:: \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.
Com base neste design, cada cabeça pode atender a diferentes partes da
entrada. Funções mais sofisticadas do que a média ponderada simples
podem ser expressadas.
.. raw:: html
.. raw:: html
.. code:: python
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l
npx.set_np()
.. raw:: html
.. raw:: html
.. code:: python
import math
import torch
from torch import nn
from d2l import torch as d2l
.. raw:: html
.. raw:: html
Implementação
-------------
Em nossa implementação, nós escolhemos a atenção do produto escalonado
para cada *head* da atenção de várias cabeças. Para evitar um
crescimento significativo de custo computacional e custo de
parametrização, montamos :math:`p_q = p_k = p_v = p_o / h`. Observe que
:math:`h` *heads* pode ser calculado em paralelo se definirmos o número
de saídas de transformações lineares para a consulta, chave e valor a
:math:`p_q h = p_k h = p_v h = p_o`. Na implementação a seguir,
:math:`p_o` é especificado através do argumento ``num_hiddens``.
.. raw:: html
.. raw:: html
.. code:: python
#@save
class MultiHeadAttention(nn.Block):
def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = valid_lens.repeat(self.num_heads, axis=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`:
# (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
.. raw:: html
.. raw:: html
.. code:: python
#@save
class MultiHeadAttention(nn.Module):
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# Shape of `queries`, `keys`, or `values`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`)
# Shape of `valid_lens`:
# (`batch_size`,) or (`batch_size`, no. of queries)
# After transposing, shape of output `queries`, `keys`, or `values`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
queries = transpose_qkv(self.W_q(queries), self.num_heads)
keys = transpose_qkv(self.W_k(keys), self.num_heads)
values = transpose_qkv(self.W_v(values), self.num_heads)
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for
# `num_heads` times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# Shape of `output`: (`batch_size` * `num_heads`, no. of queries,
# `num_hiddens` / `num_heads`)
output = self.attention(queries, keys, values, valid_lens)
# Shape of `output_concat`:
# (`batch_size`, no. of queries, `num_hiddens`)
output_concat = transpose_output(output, self.num_heads)
return self.W_o(output_concat)
.. raw:: html
.. raw:: html
Para permitir o cálculo paralelo de várias *heads* a classe
``MultiHeadAttention`` acima usa duas funções de transposição, conforme
definido abaixo. Especificamente, a função ``transpose_output`` reverte
a operação da função ``transpose_qkv``.
.. raw:: html
.. raw:: html
.. code:: python
#@save
def transpose_qkv(X, num_heads):
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.transpose(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.transpose(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
.. raw:: html
.. raw:: html
.. code:: python
#@save
def transpose_qkv(X, num_heads):
# Shape of input `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_hiddens`).
# Shape of output `X`:
# (`batch_size`, no. of queries or key-value pairs, `num_heads`,
# `num_hiddens` / `num_heads`)
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
# Shape of output `X`:
# (`batch_size`, `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
X = X.permute(0, 2, 1, 3)
# Shape of `output`:
# (`batch_size` * `num_heads`, no. of queries or key-value pairs,
# `num_hiddens` / `num_heads`)
return X.reshape(-1, X.shape[2], X.shape[3])
#@save
def transpose_output(X, num_heads):
"""Reverse the operation of `transpose_qkv`"""
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
.. raw:: html
.. raw:: html
Vamos testar nossa classe ``MultiHeadAttention`` implementada usando um
exemplo de brinquedo em que as chaves e os valores são iguais. Como
resultado, a forma da saída de atenção *multi-head* é
(``batch_size``,\ ``num_queries``, ``num_hiddens``).
.. raw:: html
.. raw:: html
.. code:: python
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, np.array([3, 2])
X = np.ones((batch_size, num_queries, num_hiddens))
Y = np.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
.. parsed-literal::
:class: output
(2, 4, 100)
.. raw:: html
.. raw:: html
.. code:: python
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
.. parsed-literal::
:class: output
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
.. code:: python
batch_size, num_queries, num_kvpairs, valid_lens = 2, 4, 6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape
.. parsed-literal::
:class: output
torch.Size([2, 4, 100])
.. raw:: html
.. raw:: html
Resumo
------
- A atenção *multi-head* combina o conhecimento do mesmo agrupamento de
atenção por meio de diferentes subespaços de representação de
consultas, chaves e valores.
- Para calcular várias *heads* de atenção de *multi-heads* em paralelo,
é necessária a manipulação adequada do tensor.
Exercícios
----------
1. Visualize o peso da atenção *multi-head* neste experimento.
2. Suponha que temos um modelo treinado com base na atenção *multi-head*
e queremos podar as *heads* menos importantes para aumentar a
velocidade de previsão. Como podemos projetar experimentos para medir
a importância de uma *head* de atenção?
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
`Discussions `__
.. raw:: html
.. raw:: html
.. raw:: html