Собесов

khangich: batch normalization

ML / Data ScienceНейросетиСредняяMiddle

Условие

Что делает batch normalization? Чем отличается от layer normalization? Когда какая нормализация подходит?

Решение

Подход

Batch Normalization (BN).

Для каждого нейрона в слое нормируем активации по батчу:

  1. Считаем μ_B, σ²_B по батчу.
  2. x̂ = (x − μ_B) / sqrt(σ²_B + ε).
  3. Аффинно преобразуем: y = γ x̂ + β, где γ, β — обучаемые.

На train используются батчевые статистики, на inference — скользящие средние, накопленные в процессе обучения.

Зачем.

  • Стабилизирует распределение активаций → можно использовать большие learning rate.
  • Ускоряет сходимость.
  • Действует как лёгкая регуляризация (шум батча).

Layer Normalization (LN).

Нормируем активации одного примера по фичам (а не по батчу). Не зависит от batch size, поэтому хорошо работает в RNN/Transformer, где батчи могут быть маленькими или последовательной длины.

Когда что.

Сценарий BN LN
CNN, классификация изображений редко
Transformer / RNN редко
Маленький батч (batch=1, online) ломается
GAN часто LN или Instance Norm
import torch.nn as nn
 
# CNN со BN
conv = nn.Sequential(
    nn.Conv2d(3, 64, 3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
)
 
# Transformer-блок с LN
norm = nn.LayerNorm(d_model)

Подводные камни

  1. model.train() / model.eval() — BN ведёт себя по-разному; забыть переключить = сломанный inference.
  2. Маленький batch (1-4): BN-статистики шумные, метрика нестабильна. Решения: LN, GroupNorm, sync-BN на нескольких GPU.
  3. Combining BN + dropout требует осторожности по порядку слоёв.
  4. Distillation/transfer learning. При файнтюнинге часто замораживают BN-статистики, чтобы не сместить распределение под новые данные.
  5. Распределённое обучение. BN, посчитанный на одном GPU из 8, видит только 1/8 батча. SyncBN объединяет статистики между GPU — иначе градиент шумный.

Эталонный ответ

BN нормирует активации по батчу, LN — по фичам одного примера. BN стандарт для CNN; LN — для Transformer/RNN и маленьких батчей. На inference у BN используются скользящие средние, накопленные на train, и обязательно model.eval().

Хочешь увидеть разбор?

Зарегистрируйся бесплатно — откроется развёрнутое решение этой задачи и ещё 4 на выбор.

Зарегистрироваться и увидеть разбор
Уже есть аккаунт? Войти