LESSON

データ拡張で汎化性能を向上させる

「データが少ないなら、増やせばいい。ただし、意味のある増やし方をしないと逆効果だ。」

田中VPoEがデータ拡張の効果を示すグラフを見せる。

「医療画像や農業画像では、やっていい拡張とやってはいけない拡張がある。ドメイン知識が重要だ。」

データ拡張の基本

データ拡張(Data Augmentation)は、学習データに変換を加えることで実質的なデータ量を増やし、モデルの汎化性能を向上させる手法である。

基本的な画像変換

from torchvision import transforms

# 基本的なデータ拡張パイプライン
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.1,
        hue=0.05
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

ドメイン別の拡張戦略

拡張手法商品画像医療画像農業画像
水平反転OK注意(左右が重要な場合あり)OK
垂直反転注意NG(上下は意味を持つ)OK
回転小角度OK小角度OKOK
色調変化OK注意(色が診断情報)控えめに
ランダムクロップOK注意(病変部を含むか)OK
ぼかしOK控えめにOK

高度なデータ拡張

Albumentationsライブラリ

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Chest X-Ray向けの拡張パイプライン
medical_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.1,
        scale_limit=0.15,
        rotate_limit=15,
        p=0.5
    ),
    A.OneOf([
        A.GaussNoise(var_limit=(10.0, 50.0), p=1),
        A.GaussianBlur(blur_limit=(3, 5), p=1),
    ], p=0.3),
    A.RandomBrightnessContrast(
        brightness_limit=0.2,
        contrast_limit=0.2,
        p=0.5
    ),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2(),
])

# Plant Pathology向けの拡張パイプライン
plant_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.1,
        scale_limit=0.2,
        rotate_limit=45,
        p=0.5
    ),
    A.OneOf([
        A.RandomBrightnessContrast(p=1),
        A.HueSaturationValue(p=1),
    ], p=0.5),
    A.CoarseDropout(
        max_holes=8,
        max_height=20,
        max_width=20,
        p=0.3
    ),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2(),
])

Mixup と CutMix

import numpy as np

def mixup_data(x, y, alpha=0.2):
    """Mixup: 2つの画像とラベルを線形補間で混合"""
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Mixup用の損失関数"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

クラス不均衡への対応

Chest X-Rayデータセットは肺炎画像が多く、クラス不均衡が存在する。

from torch.utils.data import WeightedRandomSampler

def create_balanced_sampler(dataset):
    """クラス不均衡を解消するためのサンプラーを作成"""
    targets = [sample[1] for sample in dataset.samples]
    class_counts = np.bincount(targets)
    class_weights = 1.0 / class_counts
    sample_weights = [class_weights[t] for t in targets]
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    return sampler

# 使用例
sampler = create_balanced_sampler(train_dataset)
train_loader = DataLoader(
    train_dataset, batch_size=32, sampler=sampler
)

重み付き損失関数

# クラス数に応じた重み付け
class_counts = [1341, 3875]  # [NORMAL, PNEUMONIA]
total = sum(class_counts)
class_weights = torch.tensor(
    [total / c for c in class_counts],
    dtype=torch.float
).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

まとめ

項目ポイント
基本拡張反転、回転、色調変化はドメインに応じて選択
高度な拡張Albumentationsで柔軟なパイプライン構築
Mixup/CutMix正則化効果があり、過学習を抑制
クラス不均衡WeightedRandomSampler、重み付き損失関数で対応

チェックリスト

  • ドメイン別の拡張戦略の違いを説明できる
  • Albumentationsでカスタムパイプラインを構築できる
  • Mixupの仕組みと効果を理解した
  • クラス不均衡への対処法を2つ以上挙げられる

次のステップへ

データ拡張の手法を学んだところで、次はGrad-CAMを使ってモデルの判断根拠を可視化する方法を学ぼう。

推定読了時間: 30分