Условие
Объясните Focal Loss: что это, чем отличается от weighted cross-entropy, какие гиперпараметры.
Решение
Подход
Focal Loss (Lin et al., 2017 — RetinaNet):
FL(p_t) = −α_t · (1 − p_t)^γ · log(p_t)
где p_t = p если y=1 иначе 1−p. Множитель (1−p_t)^γ понижает вес легких примеров (с уверенным правильным предсказанием) и подсвечивает трудные (близкие к границе).
При γ=0 → обычный weighted CE. При γ=2 (default) → классические настройки RetinaNet.
α_t — балансирующий фактор класса (как class_weight).
Сравнение с weighted CE
| Случай | Weighted CE | Focal Loss |
|---|---|---|
| Easy negative (p=0.99, y=0) | loss ≈ 0.01·w_0 | (1−0.99)^γ·0.01·w_0 → почти 0 |
| Hard negative (p=0.6, y=0) | 0.92·w_0 | (1−0.4)^γ·0.92·w_0 = 0.36·w_0 |
| Hard positive (p=0.3, y=1) | 1.2·w_1 | (1−0.3)^γ·1.2·w_1 = 0.59·w_1 |
Focal сужает фокус на трудных примерах, эффективно «выключая» простые negatives которых много при imbalanced.
Реализация (PyTorch)
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
# binary
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
loss = -alpha_t * (1 - p_t) ** self.gamma * torch.log(p_t.clamp(min=1e-8))
return loss.mean()Когда
- Object detection с тысячами anchor boxes, большинство — easy background.
- Сильно imbalanced классификация (1:1000+).
- Когда class_weight недостаточно эффективен.
Когда НЕ нужно
- Сбалансированные данные: γ=0 = обычный CE.
- Малая выборка: focal loss с γ=2 «выключает» большую часть примеров → недообучение.
- Когда нужен calibrated
predict_proba— focal сильнее искажает probability.
Подводные камни
γслишком большой (5+) → почти все примеры с уверенной prediction игнорируются → стагнация loss.- Initialization bias: для bounded class в начале training модель predicts 0.5 — focal loss выше; иногда полезен bias initialization (Lin et al.).
- Focal loss с softmax (multiclass) — менее стандартная форма; обычно binary one-vs-rest.
- Калибровка после focal loss сложнее, чем после CE; нужно более длинное calibration set.
αбалансирует классы — не забывайте подбирать вместе с γ, а не «α=0.25 потому что в paper».
Эталонный ответ
Focal Loss −α·(1−p_t)^γ·log(p_t): понижает вес легких примеров через (1−p_t)^γ, подсвечивает трудные. γ=2 default, α балансирует класс. Лучше weighted CE на severe imbalanced (object detection, fraud), хуже при сбалансированных данных. Калибровка после — обязательна.