Pythonでアテンション機構を理解する

プログラミング

Pythonにおけるアテンション機構の理解

アテンション機構とは

アテンション機構(Attention Mechanism)は、近年の深層学習、特に自然言語処理(NLP)や画像認識の分野で目覚ましい成果を上げている重要な技術です。その核心は、モデルが入力データ全体を均等に処理するのではなく、タスクにとってより重要と思われる部分に「注意」を向けるように学習することにあります。

例えば、機械翻訳において、ある単語を生成する際に、入力文の全ての単語が等しく関連するわけではありません。アテンション機構は、出力単語を生成する際に、入力文中のどの単語に最も焦点を当てるべきかを動的に決定します。これにより、文脈をより効果的に捉え、翻訳精度を向上させることができます。

アテンション機構は、主に以下の3つの要素で構成されます。

  • Query (Q): 現在処理している要素(例:デコーダーの現在の状態)。
  • Key (K): 入力データ中の各要素(例:エンコーダーの出力)。
  • Value (V): 入力データ中の各要素に対応する値(通常、Keyと同じか、Keyから変換されたもの)。

アテンション機構の基本的な考え方は、QueryとKeyの類似度を計算し、その類似度に基づいてValueに重みを付けて合計することで、文脈ベクトル(Context Vector)を生成することです。この文脈ベクトルが、現在のタスク処理に最も関連性の高い情報を集約したものとなります。

アテンション機構の種類

アテンション機構にはいくつかのバリエーションがありますが、代表的なものを以下に紹介します。

加法アテンション(Additive Attention / Bahdanau Attention)

加法アテンションは、QueryとKeyの類似度を計算するために、ニューラルネットワーク(通常はフィードフォワードネットワーク)を使用します。QueryとKeyを連結し、それを活性化関数に通すことで、アテンションスコアを算出します。

数式で表すと、アテンションスコア $e_{ij}$ は以下のように計算されます。

$e_{ij} = v_a^T tanh(W_a q_i + U_a k_j)$

ここで、$q_i$ はQuery、$k_j$ はKey、$v_a$ と $W_a, U_a$ は学習可能なパラメータです。

乗法アテンション(Multiplicative Attention / Luong Attention)

乗法アテンションは、QueryとKeyの類似度を計算するために、内積(dot product)を使用します。これは、加法アテンションよりも計算量が少なく、一般的に効率が良いとされています。

代表的な計算方法として、以下の3つがあります。

  • dot: $e_{ij} = q_i^T k_j$
  • general: $e_{ij} = q_i^T W_a k_j$ (ここで $W_a$ は学習可能な行列)
  • concat: $e_{ij} = q_i^T k_j$ (QueryとKeyを連結したものを入力とする場合。これは加法アテンションに近くなります)

これらのスコアは、その後Softmax関数を通して正規化され、各Valueに対する重み(アテンションウェイト)となります。

$alpha_{ij} = frac{exp(e_{ij})}{sum_k exp(e_{ik})}$

そして、文脈ベクトル $c_i$ は、これらの重みとValueの加重平均として計算されます。

$c_i = sum_j alpha_{ij} v_j$

自己アテンション(Self-Attention)

自己アテンションは、入力シーケンス内の異なる位置にある要素間の関連性を学習する機構です。Query, Key, Valueが全て同じ入力シーケンスから生成されます。これは、Transformerモデルの基盤となる技術であり、シーケンス内の長期的な依存関係を捉えるのに非常に効果的です。

具体的には、入力ベクトル $X$ から、学習可能な行列 $W^Q, W^K, W^V$ を用いて、Query行列 $Q=XW^Q$、Key行列 $K=XW^K$、Value行列 $V=XW^V$ を生成します。アテンションスコアは $QK^T$ で計算され、Softmaxで正規化された後、Valueと掛け合わせられます。

$Attention(Q, K, V) = text{softmax}(frac{QK^T}{sqrt{d_k}}) V$

ここで $sqrt{d_k}$ はスケーリングファクターで、勾配消失を防ぐために用いられます。

マルチヘッドアテンション(Multi-Head Attention)

マルチヘッドアテンションは、自己アテンションを複数回並列して実行する機構です。各「ヘッド」は異なるQuery, Key, Valueの線形変換を行い、異なる表現空間でアテンションを計算します。これにより、モデルは入力シーケンスの異なる側面や関連性を同時に捉えることができます。各ヘッドの出力は連結され、再度線形変換されて最終的な出力となります。

Pythonでの実装例(PyTorch)

Pythonでアテンション機構を実装する際には、PyTorchやTensorFlowといった深層学習フレームワークが便利です。ここでは、PyTorchを用いた基本的なアテンション機構(乗法アテンションをベースにしたもの)の実装例を示します。


import torch
import torch.nn as nn
import torch.nn.functional as F

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, encoder_outputs):
        # query: [batch_size, hidden_size]
        # encoder_outputs: [batch_size, seq_len, hidden_size]

        # query を broadcasting して seq_len の次元を追加
        # [batch_size, 1, hidden_size] -> [batch_size, seq_len, hidden_size]
        expanded_query = query.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)

        # 加法アテンションの計算
        # e = v_a * tanh(W_a * q + U_a * k)
        f = torch.tanh(self.Wa(expanded_query) + self.Ua(encoder_outputs))
        # [batch_size, seq_len, hidden_size]

        # アテンションスコアの計算
        # [batch_size, seq_len, 1]
        energy = self.Va(f)
        # アテンションウェイトの計算 (softmax)
        # [batch_size, seq_len, 1]
        attention_weights = F.softmax(energy.squeeze(2), dim=1).unsqueeze(2)

        # 文脈ベクトルの計算
        # [batch_size, hidden_size, 1]
        context_vector = torch.bmm(encoder_outputs.transpose(1, 2), attention_weights)
        # [batch_size, hidden_size, 1] -> [batch_size, hidden_size]
        context_vector = context_vector.squeeze(2)

        return context_vector, attention_weights

class LuongAttention(nn.Module):
    def __init__(self, hidden_size, attention_type='dot'):
        super(LuongAttention, self).__init__()
        self.attention_type = attention_type

        if self.attention_type == 'general':
            self.Wa = nn.Linear(hidden_size, hidden_size)
        elif self.attention_type == 'concat':
            self.Wa = nn.Linear(hidden_size * 2, hidden_size)
            self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, encoder_outputs):
        # query: [batch_size, hidden_size]
        # encoder_outputs: [batch_size, seq_len, hidden_size]

        if self.attention_type == 'dot':
            # energy = q * k^T
            # [batch_size, seq_len]
            energy = torch.bmm(query.unsqueeze(1), encoder_outputs.transpose(1, 2)).squeeze(1)
        elif self.attention_type == 'general':
            # energy = q * W_a * k^T
            # [batch_size, hidden_size] -> [batch_size, hidden_size]
            query_transformed = self.Wa(query)
            # [batch_size, seq_len]
            energy = torch.bmm(query_transformed.unsqueeze(1), encoder_outputs.transpose(1, 2)).squeeze(1)
        elif self.attention_type == 'concat':
            # energy = v_a * tanh(W_a * [q, k])
            # query を broadcasting
            # [batch_size, 1, hidden_size] -> [batch_size, seq_len, hidden_size]
            expanded_query = query.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)
            # [batch_size, seq_len, hidden_size * 2]
            concat_inputs = torch.cat((expanded_query, encoder_outputs), dim=2)
            # [batch_size, seq_len, hidden_size]
            energy = torch.tanh(self.Wa(concat_inputs))
            # [batch_size, seq_len, 1]
            energy = self.Va(energy)
            # [batch_size, seq_len]
            energy = energy.squeeze(2)

        # アテンションウェイトの計算 (softmax)
        # [batch_size, seq_len]
        attention_weights = F.softmax(energy, dim=1).unsqueeze(2)

        # 文脈ベクトルの計算
        # [batch_size, hidden_size, 1]
        context_vector = torch.bmm(encoder_outputs.transpose(1, 2), attention_weights)
        # [batch_size, hidden_size, 1] -> [batch_size, hidden_size]
        context_vector = context_vector.squeeze(2)

        return context_vector, attention_weights

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        # values, keys, query: [batch_size, seq_len, embed_size]
        # mask: [batch_size, 1, seq_len] or [batch_size, seq_len, seq_len]

        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        # [batch_size, seq_len, embed_size] -> [batch_size, seq_len, heads, head_dim]
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # Linear transformations for each head
        # [batch_size, seq_len, heads, head_dim] -> [batch_size, seq_len, heads, head_dim]
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum for efficient matrix multiplication
        # queries shape: [N, query_len, heads, head_dim]
        # keys shape: [N, key_len, heads, head_dim]
        # Transpose keys: [N, heads, head_dim, key_len]
        # energy shape: [N, heads, query_len, key_len]
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Attention weights
        # [N, heads, query_len, key_len]
        attention = torch.softmax(energy / (self.head_dim ** 0.5), dim=3)

        # Attended values
        # attention shape: [N, heads, query_len, key_len]
        # values shape: [N, key_len, heads, head_dim]
        # transpose values: [N, key_len, heads, head_dim] -> [N, heads, key_len, head_dim]
        # out shape: [N, heads, query_len, head_dim]
        out = torch.einsum("nhqk,nkhd->nqhd", [attention, values])

        # Concatenate heads and apply final linear layer
        # [N, query_len, heads, head_dim] -> [N, query_len, embed_size]
        out = out.reshape(N, query_len, self.heads * self.head_dim)
        # [N, query_len, embed_size] -> [N, query_len, embed_size]
        out = self.fc_out(out)

        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        self.dropout = nn.Dropout(0.1)

    def forward(self, values, keys, query, mask):
        # values, keys, query: [batch_size, seq_len, embed_size]
        # mask: [batch_size, 1, seq_len] or [batch_size, seq_len, seq_len]

        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        # [batch_size, seq_len, embed_size] -> [batch_size, seq_len, heads, head_dim]
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        # Linear transformations for each head
        # [batch_size, seq_len, heads, head_dim] -> [batch_size, seq_len, heads, head_dim]
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum for efficient matrix multiplication
        # queries shape: [N, query_len, heads, head_dim]
        # keys shape: [N, key_len, heads, head_dim]
        # Transpose keys: [N, heads, head_dim, key_len]
        # energy shape: [N, heads, query_len, key_len]
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Attention weights
        # [N, heads, query_len, key_len]
        attention = torch.softmax(energy / (self.head_dim ** 0.5), dim=3)

        # Apply dropout to attention weights
        attention = self.dropout(attention)

        # Attended values
        # attention shape: [N, heads, query_len, key_len]
        # values shape: [N, key_len, heads, head_dim]
        # transpose values: [N, key_len, heads, head_dim] -> [N, heads, key_len, head_dim]
        # out shape: [N, heads, query_len, head_dim]
        out = torch.einsum("nhqk,nkhd->nqhd", [attention, values])

        # Concatenate heads and apply final linear layer
        # [N, query_len, heads, head_dim] -> [N, query_len, embed_size]
        out = out.reshape(N, query_len, self.heads * self.head_dim)
        # [N, query_len, embed_size] -> [N, query_len, embed_size]
        out = self.fc_out(out)
        # Apply dropout to the output
        out = self.dropout(out)

        return out

上記のコードは、加法アテンション、乗法アテンション(dot, general, concat)、自己アテンション、マルチヘッドアテンションの基本的な実装を示しています。実際の応用では、これらの機構をエンコーダー・デコーダーモデルやTransformerモデルなどに組み込んで使用します。

アテンション機構の応用例

アテンション機構は、その強力な表現能力から、様々なタスクで成功を収めています。

  • 機械翻訳: 入力文のどの単語に注意を払うかを決定し、より自然で正確な翻訳を実現します。
  • 画像キャプション生成: 画像のどの領域に注目してキャプションを生成するかを学習します。
  • 質問応答: 文章中から質問に関連する箇所に焦点を当て、回答を抽出します。
  • テキスト要約: 長文の中から重要な部分に注意を向け、簡潔な要約を作成します。
  • 音声認識: 音声信号のどの部分がどの文字に対応するかを学習します。

特にTransformerモデルは、RNNやCNNといった従来のシーケンスモデルの制約を克服し、アテンション機構(主に自己アテンション)のみで構成されているにも関わらず、多くのNLPタスクで最先端の性能を達成しています。

まとめ

アテンション機構は、深層学習モデルが入力データの中から重要な情報に動的に焦点を当てることを可能にする革新的な技術です。Query, Key, Valueの概念と、それらを基にした類似度計算、重み付け和による文脈ベクトル生成がその核心です。加法アテンション、乗法アテンション、自己アテンション、マルチヘッドアテンションなど、様々なバリエーションが存在し、PyTorchのようなフレームワークを用いることで比較的容易に実装できます。機械翻訳から画像生成、質問応答まで、幅広い分野でその効果が証明されており、現代のAI技術を支える重要な柱となっています。