.. _sec_attention-scoring-functions: Funções de Pontuação de Atenção =============================== Em :numref:`sec_nadaraya-waston`, usamos um kernel gaussiano para modelar interações entre consultas e chaves. Tratando o expoente do kernel gaussiano em :eq:`eq_nadaraya-waston-gaussian` como uma *função de pontuação de atenção* (ou *função de pontuação* para abreviar), os resultados desta função foram essencialmente alimentados em uma operação softmax. Como resultado, Nós obtivemos uma distribuição de probabilidade (pesos de atenção) sobre valores que estão emparelhados com chaves. No fim, a saída do *pooling* de atenção é simplesmente uma soma ponderada dos valores com base nesses pesos de atenção. Em alto nível, podemos usar o algoritmo acima para instanciar a estrutura de mecanismos de atenção em :numref:`fig_qkv`. Denotando uma função de pontuação de atenção por :math:`a`, :numref:`fig_attention_output` ilustra como a saída do *pooling* de atenção pode ser calculado como uma soma ponderada de valores. Uma vez que os pesos de atenção são uma distribuição de probabilidade, a soma ponderada é essencialmente uma média ponderada. .. _fig_attention_output: .. figure:: ../img/attention-output.svg Calculando a saída do *pooling* de atenção como uma média ponderada de valores. Matematicamente, suponha que temos uma consulta :math:`\mathbf{q} \in \mathbb{R}^q` e :math:`m` pares de valores-chave :math:`(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)`, onde qualquer :math:`\mathbf{k}_i \in \mathbb{R}^k` e qualquer :math:`\mathbf{v}_i \in \mathbb{R}^v`. O *pooling* de atenção :math:`f` é instanciado como uma soma ponderada dos valores: .. math:: f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v, :label: eq_attn-pooling Onde o peso da atenção (escalar) para a consulta :math:`\mathbf{q}` e a chave :math:`\mathbf{k}_i` é calculado pela operação softmax de uma função de pontuação de atenção :math:`a` que mapeia dois vetores para um escalar: .. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}. :label: eq_attn-scoring-alpha Como podemos ver, diferentes escolhas da função de pontuação de atenção :math:`a` levam a diferentes comportamentos de concentração de atenção. Nesta secção, apresentamos duas funções populares de pontuação que usaremos para desenvolver mais mecanismos sofisticados de atenção posteriormente. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python import math from mxnet import np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. code:: python import math import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
Operação *Softmax* Mascarada ---------------------------- Como acabamos de mencionar, uma operação softmax é usada para produzir uma distribuição de probabilidade como pesos de atenção. Em alguns casos, nem todos os valores devem ser incluídos no agrupamento de atenção. Por exemplo, para processamento eficiente de minibatch em :numref:`sec_machine_translation`, algumas sequências de texto são preenchidas com tokens especiais que não possuem significado. Para obter um *pooling* de atenção sobre apenas tokens significativos como valores, podemos especificar um comprimento de sequência válido (em número de tokens) para filtrar aqueles que estão além deste intervalo especificado ao calcular softmax. Desta maneira, podemos implementar tal *operação de softmax mascarada* na seguinte função ``masked_softmax``, onde qualquer valor além do comprimento válido é mascarado como zero. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python #@save def masked_softmax(X, valid_lens): """Perform softmax operation by masking elements on the last axis.""" # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor if valid_lens is None: return npx.softmax(X) else: shape = X.shape if valid_lens.ndim == 1: valid_lens = valid_lens.repeat(shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True, value=-1e6, axis=1) return npx.softmax(X).reshape(shape) .. raw:: html
.. raw:: html
.. code:: python #@save def masked_softmax(X, valid_lens): """Perform softmax operation by masking elements on the last axis.""" # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # On the last axis, replace masked elements with a very large negative # value, whose exponentiation outputs 0 X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1) .. raw:: html
.. raw:: html
Para demonstrar como essa função funciona, considere um minibatch de dois exemplos de matriz :math:`2 \times 4`, onde os comprimentos válidos para esses dois exemplos são dois e três, respectivamente. Como resultado da operação mascarada softmax, valores além dos comprimentos válidos são todos mascarados como zero. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([2, 3])) .. parsed-literal:: :class: output array([[[0.488994 , 0.511006 , 0. , 0. ], [0.4365484 , 0.56345165, 0. , 0. ]], [[0.288171 , 0.3519408 , 0.3598882 , 0. ], [0.29034296, 0.25239873, 0.45725837, 0. ]]]) .. raw:: html
.. raw:: html
.. code:: python masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3])) .. parsed-literal:: :class: output tensor([[[0.3520, 0.6480, 0.0000, 0.0000], [0.5525, 0.4475, 0.0000, 0.0000]], [[0.2764, 0.4460, 0.2776, 0.0000], [0.3825, 0.3849, 0.2327, 0.0000]]]) .. raw:: html
.. raw:: html
Da mesma forma, também podemos use um tensor bidimensional para especificar comprimentos válidos para cada linha em cada exemplo de matriz. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python masked_softmax(np.random.uniform(size=(2, 2, 4)), np.array([[1, 3], [2, 4]])) .. parsed-literal:: :class: output array([[[1. , 0. , 0. , 0. ], [0.35848376, 0.3658879 , 0.27562833, 0. ]], [[0.54370314, 0.45629686, 0. , 0. ], [0.19598778, 0.25580427, 0.19916739, 0.3490406 ]]]) .. raw:: html
.. raw:: html
.. code:: python masked_softmax(torch.rand(2, 2, 4), torch.tensor([[1, 3], [2, 4]])) .. parsed-literal:: :class: output tensor([[[1.0000, 0.0000, 0.0000, 0.0000], [0.4436, 0.2773, 0.2791, 0.0000]], [[0.4437, 0.5563, 0.0000, 0.0000], [0.2422, 0.3533, 0.2061, 0.1984]]]) .. raw:: html
.. raw:: html
.. _subsec_additive-attention: Atenção Aditiva --------------- Em geral, quando as consultas e as chaves são vetores de comprimentos diferentes, podemos usar atenção aditiva como a função de pontuação. Dada uma consulta :math:`\mathbf{q} \in \mathbb{R}^q` e uma chave :raw-latex:`\mathbf{k}` :raw-latex:`\in `:raw-latex:`\mathbb{R}`^k$, a função de pontuação *atenção aditiva* .. math:: a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}, :label: eq_additive-attn onde parâmetros aprendíveis :math:`\mathbf W_q\in\mathbb R^{h\times q}`, :math:`\mathbf W_k\in\mathbb R^{h\times k}`, e :math:`\mathbf w_v\in\mathbb R^{h}`. Equivalente a :eq:`eq_additive-attn`, a consulta e a chave são concatenadas e alimentado em um MLP com uma única camada oculta cujo número de unidades ocultas é :math:`h`, um hiperparâmetro. Usando :math:`\tanh` como a função de ativação e desativando termos de *bias*, implementamos atenção aditiva a seguir. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python #@save class AdditiveAttention(nn.Block): """Additive attention.""" def __init__(self, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) # Use `flatten=False` to only transform the last axis so that the # shapes for the other axes are kept the same self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False) self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False) self.w_v = nn.Dense(1, use_bias=False, flatten=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of `queries`: (`batch_size`, no. of # queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1, # no. of key-value pairs, `num_hiddens`). Sum them up with # broadcasting features = np.expand_dims(queries, axis=2) + np.expand_dims( keys, axis=1) features = np.tanh(features) # There is only one output of `self.w_v`, so we remove the last # one-dimensional entry from the shape. Shape of `scores`: # (`batch_size`, no. of queries, no. of key-value pairs) scores = np.squeeze(self.w_v(features), axis=-1) self.attention_weights = masked_softmax(scores, valid_lens) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) return npx.batch_dot(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
.. code:: python #@save class AdditiveAttention(nn.Module): def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_k = nn.Linear(key_size, num_hiddens, bias=False) self.W_q = nn.Linear(query_size, num_hiddens, bias=False) self.w_v = nn.Linear(num_hiddens, 1, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values, valid_lens): queries, keys = self.W_q(queries), self.W_k(keys) # After dimension expansion, shape of `queries`: (`batch_size`, no. of # queries, 1, `num_hiddens`) and shape of `keys`: (`batch_size`, 1, # no. of key-value pairs, `num_hiddens`). Sum them up with # broadcasting features = queries.unsqueeze(2) + keys.unsqueeze(1) features = torch.tanh(features) # There is only one output of `self.w_v`, so we remove the last # one-dimensional entry from the shape. Shape of `scores`: # (`batch_size`, no. of queries, no. of key-value pairs) scores = self.w_v(features).squeeze(-1) self.attention_weights = masked_softmax(scores, valid_lens) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) return torch.bmm(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
Vamos demonstrar a classe ``AdditiveAttention`` acima com um exemplo de brinquedo, onde formas (tamanho do lote, número de etapas ou comprimento da sequência em tokens, tamanho da *feature*) de consultas, chaves e valores são (:math:`2`, :math:`1`, :math:`20`), (:math:`2`, :math:`10`, :math:`2`), e (:math:`2`, :math:`10`, :math:`4`), respectivamente. A saída de concentração de atenção tem uma forma de (tamanho do lote, número de etapas para consultas, tamanho do *feature* para valores). .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python queries, keys = np.random.normal(0, 1, (2, 1, 20)), np.ones((2, 10, 2)) # The two value matrices in the `values` minibatch are identical values = np.arange(40).reshape(1, 10, 4).repeat(2, axis=0) valid_lens = np.array([2, 6]) attention = AdditiveAttention(num_hiddens=8, dropout=0.1) attention.initialize() attention(queries, keys, values, valid_lens) .. parsed-literal:: :class: output array([[[ 2. , 3. , 4. , 5. ]], [[10. , 11. , 12.000001, 13. ]]]) .. raw:: html
.. raw:: html
.. code:: python queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2)) # The two value matrices in the `values` minibatch are identical values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat( 2, 1, 1) valid_lens = torch.tensor([2, 6]) attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8, dropout=0.1) attention.eval() attention(queries, keys, values, valid_lens) .. parsed-literal:: :class: output tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]], [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=) .. raw:: html
.. raw:: html
Embora a atenção aditiva contenha parâmetros que podem ser aprendidos, uma vez que cada chave é a mesma neste exemplo, os pesos de atenção são uniformes, determinados pelos comprimentos válidos especificados. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_2a8fdc_57_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_2a8fdc_60_0.svg .. raw:: html
.. raw:: html
Atenção de Produto Escalar em Escala ------------------------------------ Um design mais eficiente do ponto de vista computacional para a função de pontuação pode ser simplesmente o produto escalar. No entanto, a operação de produto escalar requer que a consulta e a chave tenham o mesmo comprimento de vetor, digamos :math:`d`. Suponha que todos os elementos da consulta e a chave sejam variáveis aleatórias independentes com média zero e variância unitária. O produto escalar de ambos os vetores tem média zero e variância de :math:`d`. Para garantir que a variação do produto escalar ainda permaneça um, independentemente do comprimento do vetor, a função de pontuação de *atenção ao produto escalar em escala* .. math:: a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d} divide o produto escalar por :math:`\sqrt{d}`. Na prática, geralmente pensamos em minibatches para eficiência, como computação de atenção para :math:`n` consultas e :math:`m` pares de valor-chave, onde consultas e chaves têm comprimento :math:`d` e os valores têm comprimento :math:`v`. A atenção do produto escalar das consultas :math:`\mathbf Q\in\mathbb R^{n\times d}`, chaves :math:`\mathbf K\in\mathbb R^{m\times d}`, e valores :math:`\mathbf V\in\mathbb R^{m\times v}` é .. math:: \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}. :label: eq_softmax_QK_V Na implementação a seguir da atenção ao produto escalar, usamos o *dropout* para regularização do modelo. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python #@save class DotProductAttention(nn.Block): """Scaled dot product attention.""" def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) # Shape of `queries`: (`batch_size`, no. of queries, `d`) # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # Set `transpose_b=True` to swap the last two dimensions of `keys` scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return npx.batch_dot(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
.. code:: python #@save class DotProductAttention(nn.Module): """Scaled dot product attention.""" def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) # Shape of `queries`: (`batch_size`, no. of queries, `d`) # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`) # Shape of `values`: (`batch_size`, no. of key-value pairs, value # dimension) # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # Set `transpose_b=True` to swap the last two dimensions of `keys` scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values) .. raw:: html
.. raw:: html
Para demonstrar a classe ``DotProductAttention`` acima, usamos as mesmas chaves, valores e comprimentos válidos do exemplo de brinquedo anterior para atenção aditiva. Para a operação de produto escalar, fazemos o tamanho da *feature* de consultas o mesmo que o das chaves. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python queries = np.random.normal(0, 1, (2, 1, 2)) attention = DotProductAttention(dropout=0.5) attention.initialize() attention(queries, keys, values, valid_lens) .. parsed-literal:: :class: output array([[[ 2. , 3. , 4. , 5. ]], [[10. , 11. , 12.000001, 13. ]]]) .. raw:: html
.. raw:: html
.. code:: python queries = torch.normal(0, 1, (2, 1, 2)) attention = DotProductAttention(dropout=0.5) attention.eval() attention(queries, keys, values, valid_lens) .. parsed-literal:: :class: output tensor([[[ 2.0000, 3.0000, 4.0000, 5.0000]], [[10.0000, 11.0000, 12.0000, 13.0000]]]) .. raw:: html
.. raw:: html
Da mesma forma que na demonstração de atenção aditiva, uma vez que ``keys`` contém o mesmo elemento que não pode ser diferenciado por nenhuma consulta, pesos uniformes de atenção são obtidos. .. raw:: html
mxnetpytorch
.. raw:: html
.. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_2a8fdc_84_0.svg .. raw:: html
.. raw:: html
.. code:: python d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)), xlabel='Keys', ylabel='Queries') .. figure:: output_attention-scoring-functions_2a8fdc_87_0.svg .. raw:: html
.. raw:: html
Resumo ------ - Podemos calcular a saída do *pooling* de atenção como uma média ponderada de valores, onde diferentes escolhas da função de pontuação de atenção levam a diferentes comportamentos de agrupamento de atenção. - Quando consultas e chaves são vetores de comprimentos diferentes, podemos usar a função de pontuação de atenção aditiva. Quando são iguais, a função de pontuação de atenção do produto escalonado é mais eficiente do ponto de vista computacional. Exercícios ---------- 1. Modifique as chaves no exemplo do brinquedo e visualize o peso da atenção. A atenção aditiva e a atenção de produto escalar em escala ainda geram os mesmos pesos de atenção? Por que ou por que não? 2. Usando apenas multiplicações de matrizes, você pode projetar uma nova função de pontuação para consultas e chaves com diferentes comprimentos de vetor? 3. Quando as consultas e as chaves têm o mesmo comprimento de vetor, a soma de vetores é um design melhor do que o produto escalar para a função de pontuação? Por que ou por que não? .. raw:: html
mxnetpytorch
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
.. raw:: html