LESSON

Dataset/DataLoader

田中VPoE:「Tensor の基本は分かったね。次は、大量のデータを効率的にモデルに供給する仕組みを学ぼう。PyTorch の Dataset と DataLoader だ。」

あなた:「データのバッチ処理やシャッフルを自動化してくれるんですね。」

田中VPoE:「そうだ。NetShop の商品データは数十万件ある。全データを一度にメモリに載せるのは非効率だし、バッチ処理の実装も毎回書くのは面倒だ。DataLoader がそのあたりを全部やってくれる。」

Dataset クラス

torch.utils.data.Dataset はデータセットの抽象クラスで、以下の2つのメソッドを実装する必要があります。

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd

class NetShopDataset(Dataset):
    """NetShop の顧客データセット"""

    def __init__(self, features, labels, transform=None):
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32).unsqueeze(1)
        self.transform = transform

    def __len__(self):
        """データセットのサイズを返す"""
        return len(self.labels)

    def __getitem__(self, idx):
        """インデックスに対応するサンプルを返す"""
        x = self.features[idx]
        y = self.labels[idx]

        if self.transform:
            x = self.transform(x)

        return x, y

# 使用例
np.random.seed(42)
X = np.random.randn(1000, 10).astype(np.float32)
y = (np.random.random(1000) > 0.5).astype(np.float32)

dataset = NetShopDataset(X, y)
print(f"データセットサイズ: {len(dataset)}")
print(f"1サンプル: features={dataset[0][0].shape}, label={dataset[0][1].shape}")

CSV ファイルから読み込む Dataset

class CSVDataset(Dataset):
    """CSV ファイルから読み込むデータセット"""

    def __init__(self, csv_path, target_col, feature_cols=None):
        df = pd.read_csv(csv_path)

        if feature_cols is None:
            feature_cols = [c for c in df.columns if c != target_col]

        self.features = torch.tensor(
            df[feature_cols].values, dtype=torch.float32
        )
        self.labels = torch.tensor(
            df[target_col].values, dtype=torch.float32
        ).unsqueeze(1)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

DataLoader

DataLoader はデータセットからミニバッチを自動生成し、シャッフルや並列読み込みを行います。

# DataLoader の作成
train_loader = DataLoader(
    dataset,
    batch_size=32,       # バッチサイズ
    shuffle=True,        # エポックごとにシャッフル
    num_workers=2,       # 並列読み込みワーカー数
    drop_last=True       # 最後の不完全バッチを捨てる
)

# イテレーション
for batch_idx, (features, labels) in enumerate(train_loader):
    print(f"Batch {batch_idx}: features={features.shape}, labels={labels.shape}")
    if batch_idx >= 2:
        break

DataLoader の主要パラメータ

パラメータ説明推奨値
batch_size1バッチのサンプル数32, 64, 128, 256
shuffleデータのシャッフル学習: True, 評価: False
num_workers並列読み込み数CPU コア数の半分程度
drop_last最後の不完全バッチBatchNorm 使用時は True
pin_memoryGPU 転送の高速化GPU 使用時は True

データ前処理

transforms を使った前処理パイプライン

from torchvision import transforms

# 画像データの前処理例
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),          # サイズ統一
    transforms.RandomHorizontalFlip(),      # ランダム水平反転
    transforms.ToTensor(),                  # Tensor 変換 (0-1)
    transforms.Normalize(                   # 正規化
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

テーブルデータの前処理

class StandardScalerTransform:
    """標準化 Transform"""

    def __init__(self, mean=None, std=None):
        self.mean = mean
        self.std = std

    def fit(self, data):
        self.mean = data.mean(dim=0)
        self.std = data.std(dim=0)
        self.std[self.std == 0] = 1  # ゼロ除算防止
        return self

    def __call__(self, x):
        return (x - self.mean) / self.std

学習/検証/テストの分割

from torch.utils.data import random_split

# データセット全体
full_dataset = NetShopDataset(X, y)

# 分割(60% 学習, 20% 検証, 20% テスト)
total = len(full_dataset)
train_size = int(0.6 * total)
val_size = int(0.2 * total)
test_size = total - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# 各データセットの DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"学習: {len(train_dataset)}, 検証: {len(val_dataset)}, テスト: {len(test_dataset)}")

まとめ

  • Dataset__len____getitem__ を実装してデータを定義する
  • DataLoader はバッチ処理、シャッフル、並列読み込みを自動化する
  • transforms でデータ前処理をパイプライン化できる
  • random_split でデータセットを学習/検証/テストに分割できる

チェックリスト

  • カスタム Dataset クラスを作成できる
  • DataLoader の主要パラメータを理解した
  • transforms で前処理パイプラインを構築できる
  • データセットの分割と各 DataLoader の作成ができる

推定読了時間: 30分