Собесов

Стажировка ML — Tree Barber: минимальная сумма энтропий после стрижки

ML / Data ScienceДеревья решений и энтропияСложнаяSenior

Условие

Дано бинарное дерево решений (классификатор) на 2n - 1 узлах: n листьев, остальные — внутренние. Каждый лист помечен меткой 0 или 1 (бинарная классификация). Внутренние узлы — точки разбиения.

Для каждой вершины v определена стоимость стрижки = энтропия Шеннона меток в поддереве, нормированная на мощность поддерева:

H(v) = -p log p - (1 - p) log (1 - p)

где p — доля единиц в листьях поддерева v. Если p ∈ {0, 1}, то H = 0.

«Стричь» дерево = выбрать поддерево и заменить его на лист с большинством голосов. Минимально возможное число листьев после стрижки — 1 (стричь всё), либо оставить как есть.

Найти набор поддеревьев для стрижки (необязательно вложенных), минимизирующий сумму энтропий стрижек.

Формат ввода

n                    # 1 ≤ n ≤ 300
n - 1 чисел           # для внутренних узлов 1..n-1: индекс родителя
n чисел y_1..y_n     # метки листьев, y_i ∈ {0, 1}

Гарантируется, что:

  • листья имеют номера от n до 2n − 1;
  • каждый внутренний узел имеет двух детей;
  • 1 ≤ p_i < i (родитель меньше по индексу).

Формат вывода

Минимальная сумма энтропий после оптимальной стрижки. Точность 10^-6.

Примеры

Пример 1. n=2, дерево: один корень с двумя листьями 0, 1. Если стрижём весь корень — энтропия H(1/2) = ln 2 ≈ 0.693. Вывод: 0.6931471806.

Пример 2. n=5, дерево с метками 0, 0, 1, 1, 1 так, что есть полностью чистые поддеревья. Вывод: 0.

Пример 3. n=4, метки 0, 1, 1, 0. Вывод: 1.3862943611 (= 2 ln 2).

Решение

Подход — DP на дереве

Для каждого узла v вычислим f(v) — минимальную сумму энтропий стрижки в поддереве v. Два варианта:

  1. Стричь v целиком: стоимость = H(v).
  2. Не стричь v целиком, а рассмотреть детей: f(left) + f(right) (плюс возможные стрижки внутри).

Выбираем минимум:

f(v) = min(H(v), f(left) + f(right))

Для листа: f(v) = 0.

Подсчёт H(v)

Для каждого узла нужно знать (zeros, ones) в поддереве. DFS снизу вверх:

def count(v):
    if is_leaf(v):
        return (1 - y[v], y[v])  # (zeros, ones)
    cl = count(left[v])
    cr = count(right[v])
    return (cl[0] + cr[0], cl[1] + cr[1])

Тогда H(v) = entropy(zeros / total, ones / total).

Реализация

import sys
import math
sys.setrecursionlimit(10**6)
 
def entropy(zeros, ones):
    total = zeros + ones
    if total == 0 or zeros == 0 or ones == 0:
        return 0.0
    p = zeros / total
    q = ones / total
    return -(p * math.log(p) + q * math.log(q))
 
def solve():
    data = sys.stdin.read().split()
    pos = 0
    n = int(data[pos]); pos += 1
    parents = list(map(int, data[pos:pos + n - 1])); pos += n - 1
    labels = list(map(int, data[pos:pos + n])); pos += n
 
    # Узлы 1..n-1 — внутренние; n..2n-1 — листья.
    total_nodes = 2 * n - 1
    children = [[] for _ in range(total_nodes + 1)]
    for i, p in enumerate(parents, start=2):
        # Внутренний узел номер i+1 (если 1-индексация),
        # детали индексации зависят от точного формата.
        children[p].append(i)
 
    # Листья: индексы n..2n-1, метки labels[0..n-1].
    label_at = {}
    for i in range(n):
        label_at[n + i] = labels[i]
 
    counts = [(0, 0)] * (total_nodes + 1)
    H_v = [0.0] * (total_nodes + 1)
    f = [0.0] * (total_nodes + 1)
 
    # DFS итеративно (n до 300, можно и рекурсивно).
    order = []
    stack = [1]
    visited = [False] * (total_nodes + 1)
    while stack:
        v = stack.pop()
        if visited[v]:
            continue
        visited[v] = True
        order.append(v)
        for c in children[v]:
            stack.append(c)
 
    for v in reversed(order):
        if v in label_at:
            y = label_at[v]
            counts[v] = (1 - y, y)
            H_v[v] = 0.0
            f[v] = 0.0
        else:
            zs, os_ = 0, 0
            for c in children[v]:
                cz, co = counts[c]
                zs += cz
                os_ += co
            counts[v] = (zs, os_)
            H_v[v] = entropy(zs, os_) * (zs + os_) / n  # нормировка по корню
            sub_sum = sum(f[c] for c in children[v])
            f[v] = min(H_v[v], sub_sum)
 
    print(f"{f[1]:.10f}")
 
solve()

Тонкость нормировки

Условие говорит «энтропия = ... поддерева». Без нормировки, если корень имеет всех листьев и распределение p, его энтропия — стандартная Шеннона. Подсчёт (zeros, ones) берётся по листьям поддерева. Можно нормировать на размер поддерева или на n (корня) — зависит от точной формулировки в задаче.

В представленной формуле: H(v) = -p ln p − (1 − p) ln(1 − p), без нормировки.

Сложность

O(n) — каждый узел обрабатывается один раз. n ≤ 300 — с большим запасом.

Подводные камни

  1. p ∈ {0, 1} — энтропия 0 (а не NaN от log(0)). Обработать явно.
  2. Натуральный логарифм или двоичный? Условие пишет H = ln 2 ≈ 0.693 — значит натуральный (ln).
  3. Вложенные стрижки. Если стричь корень, дочерние стрижки уже не считаются. DP это учитывает: f(v) либо стрижка v, либо суммарно дети — несовместимы.
  4. Точность 10^-6. Питон math.log достаточно. C++ — std::log double.
  5. «Стричь — заменить на лист с большинством». Не влияет на минимум суммы энтропий — только на интерпретацию (это же чистый аналог классификатора).
  6. Корень имеет индекс 1 или 0. В условии — 1..n-1 внутренние, n..2n-1 — листья, корень = 1.
  7. DP на больших деревьях. Для n = 300 рекурсивное решение работает; для миллионов нужен итеративный обход.

Эталонный ответ

DP на дереве: f(v) = min(H(v), f(left) + f(right)), для листа f = 0. Считаем (zeros, ones) снизу вверх, энтропия Шеннона с натуральным логарифмом, для чистых поддеревьев H = 0. Время O(n).

Хочешь увидеть разбор?

Зарегистрируйся бесплатно — откроется развёрнутое решение этой задачи и ещё 4 на выбор.

Зарегистрироваться и увидеть разбор
Уже есть аккаунт? Войти