Solution
- Let us call the sum of all nodes of the subtree of a node $u$ as $S(u)$.
- We are given the array $a$, where $a_i$ is $S(i)$.
- Notice that the $S(u)$ is equal to the sum of the integer written on $u$ and $S(i)$ for all $i$ which is a child of $u$.
- Since we are given $S(u)$ for all nodes, we can perform DFS or BFS on the tree and subtract $S(i)$ from $S(u)$ for all $i$ which is a child of $u$.
Implementation in C++
#include <iostream>
#include <vector>
using namespace std;
vector <int> adj[100000];
int a[100000];
void dfs(int node, int parent) {
// The function subtracts subtree sum of all
// children of the node
for (int i = 0; i < adj[node].size(); i++)
if (adj[node][i] != parent) {
a[node] -= a[adj[node][i]];
dfs(adj[node][i], node);
}
}
int main() {
int t, n, m, i, j, x, y;
cin>>t;
while (t--) {
cin>>n;
for (i = 0; i < n; i++) {
cin>>a[i];
adj[i].clear();
}
for (i = 1; i < n; i++) {
cin>>x>>y;
adj[x-1].push_back(y-1);
adj[y-1].push_back(x-1);
}
dfs(0, 0);
for (i = 0; i < n; i++)
cout<<a[i]<<" ";
cout<<"\n";
}
return 0;
}
Implementation in Java
import java.io.*;
import java.util.*;
public class Solution {
public static ArrayList <Integer> [] adj = new ArrayList[100000];
public static int a[] = new int[100000];
public static void dfs(int node, int parent) {
// The function subtracts subtree sum of all
// children of the node
for (int i = 0; i < adj[node].size(); i++)
if (adj[node].get(i) != parent) {
a[node] -= a[adj[node].get(i)];
dfs(adj[node].get(i), node);
}
}
public static void main(String args[]) {
Scanner scan = new Scanner(System.in);
int t, n, m, i, j, x, y;
t = scan.nextInt();
while (t-- > 0) {
n = scan.nextInt();
for (i = 0; i < n; i++) {
a[i] = scan.nextInt();
adj[i] = new ArrayList <Integer>();
}
for (i = 1; i < n; i++) {
x = scan.nextInt();
y = scan.nextInt();
adj[x-1].add(y-1);
adj[y-1].add(x-1);
}
dfs(0, 0);
for (i = 0; i < n; i++)
System.out.print(a[i]+" ");
System.out.println();
}
}
}
Implementation in Python
a = []
adj = []
def dfs(node, parent):
# The function subtracts subtree sum of all
# children of the node
for i in range(0, len(adj[node])):
if adj[node][i] != parent:
a[node] -= a[adj[node][i]]
dfs(adj[node][i], node)
for i in range(0, 100000):
adj.append([])
t = int(input())
for _ in range(t):
n = int(input())
a = list(map(int, input().split()))
for i in range(0, n):
adj[i] = []
for i in range(1, n):
x, y = map(int, input().split())
adj[x-1].append(y-1)
adj[y-1].append(x-1)
dfs(0, 0)
print(*a)