LESSON

Grad-CAMで判断根拠を可視化する

「モデルが95%の精度を出しました、では通らない。なぜその判断に至ったかを説明できないとダメだ。」

田中VPoEが過去の失敗事例を示す。

「以前、外観検査AIが99%の精度を出したが、実は製品ではなく背景の色で判断していた。Grad-CAMで初めてそれに気づいた。」

Grad-CAMとは

Grad-CAM(Gradient-weighted Class Activation Mapping)は、CNNが画像のどの領域に注目して判断を下したかを可視化する手法である。

仕組み

1. ターゲットクラスに対する損失を計算
2. 最終畳み込み層の特徴マップに対する勾配を取得
3. 勾配をGlobal Average Poolingで重みに変換
4. 重み付き特徴マップの和を計算しReLUを適用
5. 入力画像にヒートマップを重畳

Grad-CAMの実装

import torch
import torch.nn.functional as F
import numpy as np
import cv2
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.model.eval()
        self.gradients = None
        self.activations = None

        # フック関数の登録
        target_layer.register_forward_hook(self._forward_hook)
        target_layer.register_full_backward_hook(self._backward_hook)

    def _forward_hook(self, module, input, output):
        self.activations = output.detach()

    def _backward_hook(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, input_tensor, target_class=None):
        """Grad-CAMヒートマップを生成"""
        output = self.model(input_tensor)

        if target_class is None:
            target_class = output.argmax(dim=1).item()

        # 勾配の計算
        self.model.zero_grad()
        one_hot = torch.zeros_like(output)
        one_hot[0][target_class] = 1
        output.backward(gradient=one_hot)

        # 重みの計算(Global Average Pooling)
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)

        # 重み付き特徴マップの和
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)

        # 正規化
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)

        return cam.squeeze().cpu().numpy(), target_class

可視化関数

def visualize_gradcam(image_path, model, target_layer,
                      class_names=None):
    """Grad-CAMの可視化"""
    # 画像の読み込みと前処理
    img = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                           [0.229, 0.224, 0.225]),
    ])
    input_tensor = transform(img).unsqueeze(0)

    # Grad-CAMの生成
    grad_cam = GradCAM(model, target_layer)
    heatmap, pred_class = grad_cam.generate(input_tensor)

    # ヒートマップの重畳
    img_resized = np.array(img.resize((224, 224))) / 255.0
    heatmap_resized = cv2.resize(heatmap, (224, 224))
    heatmap_color = cv2.applyColorMap(
        np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET
    ) / 255.0
    heatmap_color = cv2.cvtColor(
        heatmap_color.astype(np.float32), cv2.COLOR_BGR2RGB
    )

    overlay = 0.6 * img_resized + 0.4 * heatmap_color

    # 表示
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(img_resized)
    axes[0].set_title("元画像")
    axes[1].imshow(heatmap_resized, cmap="jet")
    axes[1].set_title("Grad-CAMヒートマップ")
    axes[2].imshow(overlay)
    class_label = class_names[pred_class] if class_names else pred_class
    axes[2].set_title(f"重畳表示(予測: {class_label})")

    for ax in axes:
        ax.axis("off")
    plt.tight_layout()
    plt.show()

Grad-CAMによるモデル診断

正しい判断根拠の確認

# Chest X-Rayモデルの場合
# 正しい根拠: 肺野の浸潤影領域に注目
# 誤った根拠: 画像端のラベルや機器情報に注目

def diagnose_model(model, target_layer, test_images, class_names):
    """複数画像でモデルの注目領域を確認"""
    for img_path in test_images:
        print(f"\n--- {img_path} ---")
        visualize_gradcam(img_path, model, target_layer, class_names)

# ResNet-50の場合、最終畳み込み層はlayer4
model = models.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load("best_model.pth"))

target_layer = model.layer4[-1]
class_names = ["NORMAL", "PNEUMONIA"]

ショートカット学習の検出パターン

注意すべきパターン:
1. 画像の四隅に高い活性化 → 施設ラベルや撮影情報に注目
2. 画像全体に均一な活性化 → 意味のある特徴を学習していない
3. 肺野以外(肩、横隔膜下)に注目 → ドメイン外の特徴に依存
4. 片側のみに偏った注目 → 位置バイアスの可能性

Grad-CAM++とスコアCAM

# Grad-CAM++: より精密な局所化
# 勾配の2次微分も使用して重みを計算
# 複数オブジェクトの可視化に優れる

# Score-CAM: 勾配不要の手法
# 各チャネルの特徴マップを直接マスクとして使用
# 勾配消失の影響を受けない
手法精度速度特徴
Grad-CAM良好高速最も広く使われる
Grad-CAM++より精密やや遅い複数オブジェクト向き
Score-CAM最も精密遅い勾配不要

まとめ

項目ポイント
Grad-CAMの原理勾配による重み付き特徴マップの可視化
ショートカット学習Grad-CAMで誤った判断根拠を検出可能
モデル診断複数画像で注目領域のパターンを確認
発展手法Grad-CAM++、Score-CAMでより精密な可視化

チェックリスト

  • Grad-CAMの仕組みを説明できる
  • PyTorchでGrad-CAMを実装できる
  • ショートカット学習をGrad-CAMで検出できる
  • Grad-CAMの結果から改善点を特定できる

次のステップへ

Grad-CAMによる解釈性を確保したところで、次はタスクに最適なモデルを選定する方法を学ぼう。

推定読了時間: 30分