Собесов

Сценарий ML: two-tower архитектура для recsys

ML / Data ScienceRecsysСложнаяSenior

Условие

Опишите two-tower архитектуру для retrieval-стадии recsys. Зачем нужны две башни, как обучать?

Решение

Подход

Two-tower (Google, 2019 — YouTube DNN, потом TikTok, Pinterest):

  • User tower: user features (profile, recent activity) → user_emb (например, 128-dim).
  • Item tower: item features (text, image, meta, category) → item_emb (128-dim).
  • Similarity: score = user_emb · item_emb (dot product) или cosine.

Обучается end-to-end на (user, positive_item, negative_items) с in-batch softmax / sampled softmax loss.

Зачем «две башни»?

  • Item-side embedding можно прекомпьютить для всех items офлайн.
  • На запрос user: один forward user tower + ANN поиск (FAISS).
  • Latency милли-секундная для миллионов items.
  • Cold start новых items решается через side features.

Реализация

import torch
import torch.nn as nn
 
class UserTower(nn.Module):
    def __init__(self, n_users, n_cats, emb_dim=128):
        super().__init__()
        self.user_emb = nn.Embedding(n_users, 64)
        self.cat_emb = nn.Embedding(n_cats, 16)
        self.mlp = nn.Sequential(
            nn.Linear(64 + 16 + 10, 256), nn.ReLU(),
            nn.Linear(256, emb_dim),
        )
 
    def forward(self, user_id, fav_category, num_features):
        x = torch.cat([self.user_emb(user_id), self.cat_emb(fav_category), num_features], dim=-1)
        return nn.functional.normalize(self.mlp(x), dim=-1)
 
class ItemTower(nn.Module):
    def __init__(self, n_items, n_cats, emb_dim=128):
        super().__init__()
        self.item_emb = nn.Embedding(n_items, 64)
        self.cat_emb = nn.Embedding(n_cats, 16)
        self.text_proj = nn.Linear(384, 64)  # из sentence-bert
        self.mlp = nn.Sequential(
            nn.Linear(64+16+64, 256), nn.ReLU(),
            nn.Linear(256, emb_dim),
        )
 
    def forward(self, item_id, category, text_emb):
        x = torch.cat([self.item_emb(item_id), self.cat_emb(category), self.text_proj(text_emb)], dim=-1)
        return nn.functional.normalize(self.mlp(x), dim=-1)

Обучение: in-batch softmax

def train_step(user_batch, pos_item_batch, item_tower, user_tower):
    u = user_tower(user_batch)              # (B, D)
    i = item_tower(pos_item_batch)          # (B, D)
    logits = u @ i.T                         # (B, B): диагональ — positive
    labels = torch.arange(len(u)).cuda()
    loss = nn.functional.cross_entropy(logits, labels)
    return loss

Каждый positive item в batch служит negative для других users в том же batch → бесплатные negatives.

Inference

# offline: precompute item embeddings
all_item_emb = item_tower(all_items_batched)  # (N_items, D)
faiss_index = faiss.IndexFlatIP(D)
faiss_index.add(all_item_emb.cpu().numpy())
 
# online: user request
def retrieve(user_features, top_k=200):
    u_emb = user_tower(user_features)
    D, I = faiss_index.search(u_emb.cpu().numpy(), top_k)
    return I  # top-k item_ids

После retrieve → ranking-этап с тяжёлой моделью (DeepFM/Transformer) на 200 кандидатов.

Sampled softmax correction (popularity bias)

In-batch negatives over-represent popular items. Поправка:

logit_ij −= log P(item_j)

где P(item_j) — frequency item_j. Это log-Q correction. Помогает модели не недо-rate популярные.

Подводные камни

  1. Hard negatives: in-batch negatives часто слишком easy. Нужны explicit hard negatives — items, которые user видел но не click.
  2. Popularity bias: log-Q correction обязательна на больших каталогах.
  3. Embedding drift: item embedding меняется при retraining → ANN index надо перестраивать. Schedule daily или incremental.
  4. Cold items без history: side features (text, category) обязательны для cold-start.
  5. User-side history aggregation: average / attention / transformer над recent actions. Простой average underperforms.
  6. Dot product vs cosine: dot favours большие emb-нормы. Нормируйте для retrieval; для ranking лучше raw dot.

Эталонный ответ

Two-tower: user-tower(features) → user_emb, item-tower(features) → item_emb, score = dot(user_emb, item_emb). Item embeddings precomputed, online — forward user + ANN (FAISS). Train: in-batch softmax + log-Q correction для popularity. После retrieve → отдельный ranking-stage на heavy model.

Хочешь увидеть разбор?

Зарегистрируйся бесплатно — откроется развёрнутое решение этой задачи и ещё 4 на выбор.

Зарегистрироваться и увидеть разбор
Уже есть аккаунт? Войти