タスクに最適なモデルを選定する
「精度だけでモデルを選ぶな。推論速度、モデルサイズ、解釈性のバランスが重要だ。」
田中VPoEがモデル比較表を見せる。
「農業テック部門ではスマートフォンで動かしたいという要件がある。エッジ推論を見据えた選定をしてくれ。」
モデル選定の評価軸
4つの評価軸
1. 精度(Accuracy/AUC) → タスクの要求精度を満たすか
2. 推論速度(Latency) → リアルタイム性の要件を満たすか
3. モデルサイズ(Params) → デプロイ環境の制約を満たすか
4. 解釈性(Interpretability) → Grad-CAMの品質は十分か
timmライブラリによるモデル比較
import timm
# 利用可能なモデル一覧
models_list = timm.list_models("*efficientnet*")
print(f"EfficientNet系モデル数: {len(models_list)}")
# モデルの比較実験
candidates = {
"resnet50": {"params": "25.6M", "imagenet_acc": "80.9%"},
"efficientnet_b0": {"params": "5.3M", "imagenet_acc": "77.7%"},
"efficientnet_b3": {"params": "12.2M", "imagenet_acc": "82.0%"},
"convnext_tiny": {"params": "28.6M", "imagenet_acc": "82.1%"},
"mobilenetv3_large_100": {"params": "5.4M", "imagenet_acc": "75.8%"},
}
def create_model_from_timm(model_name, num_classes=2):
"""timmからモデルを作成"""
model = timm.create_model(
model_name,
pretrained=True,
num_classes=num_classes
)
return model
ベンチマーク実験
import time
def benchmark_model(model, input_size=(1, 3, 224, 224), n_runs=100):
"""モデルの推論速度を計測"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()
dummy_input = torch.randn(input_size).to(device)
# ウォームアップ
for _ in range(10):
with torch.no_grad():
model(dummy_input)
# 計測
if device.type == "cuda":
torch.cuda.synchronize()
start = time.time()
for _ in range(n_runs):
with torch.no_grad():
model(dummy_input)
if device.type == "cuda":
torch.cuda.synchronize()
elapsed = (time.time() - start) / n_runs * 1000 # ms
return elapsed
# 各モデルのベンチマーク
for name in candidates:
model = create_model_from_timm(name)
latency = benchmark_model(model)
params = sum(p.numel() for p in model.parameters()) / 1e6
print(f"{name}: {latency:.1f}ms, {params:.1f}M params")
用途別の推奨モデル
| 用途 | 推奨モデル | 理由 |
|---|---|---|
| 高精度が必要(医療) | EfficientNet-B3/B4 | 高精度、Grad-CAM品質良好 |
| エッジ推論(農業) | MobileNetV3 / EfficientNet-B0 | 軽量、高速 |
| バランス重視(品質検査) | ResNet-50 | 安定性、エコシステム充実 |
| 最新の高精度 | ConvNeXt-Tiny | Transformer的設計、高精度 |
学習率スケジューリング
# Cosine Annealing with Warm Restart
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=5, # 最初の周期の長さ
T_mult=2, # 周期を2倍に延長
eta_min=1e-6 # 最小学習率
)
# OneCycleLR(推奨)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=1e-3,
epochs=20,
steps_per_epoch=len(train_loader),
pct_start=0.3, # 全体の30%でウォームアップ
anneal_strategy="cos"
)
モデルのアンサンブル
class EnsembleModel(nn.Module):
def __init__(self, models, weights=None):
super().__init__()
self.models = nn.ModuleList(models)
if weights is None:
weights = [1.0 / len(models)] * len(models)
self.weights = weights
def forward(self, x):
outputs = []
for model, weight in zip(self.models, self.weights):
out = torch.softmax(model(x), dim=1)
outputs.append(out * weight)
return torch.stack(outputs).sum(dim=0)
# ResNet + EfficientNetのアンサンブル
ensemble = EnsembleModel([
resnet_model,
efficientnet_model
], weights=[0.4, 0.6])
まとめ
| 項目 | ポイント |
|---|---|
| 評価軸 | 精度、推論速度、モデルサイズ、解釈性の4軸 |
| timmライブラリ | 数百のモデルを統一インターフェースで利用可能 |
| 用途別選定 | デプロイ環境と要求精度に応じて選択 |
| アンサンブル | 複数モデルの組み合わせで精度向上 |
チェックリスト
- モデル選定の4つの評価軸を説明できる
- timmライブラリの使い方を理解した
- 用途に応じたモデル選定ができる
- アンサンブルの仕組みと効果を理解した
次のステップへ
モデル選定の知識を身につけたところで、次は演習で実際に画像分類モデルを構築しよう。
推定読了時間: 30分