LESSON

タスクに最適なモデルを選定する

「精度だけでモデルを選ぶな。推論速度、モデルサイズ、解釈性のバランスが重要だ。」

田中VPoEがモデル比較表を見せる。

「農業テック部門ではスマートフォンで動かしたいという要件がある。エッジ推論を見据えた選定をしてくれ。」

モデル選定の評価軸

4つの評価軸

1. 精度(Accuracy/AUC)  → タスクの要求精度を満たすか
2. 推論速度(Latency)    → リアルタイム性の要件を満たすか
3. モデルサイズ(Params) → デプロイ環境の制約を満たすか
4. 解釈性(Interpretability) → Grad-CAMの品質は十分か

timmライブラリによるモデル比較

import timm

# 利用可能なモデル一覧
models_list = timm.list_models("*efficientnet*")
print(f"EfficientNet系モデル数: {len(models_list)}")

# モデルの比較実験
candidates = {
    "resnet50": {"params": "25.6M", "imagenet_acc": "80.9%"},
    "efficientnet_b0": {"params": "5.3M", "imagenet_acc": "77.7%"},
    "efficientnet_b3": {"params": "12.2M", "imagenet_acc": "82.0%"},
    "convnext_tiny": {"params": "28.6M", "imagenet_acc": "82.1%"},
    "mobilenetv3_large_100": {"params": "5.4M", "imagenet_acc": "75.8%"},
}

def create_model_from_timm(model_name, num_classes=2):
    """timmからモデルを作成"""
    model = timm.create_model(
        model_name,
        pretrained=True,
        num_classes=num_classes
    )
    return model

ベンチマーク実験

import time

def benchmark_model(model, input_size=(1, 3, 224, 224), n_runs=100):
    """モデルの推論速度を計測"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()
    dummy_input = torch.randn(input_size).to(device)

    # ウォームアップ
    for _ in range(10):
        with torch.no_grad():
            model(dummy_input)

    # 計測
    if device.type == "cuda":
        torch.cuda.synchronize()

    start = time.time()
    for _ in range(n_runs):
        with torch.no_grad():
            model(dummy_input)

    if device.type == "cuda":
        torch.cuda.synchronize()

    elapsed = (time.time() - start) / n_runs * 1000  # ms
    return elapsed

# 各モデルのベンチマーク
for name in candidates:
    model = create_model_from_timm(name)
    latency = benchmark_model(model)
    params = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"{name}: {latency:.1f}ms, {params:.1f}M params")

用途別の推奨モデル

用途推奨モデル理由
高精度が必要(医療)EfficientNet-B3/B4高精度、Grad-CAM品質良好
エッジ推論(農業)MobileNetV3 / EfficientNet-B0軽量、高速
バランス重視(品質検査)ResNet-50安定性、エコシステム充実
最新の高精度ConvNeXt-TinyTransformer的設計、高精度

学習率スケジューリング

# Cosine Annealing with Warm Restart
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=5,       # 最初の周期の長さ
    T_mult=2,    # 周期を2倍に延長
    eta_min=1e-6 # 最小学習率
)

# OneCycleLR(推奨)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=1e-3,
    epochs=20,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,   # 全体の30%でウォームアップ
    anneal_strategy="cos"
)

モデルのアンサンブル

class EnsembleModel(nn.Module):
    def __init__(self, models, weights=None):
        super().__init__()
        self.models = nn.ModuleList(models)
        if weights is None:
            weights = [1.0 / len(models)] * len(models)
        self.weights = weights

    def forward(self, x):
        outputs = []
        for model, weight in zip(self.models, self.weights):
            out = torch.softmax(model(x), dim=1)
            outputs.append(out * weight)
        return torch.stack(outputs).sum(dim=0)

# ResNet + EfficientNetのアンサンブル
ensemble = EnsembleModel([
    resnet_model,
    efficientnet_model
], weights=[0.4, 0.6])

まとめ

項目ポイント
評価軸精度、推論速度、モデルサイズ、解釈性の4軸
timmライブラリ数百のモデルを統一インターフェースで利用可能
用途別選定デプロイ環境と要求精度に応じて選択
アンサンブル複数モデルの組み合わせで精度向上

チェックリスト

  • モデル選定の4つの評価軸を説明できる
  • timmライブラリの使い方を理解した
  • 用途に応じたモデル選定ができる
  • アンサンブルの仕組みと効果を理解した

次のステップへ

モデル選定の知識を身につけたところで、次は演習で実際に画像分類モデルを構築しよう。

推定読了時間: 30分