LESSON

画像+テキスト融合

「画像だけ、テキストだけでは不十分なケースがある。」

田中VPoEが事例を見せる。

「商品画像と説明文を組み合わせた品質判定、X線画像と問診票を統合した診断支援。マルチモーダル融合が、より精度の高い意思決定を可能にする。」

マルチモーダル融合の戦略

融合方式説明適用場面
早期融合入力段階で画像特徴量とテキスト特徴量を結合特徴量間の相互作用が重要
遅延融合各モダリティを独立に処理し、最終段で統合モダリティの独立性が高い
注意融合Cross-Attentionでモダリティ間の関連を学習複雑な相互参照が必要

早期融合の実装

import torch
import torch.nn as nn
from torchvision import models

class EarlyFusionModel(nn.Module):
    """早期融合によるマルチモーダルモデル"""

    def __init__(self, text_dim, n_classes):
        super().__init__()
        # 画像エンコーダー
        self.image_encoder = models.resnet50(pretrained=True)
        self.image_encoder.fc = nn.Identity()  # 2048次元の特徴量

        # テキストエンコーダー
        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
        )

        # 融合層
        self.fusion = nn.Sequential(
            nn.Linear(2048 + 256, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, n_classes),
        )

    def forward(self, image, text_features):
        img_feat = self.image_encoder(image)      # (B, 2048)
        txt_feat = self.text_encoder(text_features)  # (B, 256)
        combined = torch.cat([img_feat, txt_feat], dim=1)  # (B, 2304)
        return self.fusion(combined)

遅延融合の実装

class LateFusionModel(nn.Module):
    """遅延融合によるマルチモーダルモデル"""

    def __init__(self, text_dim, n_classes):
        super().__init__()
        # 画像分類器
        self.image_classifier = nn.Sequential(
            models.resnet50(pretrained=True),
            nn.Linear(1000, n_classes),
        )

        # テキスト分類器
        self.text_classifier = nn.Sequential(
            nn.Linear(text_dim, 256),
            nn.ReLU(),
            nn.Linear(256, n_classes),
        )

        # 融合重み(学習可能)
        self.fusion_weight = nn.Parameter(torch.tensor(0.5))

    def forward(self, image, text_features):
        img_logits = self.image_classifier(image)
        txt_logits = self.text_classifier(text_features)

        # 加重平均
        w = torch.sigmoid(self.fusion_weight)
        fused_logits = w * img_logits + (1 - w) * txt_logits
        return fused_logits

Cross-Attention融合

class CrossAttentionFusion(nn.Module):
    """Cross-Attentionによるマルチモーダル融合"""

    def __init__(self, img_dim=2048, txt_dim=768, hidden_dim=512):
        super().__init__()
        self.img_proj = nn.Linear(img_dim, hidden_dim)
        self.txt_proj = nn.Linear(txt_dim, hidden_dim)

        self.cross_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            batch_first=True,
        )

        self.classifier = nn.Linear(hidden_dim, 2)

    def forward(self, img_features, txt_features):
        img_proj = self.img_proj(img_features).unsqueeze(1)
        txt_proj = self.txt_proj(txt_features).unsqueeze(1)

        # テキストで画像に注意を向ける
        attn_output, _ = self.cross_attn(
            query=txt_proj, key=img_proj, value=img_proj,
        )

        return self.classifier(attn_output.squeeze(1))

まとめ

項目ポイント
早期融合特徴量を結合して共同学習、相互作用を捉える
遅延融合独立に処理し最終段で統合、実装が簡単
Cross-Attentionモダリティ間の関連を動的に学習
選定基準データ量、モダリティ間の相関度で判断

チェックリスト

  • 3つの融合方式の違いを説明できる
  • 早期融合モデルを実装できる
  • 遅延融合の学習可能な重みを理解した
  • Cross-Attentionの仕組みを説明できる

次のステップへ

画像+テキスト融合を理解した。次はVLMの業務応用を学ぼう。

推定読了時間: 30分