LESSON

Attention機構

田中VPoE:「CNN で画像分類は上手くいった。次は自然言語処理(NLP)の領域に入ろう。NetShop には商品レビューが大量にある。これを感情分析して、ユーザー体験の改善に活かしたい。」

あなた:「テキストデータは画像とはまた違うアプローチが必要ですよね。RNN とか LSTM とか聞いたことがあります。」

田中VPoE:「RNN 系は長い文の依存関係を捉えるのが苦手だった。それを劇的に改善したのが Attention 機構だ。現代の NLP を支える最も重要な概念だから、しっかり理解しよう。」

系列データ処理の課題

テキストなどの系列データを処理する際、従来の RNN(再帰型ニューラルネットワーク)には以下の課題がありました。

RNN の課題:
入力: "この商品はデザインが良くて品質も高いのでとても満足しています"

[この] → [商品] → [は] → ... → [満足] → [しています]
  ↓        ↓       ↓              ↓          ↓
  h1  →    h2  →   h3  → ... →   hn-1  →    hn

問題: "この商品" の情報が "満足" まで伝わる間に薄れてしまう(長距離依存の問題)

Attention とは

Attention 機構は「入力のどの部分に注目すべきか」を動的に決定する仕組みです。

従来の RNN:
  最後の隠れ状態のみ使用 → 情報のボトルネック

Attention 付き RNN:
  すべての隠れ状態を参照し、重要な部分に注目 → ボトルネック解消

Attention の3要素

Attention は Query(Q)、Key(K)、Value(V)の3つの要素で構成されます。

Q(Query): 「何を探しているか」(検索クエリ)
K(Key):   「何があるか」(検索キー)
V(Value): 「実際の内容」(検索結果の値)

類似度 = Q と K のマッチング → 重みとして V を加重和

具体例:レビューの感情分析

レビュー: "この商品はデザインが素晴らしく品質も高い"

"満足度" に注目(Query)したとき:
  "デザイン"  → 関連度: 0.15(Key-Query 類似度)
  "素晴らしく" → 関連度: 0.35(高い)
  "品質"      → 関連度: 0.10
  "高い"      → 関連度: 0.30(高い)
  その他      → 関連度: 0.10

→ "素晴らしく" と "高い" に注目した表現が得られる

Attention の計算

Scaled Dot-Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, seq_len_q, d_k)
    K: (batch, seq_len_k, d_k)
    V: (batch, seq_len_k, d_v)
    """
    d_k = Q.size(-1)

    # 1. Q と K の内積で類似度を計算
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # 2. マスク(オプション)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # 3. Softmax で重みに変換
    attention_weights = F.softmax(scores, dim=-1)

    # 4. 重み付き和で出力を計算
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

計算の流れ

1. スコア計算:  scores = Q * K^T / sqrt(d_k)
2. 正規化:      weights = softmax(scores)
3. 加重和:      output = weights * V

sqrt(d_k) で割る理由:
  → 内積が大きくなりすぎると softmax の勾配が消失するため、スケーリングする

Multi-Head Attention

単一の Attention では捉えられない多様な関係性を、複数の「ヘッド」で並列に計算します。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # 線形変換 → ヘッドに分割
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention
        output, weights = scaled_dot_product_attention(Q, K, V, mask)

        # ヘッドを結合
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        output = self.W_o(output)

        return output, weights
Multi-Head Attention の利点:
  Head 1: 文法的な関係(主語-述語)を捉える
  Head 2: 意味的な関係(形容詞-名詞)を捉える
  Head 3: 位置的な近接関係を捉える
  ...
  → 多角的な関係性を同時に学習できる

Self-Attention

Self-Attention は、同じ系列の中で各要素が他のすべての要素との関係を計算する仕組みです。Q、K、V がすべて同じ入力から生成されます。

入力: "この 商品 は とても 良い"

Self-Attention の結果:
  "良い" → "商品" に強く注目(何が良いのか)
  "良い" → "とても" に強く注目(程度の修飾)
  "この" → "商品" に強く注目(指示対象)

Attention の可視化

def visualize_attention(sentence, attention_weights):
    """Attention の重みをヒートマップで可視化"""
    import matplotlib.pyplot as plt
    import seaborn as sns

    tokens = sentence.split()
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(
        attention_weights.detach().numpy(),
        xticklabels=tokens,
        yticklabels=tokens,
        cmap='YlOrRd',
        annot=True,
        fmt='.2f',
        ax=ax
    )
    ax.set_title('Self-Attention Weights')
    plt.tight_layout()
    plt.show()

まとめ

  • Attention 機構は「入力のどこに注目すべきか」を動的に決定する仕組み
  • Query、Key、Value の3要素で構成される
  • Scaled Dot-Product Attention でスコアを計算し、Softmax で正規化する
  • Multi-Head Attention により多角的な関係性を並列に学習できる
  • Self-Attention は同一系列内の要素間の関係を捉える

チェックリスト

  • Attention 機構が解決する問題(長距離依存)を説明できる
  • Query、Key、Value の役割を理解した
  • Scaled Dot-Product Attention の計算式を理解した
  • Multi-Head Attention と Self-Attention の違いを説明できる

推定読了時間: 30分