LESSON

Two-Towerモデル

「大規模な推薦システムでは、全アイテムをリアルタイムでスコアリングできない。」

田中VPoEがGoogleの推薦システム論文を見せる。

「Two-Towerモデルは、ユーザーとアイテムを別々のネットワークで埋め込み、内積で高速にマッチングする。YouTubeやGoogleでも使われているアーキテクチャだ。」

Two-Towerアーキテクチャ

ユーザータワー           アイテムタワー
┌──────────┐           ┌──────────┐
│ユーザーID │           │アイテムID │
│年齢       │           │カテゴリ  │
│性別       │           │価格帯    │
│行動履歴   │           │説明文    │
└─────┬────┘           └─────┬────┘
      ↓                      ↓
  [Dense層]              [Dense層]
      ↓                      ↓
  [Dense層]              [Dense層]
      ↓                      ↓
  ユーザー埋め込み       アイテム埋め込み
  (128次元)              (128次元)
      └────────┬──────────┘
             内積

          マッチスコア

実装

import torch
import torch.nn as nn

class UserTower(nn.Module):
    """ユーザータワー"""

    def __init__(self, n_users, n_features, embed_dim=128):
        super().__init__()
        self.user_embed = nn.Embedding(n_users, 64)
        self.feature_layer = nn.Linear(n_features, 64)
        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embed_dim),
        )

    def forward(self, user_ids, user_features):
        user_emb = self.user_embed(user_ids)
        feat_emb = self.feature_layer(user_features)
        combined = torch.cat([user_emb, feat_emb], dim=1)
        return self.fc(combined)


class ItemTower(nn.Module):
    """アイテムタワー"""

    def __init__(self, n_items, n_features, embed_dim=128):
        super().__init__()
        self.item_embed = nn.Embedding(n_items, 64)
        self.feature_layer = nn.Linear(n_features, 64)
        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, embed_dim),
        )

    def forward(self, item_ids, item_features):
        item_emb = self.item_embed(item_ids)
        feat_emb = self.feature_layer(item_features)
        combined = torch.cat([item_emb, feat_emb], dim=1)
        return self.fc(combined)


class TwoTowerModel(nn.Module):
    """Two-Tower推薦モデル"""

    def __init__(self, n_users, n_items, user_feat_dim, item_feat_dim, embed_dim=128):
        super().__init__()
        self.user_tower = UserTower(n_users, user_feat_dim, embed_dim)
        self.item_tower = ItemTower(n_items, item_feat_dim, embed_dim)

    def forward(self, user_ids, user_features, item_ids, item_features):
        user_emb = self.user_tower(user_ids, user_features)
        item_emb = self.item_tower(item_ids, item_features)

        # L2正規化
        user_emb = nn.functional.normalize(user_emb, dim=1)
        item_emb = nn.functional.normalize(item_emb, dim=1)

        # 内積でスコア計算
        scores = (user_emb * item_emb).sum(dim=1)
        return scores

学習

def train_two_tower(model, train_loader, optimizer, n_epochs=10):
    """Two-Towerモデルの学習"""
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(n_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            user_ids, user_feats, item_ids, item_feats, labels = batch

            scores = model(user_ids, user_feats, item_ids, item_feats)
            loss = criterion(scores, labels.float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")

推論の高速化(ANN検索)

import faiss

def build_item_index(model, item_ids, item_features, embed_dim=128):
    """アイテム埋め込みのインデックスを構築"""
    model.eval()
    with torch.no_grad():
        item_embeddings = model.item_tower(
            torch.tensor(item_ids),
            torch.tensor(item_features, dtype=torch.float32)
        ).numpy()

    # FAISSインデックスの構築
    index = faiss.IndexFlatIP(embed_dim)  # 内積ベース
    faiss.normalize_L2(item_embeddings)
    index.add(item_embeddings)

    return index

def recommend_fast(model, user_id, user_features, index, k=10):
    """ANN検索で高速推薦"""
    model.eval()
    with torch.no_grad():
        user_emb = model.user_tower(
            torch.tensor([user_id]),
            torch.tensor([user_features], dtype=torch.float32)
        ).numpy()

    faiss.normalize_L2(user_emb)
    scores, indices = index.search(user_emb, k)

    return indices[0], scores[0]

まとめ

項目ポイント
アーキテクチャユーザー/アイテムを独立したタワーで埋め込み
スコア計算埋め込みの内積でマッチ度を算出
高速推論FAISS等のANN検索で数百万アイテムにも対応
特徴候補生成(Retrieval)フェーズに最適

チェックリスト

  • Two-Towerの仕組みと利点を説明できる
  • ユーザータワーとアイテムタワーを実装できる
  • FAISSを使った高速推論の流れを理解した
  • Two-Towerが候補生成に適している理由を説明できる

次のステップへ

Two-Towerモデルを理解した。次は特徴量エンジニアリングを学ぼう。

推定読了時間: 30分