Two-Towerモデル
「大規模な推薦システムでは、全アイテムをリアルタイムでスコアリングできない。」
田中VPoEがGoogleの推薦システム論文を見せる。
「Two-Towerモデルは、ユーザーとアイテムを別々のネットワークで埋め込み、内積で高速にマッチングする。YouTubeやGoogleでも使われているアーキテクチャだ。」
Two-Towerアーキテクチャ
ユーザータワー アイテムタワー
┌──────────┐ ┌──────────┐
│ユーザーID │ │アイテムID │
│年齢 │ │カテゴリ │
│性別 │ │価格帯 │
│行動履歴 │ │説明文 │
└─────┬────┘ └─────┬────┘
↓ ↓
[Dense層] [Dense層]
↓ ↓
[Dense層] [Dense層]
↓ ↓
ユーザー埋め込み アイテム埋め込み
(128次元) (128次元)
└────────┬──────────┘
内積
↓
マッチスコア
実装
import torch
import torch.nn as nn
class UserTower(nn.Module):
"""ユーザータワー"""
def __init__(self, n_users, n_features, embed_dim=128):
super().__init__()
self.user_embed = nn.Embedding(n_users, 64)
self.feature_layer = nn.Linear(n_features, 64)
self.fc = nn.Sequential(
nn.Linear(128, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, embed_dim),
)
def forward(self, user_ids, user_features):
user_emb = self.user_embed(user_ids)
feat_emb = self.feature_layer(user_features)
combined = torch.cat([user_emb, feat_emb], dim=1)
return self.fc(combined)
class ItemTower(nn.Module):
"""アイテムタワー"""
def __init__(self, n_items, n_features, embed_dim=128):
super().__init__()
self.item_embed = nn.Embedding(n_items, 64)
self.feature_layer = nn.Linear(n_features, 64)
self.fc = nn.Sequential(
nn.Linear(128, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, embed_dim),
)
def forward(self, item_ids, item_features):
item_emb = self.item_embed(item_ids)
feat_emb = self.feature_layer(item_features)
combined = torch.cat([item_emb, feat_emb], dim=1)
return self.fc(combined)
class TwoTowerModel(nn.Module):
"""Two-Tower推薦モデル"""
def __init__(self, n_users, n_items, user_feat_dim, item_feat_dim, embed_dim=128):
super().__init__()
self.user_tower = UserTower(n_users, user_feat_dim, embed_dim)
self.item_tower = ItemTower(n_items, item_feat_dim, embed_dim)
def forward(self, user_ids, user_features, item_ids, item_features):
user_emb = self.user_tower(user_ids, user_features)
item_emb = self.item_tower(item_ids, item_features)
# L2正規化
user_emb = nn.functional.normalize(user_emb, dim=1)
item_emb = nn.functional.normalize(item_emb, dim=1)
# 内積でスコア計算
scores = (user_emb * item_emb).sum(dim=1)
return scores
学習
def train_two_tower(model, train_loader, optimizer, n_epochs=10):
"""Two-Towerモデルの学習"""
criterion = nn.BCEWithLogitsLoss()
for epoch in range(n_epochs):
model.train()
total_loss = 0
for batch in train_loader:
user_ids, user_feats, item_ids, item_feats, labels = batch
scores = model(user_ids, user_feats, item_ids, item_feats)
loss = criterion(scores, labels.float())
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")
推論の高速化(ANN検索)
import faiss
def build_item_index(model, item_ids, item_features, embed_dim=128):
"""アイテム埋め込みのインデックスを構築"""
model.eval()
with torch.no_grad():
item_embeddings = model.item_tower(
torch.tensor(item_ids),
torch.tensor(item_features, dtype=torch.float32)
).numpy()
# FAISSインデックスの構築
index = faiss.IndexFlatIP(embed_dim) # 内積ベース
faiss.normalize_L2(item_embeddings)
index.add(item_embeddings)
return index
def recommend_fast(model, user_id, user_features, index, k=10):
"""ANN検索で高速推薦"""
model.eval()
with torch.no_grad():
user_emb = model.user_tower(
torch.tensor([user_id]),
torch.tensor([user_features], dtype=torch.float32)
).numpy()
faiss.normalize_L2(user_emb)
scores, indices = index.search(user_emb, k)
return indices[0], scores[0]
まとめ
| 項目 | ポイント |
|---|---|
| アーキテクチャ | ユーザー/アイテムを独立したタワーで埋め込み |
| スコア計算 | 埋め込みの内積でマッチ度を算出 |
| 高速推論 | FAISS等のANN検索で数百万アイテムにも対応 |
| 特徴 | 候補生成(Retrieval)フェーズに最適 |
チェックリスト
- Two-Towerの仕組みと利点を説明できる
- ユーザータワーとアイテムタワーを実装できる
- FAISSを使った高速推論の流れを理解した
- Two-Towerが候補生成に適している理由を説明できる
次のステップへ
Two-Towerモデルを理解した。次は特徴量エンジニアリングを学ぼう。
推定読了時間: 30分