データ拡張で汎化性能を向上させる
「データが少ないなら、増やせばいい。ただし、意味のある増やし方をしないと逆効果だ。」
田中VPoEがデータ拡張の効果を示すグラフを見せる。
「医療画像や農業画像では、やっていい拡張とやってはいけない拡張がある。ドメイン知識が重要だ。」
データ拡張の基本
データ拡張(Data Augmentation)は、学習データに変換を加えることで実質的なデータ量を増やし、モデルの汎化性能を向上させる手法である。
基本的な画像変換
from torchvision import transforms
# 基本的なデータ拡張パイプライン
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.1,
hue=0.05
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
ドメイン別の拡張戦略
| 拡張手法 | 商品画像 | 医療画像 | 農業画像 |
|---|---|---|---|
| 水平反転 | OK | 注意(左右が重要な場合あり) | OK |
| 垂直反転 | 注意 | NG(上下は意味を持つ) | OK |
| 回転 | 小角度OK | 小角度OK | OK |
| 色調変化 | OK | 注意(色が診断情報) | 控えめに |
| ランダムクロップ | OK | 注意(病変部を含むか) | OK |
| ぼかし | OK | 控えめに | OK |
高度なデータ拡張
Albumentationsライブラリ
import albumentations as A
from albumentations.pytorch import ToTensorV2
# Chest X-Ray向けの拡張パイプライン
medical_transform = A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(p=0.5),
A.ShiftScaleRotate(
shift_limit=0.1,
scale_limit=0.15,
rotate_limit=15,
p=0.5
),
A.OneOf([
A.GaussNoise(var_limit=(10.0, 50.0), p=1),
A.GaussianBlur(blur_limit=(3, 5), p=1),
], p=0.3),
A.RandomBrightnessContrast(
brightness_limit=0.2,
contrast_limit=0.2,
p=0.5
),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2(),
])
# Plant Pathology向けの拡張パイプライン
plant_transform = A.Compose([
A.Resize(224, 224),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.ShiftScaleRotate(
shift_limit=0.1,
scale_limit=0.2,
rotate_limit=45,
p=0.5
),
A.OneOf([
A.RandomBrightnessContrast(p=1),
A.HueSaturationValue(p=1),
], p=0.5),
A.CoarseDropout(
max_holes=8,
max_height=20,
max_width=20,
p=0.3
),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2(),
])
Mixup と CutMix
import numpy as np
def mixup_data(x, y, alpha=0.2):
"""Mixup: 2つの画像とラベルを線形補間で混合"""
lam = np.random.beta(alpha, alpha)
batch_size = x.size(0)
index = torch.randperm(batch_size).to(x.device)
mixed_x = lam * x + (1 - lam) * x[index]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
"""Mixup用の損失関数"""
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
クラス不均衡への対応
Chest X-Rayデータセットは肺炎画像が多く、クラス不均衡が存在する。
from torch.utils.data import WeightedRandomSampler
def create_balanced_sampler(dataset):
"""クラス不均衡を解消するためのサンプラーを作成"""
targets = [sample[1] for sample in dataset.samples]
class_counts = np.bincount(targets)
class_weights = 1.0 / class_counts
sample_weights = [class_weights[t] for t in targets]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
return sampler
# 使用例
sampler = create_balanced_sampler(train_dataset)
train_loader = DataLoader(
train_dataset, batch_size=32, sampler=sampler
)
重み付き損失関数
# クラス数に応じた重み付け
class_counts = [1341, 3875] # [NORMAL, PNEUMONIA]
total = sum(class_counts)
class_weights = torch.tensor(
[total / c for c in class_counts],
dtype=torch.float
).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
まとめ
| 項目 | ポイント |
|---|---|
| 基本拡張 | 反転、回転、色調変化はドメインに応じて選択 |
| 高度な拡張 | Albumentationsで柔軟なパイプライン構築 |
| Mixup/CutMix | 正則化効果があり、過学習を抑制 |
| クラス不均衡 | WeightedRandomSampler、重み付き損失関数で対応 |
チェックリスト
- ドメイン別の拡張戦略の違いを説明できる
- Albumentationsでカスタムパイプラインを構築できる
- Mixupの仕組みと効果を理解した
- クラス不均衡への対処法を2つ以上挙げられる
次のステップへ
データ拡張の手法を学んだところで、次はGrad-CAMを使ってモデルの判断根拠を可視化する方法を学ぼう。
推定読了時間: 30分