import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
System.out.println(getCount(720));
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int[] weight = new int[n + 1];
for (int i = 1; i <= n; i++ ) {
weight[i] = in.nextInt();
}
long[] res = new long [n + 1];
HashMap<Integer, List<Integer>> map = new HashMap<>();
for (int i = 0; i < n - 1; i++) {
int a = in.nextInt(), b = in.nextInt();
if (map.containsKey(a)) {
map.get(a).add(b);
}else {
map.put(a, new ArrayList<Integer>(){
{add(b);}
});
}
}
for (int i = n; i >=1; i--) {
List<Integer> list = map.get(i);
if (list == null) {
res[i] = getCount(weight[i]);
}else{
int k = weight[i];
for (int val : list){
k*= weight[val];
}
weight[i] = k;
int count = getCount(k);
res[i] = count;
}
}
long sum = 0l;
for (Long val : res) {
sum+=val;
}
System.out.println(sum);
}
public static int getCount(int n) {
int factor = 1, num = n;
for (int i = 2; i <= Math.sqrt(n); i++) {
int counter = 0;
while (num % i == 0) {
num = num / i;
counter++;
}
factor *= (counter + 1);
}
if (num > 1) {
factor *=2;
}
return factor;
}
}
#网易笔试#