# 834. 树中距离之和 (Hard)

## 问题描述

[834. 树中距离之和][link] (Hard)

[link]: https://leetcode.cn/problems/sum-of-distances-in-tree/

给定一个无向、连通的树。树中有 `n` 个标记为 `0...n-1` 的节点以及 `n-1` 条边 。

给定整数 `n` 和数组 `edges` ， `edges[i] = [aᵢ, bᵢ]` 表示树中的节点 `aᵢ` 和 `bᵢ` 之间有一条边。

返回长度为 `n` 的数组 `answer` ，其中 `answer[i]` 是树中第 `i` 个节点与所有其他节点之间的距离之和。

**示例 1:**

![](https://pic-upyun.zwyyy456.tech/smms/2023-12-26-065413.jpg)

```
输入: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 树如图所示。
我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此，answer[0] = 8，以此类推。

```

**示例 2:**

![](https://pic-upyun.zwyyy456.tech/smms/2023-12-26-065415.jpg)

```
输入: n = 1, edges = []
输出: [0]

```

**示例 3:**

![](https://pic-upyun.zwyyy456.tech/smms/2023-12-26-065416.jpg)

```
输入: n = 2, edges = [[1,0]]
输出: [1,1]

```

**提示:**

- `1 <= n <= 3 * 10⁴`
- `edges.length == n - 1`
- `edges[i].length == 2`
- `0 <= aᵢ, bᵢ < n`
- `aᵢ != bᵢ`
- 给定的输入保证为有效的树

## 解题思路

求单个结点（例如 $0$）的距离之和 $dp[0]$，我们可以很容易利用 dfs 以 $O(n)$ 的时间复杂度求出，求 $n$ 个点就是 $O(n^2)$，明显会超时。

但是，不难看出，父结点 $j$ 的 $dp[j]$ 和子结点 $i$ 的 $dp[i]$ 之间存在某种递推关系，即 $dp[i] = dp[j] - cnt[i] + n - cnt[i]$（因为 $i, j$ 直接相连）。

那么，问题就只剩下如何求 `cnt[i]` 了，即在无向图表示的树中，如何求以当前结点为根结点的子树的结点数，参见 [无向图形式组织的树](https://blog.zwyyy456.tech/zh/posts/tech/undirected-graph-tree/)。

## 代码

```cpp
class Solution {
  public:
    int count(vector<vector<int>> &tree, vector<int> &dis, vector<int> &cnt, int pa, int grandpa) {
        int res = 1;
        for (int child : tree[pa]) {
            if (child == grandpa) { // 防止重复遍历，保证 dfs 遍历时的单向性
                continue;
            }
            dis[child] = dis[pa] + 1;
            res += count(tree, dis, cnt, child, pa);
        }
        cnt[pa] = res;
        return res;
    }
    vector<int> sumOfDistancesInTree(int n, vector<vector<int>> &edges) {
        vector<vector<int>> tree(n);
        for (auto &vec : edges) {
            tree[vec[0]].push_back(vec[1]);
            tree[vec[1]].push_back(vec[0]); // 注意，建立无向图要 push_back 两次！
        }
        vector<int> cnt(n);
        vector<int> dp(n);
        vector<int> dis(n); // 表示结点 0 到其他结点的最短距离
        count(tree, dis, cnt, 0, -1);
        for (int i = 0; i < n; ++i) {
            dp[0] += dis[i];
        }
        queue<pair<int, int>> q;
        q.push({0, -1}); // pa, grandpa
        while (!q.empty()) {
            auto [pa, grandpa] = q.front();
            q.pop();
            for (int child : tree[pa]) {
                if (child == grandpa) { // 保证 bfs 遍历时的单向性
                    continue;
                }
                dp[child] = dp[pa] + n - 2 * cnt[child];
                q.push({child, pa});
            }
        }
        return dp;
    }
};
```

