画像+テキスト融合
「画像だけ、テキストだけでは不十分なケースがある。」
田中VPoEが事例を見せる。
「商品画像と説明文を組み合わせた品質判定、X線画像と問診票を統合した診断支援。マルチモーダル融合が、より精度の高い意思決定を可能にする。」
マルチモーダル融合の戦略
| 融合方式 | 説明 | 適用場面 |
|---|---|---|
| 早期融合 | 入力段階で画像特徴量とテキスト特徴量を結合 | 特徴量間の相互作用が重要 |
| 遅延融合 | 各モダリティを独立に処理し、最終段で統合 | モダリティの独立性が高い |
| 注意融合 | Cross-Attentionでモダリティ間の関連を学習 | 複雑な相互参照が必要 |
早期融合の実装
import torch
import torch.nn as nn
from torchvision import models
class EarlyFusionModel(nn.Module):
"""早期融合によるマルチモーダルモデル"""
def __init__(self, text_dim, n_classes):
super().__init__()
# 画像エンコーダー
self.image_encoder = models.resnet50(pretrained=True)
self.image_encoder.fc = nn.Identity() # 2048次元の特徴量
# テキストエンコーダー
self.text_encoder = nn.Sequential(
nn.Linear(text_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
)
# 融合層
self.fusion = nn.Sequential(
nn.Linear(2048 + 256, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, n_classes),
)
def forward(self, image, text_features):
img_feat = self.image_encoder(image) # (B, 2048)
txt_feat = self.text_encoder(text_features) # (B, 256)
combined = torch.cat([img_feat, txt_feat], dim=1) # (B, 2304)
return self.fusion(combined)
遅延融合の実装
class LateFusionModel(nn.Module):
"""遅延融合によるマルチモーダルモデル"""
def __init__(self, text_dim, n_classes):
super().__init__()
# 画像分類器
self.image_classifier = nn.Sequential(
models.resnet50(pretrained=True),
nn.Linear(1000, n_classes),
)
# テキスト分類器
self.text_classifier = nn.Sequential(
nn.Linear(text_dim, 256),
nn.ReLU(),
nn.Linear(256, n_classes),
)
# 融合重み(学習可能)
self.fusion_weight = nn.Parameter(torch.tensor(0.5))
def forward(self, image, text_features):
img_logits = self.image_classifier(image)
txt_logits = self.text_classifier(text_features)
# 加重平均
w = torch.sigmoid(self.fusion_weight)
fused_logits = w * img_logits + (1 - w) * txt_logits
return fused_logits
Cross-Attention融合
class CrossAttentionFusion(nn.Module):
"""Cross-Attentionによるマルチモーダル融合"""
def __init__(self, img_dim=2048, txt_dim=768, hidden_dim=512):
super().__init__()
self.img_proj = nn.Linear(img_dim, hidden_dim)
self.txt_proj = nn.Linear(txt_dim, hidden_dim)
self.cross_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=8,
batch_first=True,
)
self.classifier = nn.Linear(hidden_dim, 2)
def forward(self, img_features, txt_features):
img_proj = self.img_proj(img_features).unsqueeze(1)
txt_proj = self.txt_proj(txt_features).unsqueeze(1)
# テキストで画像に注意を向ける
attn_output, _ = self.cross_attn(
query=txt_proj, key=img_proj, value=img_proj,
)
return self.classifier(attn_output.squeeze(1))
まとめ
| 項目 | ポイント |
|---|---|
| 早期融合 | 特徴量を結合して共同学習、相互作用を捉える |
| 遅延融合 | 独立に処理し最終段で統合、実装が簡単 |
| Cross-Attention | モダリティ間の関連を動的に学習 |
| 選定基準 | データ量、モダリティ間の相関度で判断 |
チェックリスト
- 3つの融合方式の違いを説明できる
- 早期融合モデルを実装できる
- 遅延融合の学習可能な重みを理解した
- Cross-Attentionの仕組みを説明できる
次のステップへ
画像+テキスト融合を理解した。次はVLMの業務応用を学ぼう。
推定読了時間: 30分