Собесов

Сценарий ML: Focal Loss для imbalanced классов

ML / Data ScienceКласс дисбалансСложнаяSenior

Условие

Объясните Focal Loss: что это, чем отличается от weighted cross-entropy, какие гиперпараметры.

Решение

Подход

Focal Loss (Lin et al., 2017 — RetinaNet):

FL(p_t) = −α_t · (1 − p_t)^γ · log(p_t)

где p_t = p если y=1 иначе 1−p. Множитель (1−p_t)^γ понижает вес легких примеров (с уверенным правильным предсказанием) и подсвечивает трудные (близкие к границе).

При γ=0 → обычный weighted CE. При γ=2 (default) → классические настройки RetinaNet.

α_t — балансирующий фактор класса (как class_weight).

Сравнение с weighted CE

Случай Weighted CE Focal Loss
Easy negative (p=0.99, y=0) loss ≈ 0.01·w_0 (1−0.99)^γ·0.01·w_0 → почти 0
Hard negative (p=0.6, y=0) 0.92·w_0 (1−0.4)^γ·0.92·w_0 = 0.36·w_0
Hard positive (p=0.3, y=1) 1.2·w_1 (1−0.3)^γ·1.2·w_1 = 0.59·w_1

Focal сужает фокус на трудных примерах, эффективно «выключая» простые negatives которых много при imbalanced.

Реализация (PyTorch)

import torch
import torch.nn as nn
 
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
 
    def forward(self, logits, targets):
        # binary
        p = torch.sigmoid(logits)
        p_t = p * targets + (1 - p) * (1 - targets)
        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
        loss = -alpha_t * (1 - p_t) ** self.gamma * torch.log(p_t.clamp(min=1e-8))
        return loss.mean()

Когда

  • Object detection с тысячами anchor boxes, большинство — easy background.
  • Сильно imbalanced классификация (1:1000+).
  • Когда class_weight недостаточно эффективен.

Когда НЕ нужно

  • Сбалансированные данные: γ=0 = обычный CE.
  • Малая выборка: focal loss с γ=2 «выключает» большую часть примеров → недообучение.
  • Когда нужен calibrated predict_proba — focal сильнее искажает probability.

Подводные камни

  1. γ слишком большой (5+) → почти все примеры с уверенной prediction игнорируются → стагнация loss.
  2. Initialization bias: для bounded class в начале training модель predicts 0.5 — focal loss выше; иногда полезен bias initialization (Lin et al.).
  3. Focal loss с softmax (multiclass) — менее стандартная форма; обычно binary one-vs-rest.
  4. Калибровка после focal loss сложнее, чем после CE; нужно более длинное calibration set.
  5. α балансирует классы — не забывайте подбирать вместе с γ, а не «α=0.25 потому что в paper».

Эталонный ответ

Focal Loss −α·(1−p_t)^γ·log(p_t): понижает вес легких примеров через (1−p_t)^γ, подсвечивает трудные. γ=2 default, α балансирует класс. Лучше weighted CE на severe imbalanced (object detection, fraud), хуже при сбалансированных данных. Калибровка после — обязательна.

Хочешь увидеть разбор?

Зарегистрируйся бесплатно — откроется развёрнутое решение этой задачи и ещё 4 на выбор.

Зарегистрироваться и увидеть разбор
Уже есть аккаунт? Войти