Editorial for Colourful Christmas Bauble Burglary
Submitting an official solution before solving the problem yourself is a bannable offence.
Author:
Cutting an edge will leave a subtree disconnected from the root, so we calculate the answer for each subtree, and then look for which edge connects into the root of the subtree.
An important observation is that adding another bauble doesn't create a constant increase in value, it depends on how many baubles of that colour we already have. So to accurately calculate the score, we need to know the total counts of each colour in each subtree.
We can do so recursively.
Let's say we are trying to calculate the counts of each colour in a container like a std::map
(or a dictionary in Python),
and the value of the subtree of a node, and we already know these values for each of the children.
We can pick a std::map
of one of the children, and then add the entries of all of the others to it,
updating the total value as we go. Essentially, we are merging the std::maps
s of all of the children.
Naively, this would take . Merging two
std::maps
s can take O(n log n), and we do so n times.
However, intelligently choosing which std::map
to merge into will improve this.
Specifically, it is always optimal to merge the smaller std::map
into the larger one. It's pretty
easy to see that this will be better, but does it lead to an improvement in time complexity?
Let's say the total merged size is n. The worst case when merging two std::map
s occurs when they
are equal size. So we have a recurrence
, which, using Master Theorem, is
.
Note: On Python, you would have DFS with a stack rather than recursion.
C++ Solution:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
int main() {
ll t; cin >> t;
while(t--) {
ll n; cin >> n;
vector<ll> cols(n);
for(int i = 0; i < n; i++) cin >> cols[i];
// to_node, edge
vector<vector<pair<ll,ll>>> graph(n);
for(int i = 0; i < n - 1; i++) {
ll a,b; cin >> a >> b;
a--;b--;
graph[a].push_back({b,i});
graph[b].push_back({a,i});
}
vector<ll> ans_v(n - 1);
vector<bool> vis(n, false);
vector<ll> incoming(n,-1);
stack<ll> stk;
stk.push(0);
vector<ll> dfs_order;
while(!stk.empty()) {
ll cur = stk.top();
vis[cur] = true;
stk.pop();
for(auto [nxt,edge] : graph[cur]) {
if (vis[nxt]) continue;
incoming[nxt] = edge;
stk.push(nxt);
dfs_order.push_back(nxt);
}
}
vector<pair<ll,map<ll,ll>*>> recurse(n);
reverse(dfs_order.begin(), dfs_order.end());
for(auto cur : dfs_order) {
vector<map<ll,ll>*> child_maps;
ll ans = 0;
map<ll,ll> *best = NULL;
for(auto [nxt,edg] : graph[cur]) {
if (edg == incoming[cur]) continue;
auto [c_ans, child_map] = recurse[nxt];
child_maps.push_back(child_map);
if (best == NULL || best->size() < child_map->size()) {
best = child_map;
ans = c_ans;
}
}
if (best == NULL) {
best = new map<ll,ll>();
} else {
for (auto child_map : child_maps) {
if (child_map == best) continue;
for(auto [c,a] : *child_map) {
auto before = (*best)[c];
ans -= before*before;
ans += (before + a)*(before + a);
(*best)[c] = before + a;
}
delete child_map;
}
}
auto before = (*best)[cols[cur]];
ans -= before*before;
ans += (before+1)*(before+1);
(*best)[cols[cur]] += 1;
recurse[cur] = {ans, best};
ans_v[incoming[cur]] = ans;
};
for(auto a : ans_v) cout << a << " ";
cout << "\n";
}
}
Python solution:
import sys
def main():
input = sys.stdin.readline
t = int(input())
for _ in range(t):
n = int(input())
cols = list(map(int, input().split()))
# Build adjacency list: (neighbor, edge_index)
graph = [[] for _ in range(n)]
for i in range(n - 1):
a, b = map(int, input().split())
a -= 1; b -= 1
graph[a].append((b, i))
graph[b].append((a, i))
# DFS order and incoming edge index to each node
vis = [False] * n
incoming = [-1] * n
stack = [0]
dfs_order = []
while stack:
cur = stack.pop()
if vis[cur]:
continue
vis[cur] = True
for nxt, edge in graph[cur]:
if not vis[nxt]:
incoming[nxt] = edge
stack.append(nxt)
dfs_order.append(nxt)
# DSU-on-tree: for each node, store (answer, color_count_map)
# Initialize with default (0, empty map) for type consistency
recurse = [(0, {}) for _ in range(n)]
ans_v = [0] * (n - 1)
# Process children before parents
for cur in reversed(dfs_order):
ans = 0
best = None
child_maps = []
# Identify the largest child map
for nxt, edge in graph[cur]:
if edge == incoming[cur]:
continue
c_ans, child_map = recurse[nxt]
child_maps.append(child_map)
if best is None or len(child_map) > len(best):
best = child_map
ans = c_ans
# Merge smaller maps into best
if best is None:
best = {}
else:
for child_map in child_maps:
if child_map is best:
continue
for c, cnt in child_map.items():
before = best.get(c, 0)
ans -= before * before
new_val = before + cnt
ans += new_val * new_val
best[c] = new_val
# Add current node's color
color = cols[cur]
before = best.get(color, 0)
ans -= before * before
new_val = before + 1
ans += new_val * new_val
best[color] = new_val
recurse[cur] = (ans, best)
ans_v[incoming[cur]] = ans
# Output answers for each edge in input order, including trailing space
# Join and print with a space after the last element
print(" ".join(map(str, ans_v)) + " ")
if __name__ == "__main__":
main()
Comments