Editorial for Tree Shifting


Remember to use this editorial only when stuck, and not to copy-paste code from it. Please be respectful to the problem author and editorialist.
Submitting an official solution before solving the problem yourself is a bannable offence.

Author: Loctildore

This solution depends on

  • Segment Tree (segtree) Walk
  • Lowest Common Ancestor (LCA)
  • Heavy-Light decomposition (HLD)

If you don't know any of these techniques, I recommend reading about it on CP-Algorithms or USACO Guide. Furthermore, familiarity with the concept of "centroids" of a tree would help in understanding the solution.

First some observations:

  • When we are considering the solution for a subtree, we only care about paths that have one endpoint inside the subtree and one outside the subtree. All the other paths won't be affected by an operation.
  • For these paths, we call the endpoint inside the subtree _active_. The new root of the subtree (the node that is reconnected to the parent) should be a node that minimises the total distance to all active endpoints.
  • We call these potential new roots _weighted centroids_ of the subtree. These nodes have the property that no adjacent edge can lead to more than half the active endpoints.
  • For every edge, let its _score_ be the number active endpoints below it minus the number of active endpoints above it. Starting from the subtree root, we want to walk down an edge if it's score is positive. If we continue this process until we can't anymore, we will end at a weighted centroid. The total distance we save by rerooting to the weighted centroid is the sum of the score of the edges.

We can try to solve this problem by calculating the answer to a node's children before calculating the answer for the node (basically evaluating the nodes in postorder). The advantage of this approach is we can use information left from evaluating the children to evaluate the answer at the parent quickly. In particular, if we have the scores left from evaluating a child, we can update these scores by subtracting single values from entire subtrees. Of course, this ignores how nodes become active then inactive as we consider subtrees in postorder — we'll fix that later.

We associate an edge's score with it's lower node. We can perform these subtractions quickly by storing the scores in preorder. As subtrees now form consecutive ranges, we can use a segtree to quickly subtract from the scores inside a subtree.

An endpoint becomes active when we evaluate the answer there and becomes inactive when we evaluate the answer at the LCA of it's path. When an endpoint becomes active, we just need to do one subtree subtract operation. When it becomes inactive, we need to do a subtree add and path subtract operation. We can handle the path subtract operation by doing Heavy-Light Decomposition. As HLD can be implemented as a segtree over some preorder (refer to the CP-Algorithms implementation), we can handle the subtree add/subtract operations in the same segtree.

The complexity so far is O(M \cdot \log^2 N), the bottleneck being the HLD operation when each endpoint becomes inactive. Now we have all the scores in an HLD structure, we need to figure out how to walk down positive edges quickly. We can walk down all the positive edges of a heavy path with one segtree walk. After each heavy path walk, we need to consider walking down to a light child. We can find this light child by storing the scores of the children of each node in some O(\log N) max query data structure.

As subtree add/subtract operations affect all children, we don't need to update the scores then, instead we only need to make sure scores are correct relative to the same constant and check the child with the highest relative score. To maintain this, we can store the scores in a multiset and update it when one child is affected by a path subtract operation. The sample solution below uses a priority_queue and lazy update, but it calculates the same thing.

When walking down the tree, we will encounter at most O(\log N) heavy paths. Each segtree walk is O(\log N) and picking (and checking the value of) the light child to go down takes O(\log N) time.

The total complexity of this algorithm is therefore O((N + M) \cdot \log^2 N).

C++ Solution:

#include <bits/stdc++.h>
using namespace std;
// trans rights
#define int long long
#define f first
#define s second
#define endl '\n'
#define all(x) begin(x), end(x)

struct node {
    int l, r;
    node *lft, *rht;
    int sum, mini, offst;
    node(int tl, int tr): l(tl), r(tr), sum(0), mini(0), offst(0) {
        if (l + 1 != r) {
            lft = new node(l, (l + r) / 2);
            rht = new node((l + r) / 2, r);
        }
        else lft = rht = NULL;
    }
    void add(int tl, int tr, int val) {
        if (tr <= tl) return;
        if (tr <= l || r <= tl) return;
        if (tl <= l && r <= tr) {
            offst += val;
            sum += (r - l) * val;
            mini += val;
            return;
        }
        lft->add(tl, tr, val);
        rht->add(tl, tr, val);
        sum = lft->sum + rht->sum + (r - l) * offst;
        mini = min(lft->mini, rht->mini) + offst;
    }
    void clean() {
        if (offst && lft) {
            lft->add(l, r, offst);
            rht->add(l, r, offst);
            offst = 0;
        }
    }
    int qry(int x) {
        if (x < l || x >= r) return 0;
        if (l + 1 == r) return sum;
        clean();
        return lft->qry(x) + rht->qry(x);
    }
    pair<int, int> walk(int x) {
        clean();
        if (r <= x) return {0, r - 1};
        if (x <= l) {
            if (mini > 0) return {sum, r - 1};
        }
        if (l + 1 == r) return {0, l - 1};
        auto lw = lft->walk(x);
        if (lw.s != lft->r - 1) return lw;
        auto rw = rht->walk(x);
        return {lw.f + rw.f, rw.s};
    }
}* root;

int n, q, bans, ans[200007];
int par[200007], jmp[200007][20], head[200007];
vector<int> chd[200007];
priority_queue<pair<int, int>> mchd[200007];
int ldel[200007];
int dep[200007], sz[200007];
int pin[200007], pout[200007], pnxt, rord[200007];
int qadd[200007]; vector<int> qrmv[200007];

void dfs0(int x) {
    jmp[x][0] = par[x];
    for (int i = 1; i < 20; i++) {
        jmp[x][i] = jmp[jmp[x][i - 1]][i - 1];
    }
    for (auto i : chd[x]) {
        dep[i] = dep[x] + 1;
        dfs0(i);
    }
    sz[x] = 1;
    for (auto i : chd[x]) {
        sz[x] += sz[i];
    }
    sort(all(chd[x]), [](int lhs, int rhs) {return sz[lhs] > sz[rhs];});
}
void dfs_hld(int x, int h) {
    head[x] = h; pin[x] = pnxt++;
    rord[pin[x]] = x;
    if (chd[x].size()) dfs_hld(chd[x][0], h);
    for (auto i : chd[x]) if (i != chd[x][0]) {
        dfs_hld(i, i);
    }
    pout[x] = pnxt;
}
int lca(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    for (int i = 19; ~i; i--) {
        if (dep[jmp[y][i]] >= dep[x]) y = jmp[y][i];
    }
    for (int i = 19; ~i; i--) {
        if (jmp[x][i] != jmp[y][i]) {
            x = jmp[x][i];
            y = jmp[y][i];
        }
    }
    return x == y ? x : jmp[x][0];
}
void hldsub(int x, int y) {
    while (head[x] != head[y]) {
        root->add(pin[head[y]], pin[y] + 1, -2);
        ldel[head[y]] += 2;
        y = par[head[y]];
    }
    root->add(pin[x] + 1, pin[y] + 1, -2);
}
int solve(int x) {
    vector<int> v; int sumv = qadd[x];
    for (auto i : chd[x]) v.push_back(solve(i));
    for (auto i : v) sumv += i;
    for (int i = 0; i < chd[x].size(); i++) {
        int c = chd[x][i];
        root->add(pin[c], pin[c] + 1, v[i]);
        root->add(pin[c], pout[c], v[i] - sumv);
        if (i) mchd[x].push({root->qry(pin[c]), c});
    }
    for (auto i : qrmv[x]) {
        sumv--;
        root->add(pin[x] + 1, pout[x], 1);
        hldsub(x, i);
    }
    int cur = x; ans[x] = bans;
    while (true) {
        if (chd[cur].empty()) break;
        if (root->qry(pin[chd[cur][0]]) > 0) {
            cur = chd[cur][0];
        }
        else {
            if (mchd[cur].empty()) break;
            while (ldel[mchd[cur].top().s]) {
                auto tmp = mchd[cur].top();
                tmp.f -= ldel[tmp.s];
                ldel[tmp.s] = 0;
                mchd[cur].push(tmp);
            }
            if (root->qry(pin[mchd[cur].top().s]) > 0) {
                cur = mchd[cur].top().s;
            }
            else break;
        }
        auto tmp = root->walk(pin[cur]);
        ans[x] -= tmp.f;
        cur = rord[tmp.s];
    }
    return sumv;
}
signed main() {
    ios_base::sync_with_stdio(0);
    cin.tie(NULL);
    cin>>n>>q;
    root = new node(0, n);
    for (int i = 1; i < n; i++) {
        cin>>par[i]; par[i]--;
        chd[par[i]].push_back(i);
    }
    dfs0(0);
    dfs_hld(0, 0);
    for (int i = 0; i < q; i++) {
        int a, b; cin>>a>>b; a--;b--;
        int l = lca(a, b);
        qadd[a]++; qadd[b]++;
        qrmv[l].push_back(a); qrmv[l].push_back(b);
        bans += dep[a] + dep[b] - 2 * dep[l];
    }
    solve(0);
    for (int i = 0; i < n; i++) cout<<ans[i]<<endl;
    return 0;
}

Comments

There are no comments at the moment.