def solve():
# Read input
n = int(input())
# Node values, 1-indexed
w = [0] + list(map(int, input().split()))
# Create adjacency list representation of the tree
g = [[] for _ in range(n + 1)]
for _ in range(n - 1):
u, v = map(int, input().split())
g[u].append(v)
g[v].append(u)
# Track visited/colored nodes
vis = [0] * (n + 1)
ans = 0
def is_perfect_square(num):
sqrt_num = int(num ** 0.5)
return sqrt_num * sqrt_num == num
def dfs(u, parent):
nonlocal ans
# Process children first
for v in g[u]:
if v != parent:
dfs(v, u)
# Check if current node and parent can be colored
if parent != 0 and not vis[u] and not vis[parent]:
product = w[u] * w[parent]
if is_perfect_square(product):
vis[u] = vis[parent] = 1
ans += 2
# Start DFS from root (node 1)
dfs(1, 0)
print(ans)
solve()