Grad-CAMで判断根拠を可視化する
「モデルが95%の精度を出しました、では通らない。なぜその判断に至ったかを説明できないとダメだ。」
田中VPoEが過去の失敗事例を示す。
「以前、外観検査AIが99%の精度を出したが、実は製品ではなく背景の色で判断していた。Grad-CAMで初めてそれに気づいた。」
Grad-CAMとは
Grad-CAM(Gradient-weighted Class Activation Mapping)は、CNNが画像のどの領域に注目して判断を下したかを可視化する手法である。
仕組み
1. ターゲットクラスに対する損失を計算
2. 最終畳み込み層の特徴マップに対する勾配を取得
3. 勾配をGlobal Average Poolingで重みに変換
4. 重み付き特徴マップの和を計算しReLUを適用
5. 入力画像にヒートマップを重畳
Grad-CAMの実装
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.model.eval()
self.gradients = None
self.activations = None
# フック関数の登録
target_layer.register_forward_hook(self._forward_hook)
target_layer.register_full_backward_hook(self._backward_hook)
def _forward_hook(self, module, input, output):
self.activations = output.detach()
def _backward_hook(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_tensor, target_class=None):
"""Grad-CAMヒートマップを生成"""
output = self.model(input_tensor)
if target_class is None:
target_class = output.argmax(dim=1).item()
# 勾配の計算
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0][target_class] = 1
output.backward(gradient=one_hot)
# 重みの計算(Global Average Pooling)
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
# 重み付き特徴マップの和
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = F.relu(cam)
# 正規化
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
return cam.squeeze().cpu().numpy(), target_class
可視化関数
def visualize_gradcam(image_path, model, target_layer,
class_names=None):
"""Grad-CAMの可視化"""
# 画像の読み込みと前処理
img = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225]),
])
input_tensor = transform(img).unsqueeze(0)
# Grad-CAMの生成
grad_cam = GradCAM(model, target_layer)
heatmap, pred_class = grad_cam.generate(input_tensor)
# ヒートマップの重畳
img_resized = np.array(img.resize((224, 224))) / 255.0
heatmap_resized = cv2.resize(heatmap, (224, 224))
heatmap_color = cv2.applyColorMap(
np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET
) / 255.0
heatmap_color = cv2.cvtColor(
heatmap_color.astype(np.float32), cv2.COLOR_BGR2RGB
)
overlay = 0.6 * img_resized + 0.4 * heatmap_color
# 表示
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(img_resized)
axes[0].set_title("元画像")
axes[1].imshow(heatmap_resized, cmap="jet")
axes[1].set_title("Grad-CAMヒートマップ")
axes[2].imshow(overlay)
class_label = class_names[pred_class] if class_names else pred_class
axes[2].set_title(f"重畳表示(予測: {class_label})")
for ax in axes:
ax.axis("off")
plt.tight_layout()
plt.show()
Grad-CAMによるモデル診断
正しい判断根拠の確認
# Chest X-Rayモデルの場合
# 正しい根拠: 肺野の浸潤影領域に注目
# 誤った根拠: 画像端のラベルや機器情報に注目
def diagnose_model(model, target_layer, test_images, class_names):
"""複数画像でモデルの注目領域を確認"""
for img_path in test_images:
print(f"\n--- {img_path} ---")
visualize_gradcam(img_path, model, target_layer, class_names)
# ResNet-50の場合、最終畳み込み層はlayer4
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_model.pth"))
target_layer = model.layer4[-1]
class_names = ["NORMAL", "PNEUMONIA"]
ショートカット学習の検出パターン
注意すべきパターン:
1. 画像の四隅に高い活性化 → 施設ラベルや撮影情報に注目
2. 画像全体に均一な活性化 → 意味のある特徴を学習していない
3. 肺野以外(肩、横隔膜下)に注目 → ドメイン外の特徴に依存
4. 片側のみに偏った注目 → 位置バイアスの可能性
Grad-CAM++とスコアCAM
# Grad-CAM++: より精密な局所化
# 勾配の2次微分も使用して重みを計算
# 複数オブジェクトの可視化に優れる
# Score-CAM: 勾配不要の手法
# 各チャネルの特徴マップを直接マスクとして使用
# 勾配消失の影響を受けない
| 手法 | 精度 | 速度 | 特徴 |
|---|---|---|---|
| Grad-CAM | 良好 | 高速 | 最も広く使われる |
| Grad-CAM++ | より精密 | やや遅い | 複数オブジェクト向き |
| Score-CAM | 最も精密 | 遅い | 勾配不要 |
まとめ
| 項目 | ポイント |
|---|---|
| Grad-CAMの原理 | 勾配による重み付き特徴マップの可視化 |
| ショートカット学習 | Grad-CAMで誤った判断根拠を検出可能 |
| モデル診断 | 複数画像で注目領域のパターンを確認 |
| 発展手法 | Grad-CAM++、Score-CAMでより精密な可視化 |
チェックリスト
- Grad-CAMの仕組みを説明できる
- PyTorchでGrad-CAMを実装できる
- ショートカット学習をGrad-CAMで検出できる
- Grad-CAMの結果から改善点を特定できる
次のステップへ
Grad-CAMによる解釈性を確保したところで、次はタスクに最適なモデルを選定する方法を学ぼう。
推定読了時間: 30分