Условие
Опишите two-tower архитектуру для retrieval-стадии recsys. Зачем нужны две башни, как обучать?
Решение
Подход
Two-tower (Google, 2019 — YouTube DNN, потом TikTok, Pinterest):
- User tower: user features (profile, recent activity) →
user_emb(например, 128-dim). - Item tower: item features (text, image, meta, category) →
item_emb(128-dim). - Similarity:
score = user_emb · item_emb(dot product) или cosine.
Обучается end-to-end на (user, positive_item, negative_items) с in-batch softmax / sampled softmax loss.
Зачем «две башни»?
- Item-side embedding можно прекомпьютить для всех items офлайн.
- На запрос user: один forward user tower + ANN поиск (FAISS).
- Latency милли-секундная для миллионов items.
- Cold start новых items решается через side features.
Реализация
import torch
import torch.nn as nn
class UserTower(nn.Module):
def __init__(self, n_users, n_cats, emb_dim=128):
super().__init__()
self.user_emb = nn.Embedding(n_users, 64)
self.cat_emb = nn.Embedding(n_cats, 16)
self.mlp = nn.Sequential(
nn.Linear(64 + 16 + 10, 256), nn.ReLU(),
nn.Linear(256, emb_dim),
)
def forward(self, user_id, fav_category, num_features):
x = torch.cat([self.user_emb(user_id), self.cat_emb(fav_category), num_features], dim=-1)
return nn.functional.normalize(self.mlp(x), dim=-1)
class ItemTower(nn.Module):
def __init__(self, n_items, n_cats, emb_dim=128):
super().__init__()
self.item_emb = nn.Embedding(n_items, 64)
self.cat_emb = nn.Embedding(n_cats, 16)
self.text_proj = nn.Linear(384, 64) # из sentence-bert
self.mlp = nn.Sequential(
nn.Linear(64+16+64, 256), nn.ReLU(),
nn.Linear(256, emb_dim),
)
def forward(self, item_id, category, text_emb):
x = torch.cat([self.item_emb(item_id), self.cat_emb(category), self.text_proj(text_emb)], dim=-1)
return nn.functional.normalize(self.mlp(x), dim=-1)Обучение: in-batch softmax
def train_step(user_batch, pos_item_batch, item_tower, user_tower):
u = user_tower(user_batch) # (B, D)
i = item_tower(pos_item_batch) # (B, D)
logits = u @ i.T # (B, B): диагональ — positive
labels = torch.arange(len(u)).cuda()
loss = nn.functional.cross_entropy(logits, labels)
return lossКаждый positive item в batch служит negative для других users в том же batch → бесплатные negatives.
Inference
# offline: precompute item embeddings
all_item_emb = item_tower(all_items_batched) # (N_items, D)
faiss_index = faiss.IndexFlatIP(D)
faiss_index.add(all_item_emb.cpu().numpy())
# online: user request
def retrieve(user_features, top_k=200):
u_emb = user_tower(user_features)
D, I = faiss_index.search(u_emb.cpu().numpy(), top_k)
return I # top-k item_idsПосле retrieve → ranking-этап с тяжёлой моделью (DeepFM/Transformer) на 200 кандидатов.
Sampled softmax correction (popularity bias)
In-batch negatives over-represent popular items. Поправка:
logit_ij −= log P(item_j)
где P(item_j) — frequency item_j. Это log-Q correction. Помогает модели не недо-rate популярные.
Подводные камни
- Hard negatives: in-batch negatives часто слишком easy. Нужны explicit hard negatives — items, которые user видел но не click.
- Popularity bias: log-Q correction обязательна на больших каталогах.
- Embedding drift: item embedding меняется при retraining → ANN index надо перестраивать. Schedule daily или incremental.
- Cold items без history: side features (text, category) обязательны для cold-start.
- User-side history aggregation: average / attention / transformer над recent actions. Простой average underperforms.
- Dot product vs cosine: dot favours большие emb-нормы. Нормируйте для retrieval; для ranking лучше raw dot.
Эталонный ответ
Two-tower: user-tower(features) → user_emb, item-tower(features) → item_emb, score = dot(user_emb, item_emb). Item embeddings precomputed, online — forward user + ANN (FAISS). Train: in-batch softmax + log-Q correction для popularity. После retrieve → отдельный ranking-stage на heavy model.