Editorial for Colourful Christmas Bauble Burglary


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Author: frexe

Cutting an edge will leave a subtree disconnected from the root, so we calculate the answer for each subtree, and then look for which edge connects into the root of the subtree.

An important observation is that adding another bauble doesn't create a constant increase in value, it depends on how many baubles of that colour we already have. So to accurately calculate the score, we need to know the total counts of each colour in each subtree.

We can do so recursively. Let's say we are trying to calculate the counts of each colour in a container like a std::map (or a dictionary in Python), and the value of the subtree of a node, and we already know these values for each of the children. We can pick a std::map of one of the children, and then add the entries of all of the others to it, updating the total value as we go. Essentially, we are merging the std::mapss of all of the children.

Naively, this would take O(n^2\log n). Merging two std::mapss can take O(n log n), and we do so n times. However, intelligently choosing which std::map to merge into will improve this.

Specifically, it is always optimal to merge the smaller std::map into the larger one. It's pretty easy to see that this will be better, but does it lead to an improvement in time complexity?

Let's say the total merged size is n. The worst case when merging two std::maps occurs when they are equal size. So we have a recurrence

T(n) = 2 * T(n / 2) + n \log (n), which, using Master Theorem, is O(n \log^2 (n)).

Note: On Python, you would have DFS with a stack rather than recursion.

C++ Solution:

#include <bits/stdc++.h>

using namespace std;

#define ll long long

int main() {
    ll t; cin >> t;

    while(t--) {
        ll n; cin >> n;
        vector<ll> cols(n);
        for(int i = 0; i < n; i++) cin >> cols[i];

        // to_node, edge
        vector<vector<pair<ll,ll>>> graph(n);

        for(int i = 0; i < n - 1; i++) {
            ll a,b; cin >> a >> b;
            a--;b--;
            graph[a].push_back({b,i});
            graph[b].push_back({a,i});
        }

        vector<ll> ans_v(n - 1);
        vector<bool> vis(n, false);
        vector<ll> incoming(n,-1);
        stack<ll> stk;
        stk.push(0);
        vector<ll> dfs_order;

        while(!stk.empty()) {
            ll cur = stk.top();
            vis[cur] = true;
            stk.pop();
            for(auto [nxt,edge] : graph[cur]) {
                if (vis[nxt]) continue;
                incoming[nxt] = edge;
                stk.push(nxt);
                dfs_order.push_back(nxt);
            }
        }
        vector<pair<ll,map<ll,ll>*>> recurse(n);
        reverse(dfs_order.begin(), dfs_order.end());

        for(auto cur : dfs_order) {
            vector<map<ll,ll>*> child_maps;

            ll ans = 0;
            map<ll,ll> *best = NULL;
            for(auto [nxt,edg] : graph[cur]) {
                if (edg == incoming[cur]) continue;
                auto [c_ans, child_map] = recurse[nxt];
                child_maps.push_back(child_map);
                if (best == NULL || best->size() < child_map->size()) {
                    best = child_map;
                    ans = c_ans;
                }
            }
            if (best == NULL) {
                best = new map<ll,ll>();
            } else {
                for (auto child_map : child_maps) {
                    if (child_map == best) continue;
                    for(auto [c,a] : *child_map) {
                        auto before = (*best)[c];
                        ans -= before*before;
                        ans += (before + a)*(before + a);
                        (*best)[c] = before + a;
                    }
                    delete child_map;
                }
            }

            auto before = (*best)[cols[cur]];
            ans -= before*before;
            ans += (before+1)*(before+1);
            (*best)[cols[cur]] += 1;
            recurse[cur] = {ans, best};
            ans_v[incoming[cur]] = ans;
        };

        for(auto a : ans_v) cout << a << " ";
        cout << "\n";
    }
}

Python solution:

import sys

def main():
    input = sys.stdin.readline
    t = int(input())
    for _ in range(t):
        n = int(input())
        cols = list(map(int, input().split()))

        # Build adjacency list: (neighbor, edge_index)
        graph = [[] for _ in range(n)]
        for i in range(n - 1):
            a, b = map(int, input().split())
            a -= 1; b -= 1
            graph[a].append((b, i))
            graph[b].append((a, i))

        # DFS order and incoming edge index to each node
        vis = [False] * n
        incoming = [-1] * n
        stack = [0]
        dfs_order = []
        while stack:
            cur = stack.pop()
            if vis[cur]:
                continue
            vis[cur] = True
            for nxt, edge in graph[cur]:
                if not vis[nxt]:
                    incoming[nxt] = edge
                    stack.append(nxt)
                    dfs_order.append(nxt)

        # DSU-on-tree: for each node, store (answer, color_count_map)
        # Initialize with default (0, empty map) for type consistency
        recurse = [(0, {}) for _ in range(n)]
        ans_v = [0] * (n - 1)

        # Process children before parents
        for cur in reversed(dfs_order):
            ans = 0
            best = None
            child_maps = []
            # Identify the largest child map
            for nxt, edge in graph[cur]:
                if edge == incoming[cur]:
                    continue
                c_ans, child_map = recurse[nxt]
                child_maps.append(child_map)
                if best is None or len(child_map) > len(best):
                    best = child_map
                    ans = c_ans

            # Merge smaller maps into best
            if best is None:
                best = {}
            else:
                for child_map in child_maps:
                    if child_map is best:
                        continue
                    for c, cnt in child_map.items():
                        before = best.get(c, 0)
                        ans -= before * before
                        new_val = before + cnt
                        ans += new_val * new_val
                        best[c] = new_val

            # Add current node's color
            color = cols[cur]
            before = best.get(color, 0)
            ans -= before * before
            new_val = before + 1
            ans += new_val * new_val
            best[color] = new_val

            recurse[cur] = (ans, best)
            ans_v[incoming[cur]] = ans

        # Output answers for each edge in input order, including trailing space
        # Join and print with a space after the last element
        print(" ".join(map(str, ans_v)) + " ")

if __name__ == "__main__":
    main()

Comments

There are no comments at the moment.