首页 > 试题广场 >

小美的树上染色

[编程题]小美的树上染色
  • 热度指数:2396 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
小美拿到了一棵树,每个节点有一个权值。初始每个节点都是白色。
小美有若干次操作,每次操作可以选择两个相邻的节点,如果它们都是白色且权值的乘积是完全平方数,小美就可以把这两个节点同时染红。
小美想知道,自己最多可以染红多少个节点?

输入描述:
第一行输入一个正整数n,代表节点的数量。
第二行输入n个正整数a_i,代表每个节点的权值。
接下来的n-1行,每行输入两个正整数u,v,代表节点u和节点v有一条边连接。
1\leq n \leq 10^5
1\leq a_i \leq 10^9
1\leq u,v \leq n


输出描述:
输出一个整数,表示最多可以染红的节点数量。
示例1

输入

3
3 3 12
1 2
2 3

输出

2

说明

可以染红第二个和第三个节点。
请注意,此时不能再染红第一个和第二个节点,因为第二个节点已经被染红。
因此,最多染红 2 个节点。

#include <cmath>
#include <iostream>
#include <vector>
#include <map>
using namespace std;

vector<vector<int>> G;
vector<int> val(100005, 0);
vector<int> dp(100005, 0);

bool isQrt(int n) {
    int x = sqrt(n);
    if (x*x != n) return false;
    return true;
}

int dfs(int root, int fa) {
    if (dp[root] > 0) return dp[root];
    // 不选root
    int notSelect = 0;
    for (auto u : G[root]) if (u != fa) {
        notSelect += dfs(u, root);
    }
    // 选root
    int selected = 0;
    for (auto u : G[root]) if (u != fa) {
        if (isQrt(val[root] * val[u])) {
            int tmp = 2;
            for (auto v : G[u]) if (v != root) tmp += dfs(v, u);
            selected = max(selected, notSelect-dfs(u, root)+tmp);
        }
    }
    // if (selected > 0) selected++;

    dp[root] = max(notSelect, selected);

    return dp[root];
}

int main() {
    int n;
    int a, b;
    cin >> n;
    G.resize(n + 5);
    for (int i = 0; i < n; ++i) { // 注意 while 处理多个 case
        cin >> a;
        val[i+1] = a;
    }

    int root = -1;
    --n;
    while (n--) {
        cin >> a >> b;
        G[a].push_back(b);
        G[b].push_back(a);
        if (root == -1) root = a;
    }

    // cout << "root = " << root<<endl;
    int ans = dfs(root, -1);

    cout << ans  <<endl;
}


// 64 位输出请用 printf("%lld")

发表于 2023-08-17 18:22:43 回复(3)
# 【备战秋招】每日一题:2023.08.12-美团机试-第五题-树上染色_塔子哥学算法的博客-CSDN博客 import math

n = int(input())
val = list(map(int, input().strip().split()))

# 邻接列表
matrix = [[] for _ in range(n)]
for _ in range(n - 1):
    i, j = list(map(int, input().strip().split()))
    matrix[i - 1].append(j - 1)
    matrix[j - 1].append(i - 1)

dp = [[0] * 2 for _ in range(n)]

def dfs(cur, fa):
    # 第cur个节点不染色
    for i in matrix[cur]:
        if i == fa:continue
        dfs(i, cur)
        dp[cur][0] += max(dp[i][0], dp[i][1])

    # 第cur个节点染色
    for i in matrix[cur]:
        if i == fa:continue
        tmp = int(math.sqrt(val[cur] * val[i]))
        if tmp * tmp != val[cur] * val[i]:continue
        dp[cur][1] = max(dp[cur][1], dp[cur][0] - max(dp[i][0], dp[i][1]) + dp[i][0] + 2)

dfs(0, -1)
print(max(dp[0][0], dp[0][1]))

发表于 2023-08-26 10:29:53 回复(2)
#include <bits/stdc++.h>
using namespace std;

using i64 = long long;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);

    int n; cin >> n;
    int w[n + 1];
    vector<int> g[n + 1];
    for (int i = 1; i <= n; ++i)
        cin >> w[i];
    for (int i = 0; i < n - 1; ++i) {
        int a, b; cin >> a >> b;
        g[a].push_back(b);
        g[b].push_back(a);
    }

    function<bool(int, int)> check = [&](int u, int v) -> bool {
        i64 x = (i64)w[u] * w[v];
        i64 y = sqrtl(x);
        return y * y == x;
    };

    int dp[n + 1][2];
    memset(dp, -1, sizeof dp);
    function<int(int, int, int)> dfs = [&](int u, int state, int fa) -> int {
        if (dp[u][state] != -1) return dp[u][state];
        int res = 0, ans = 0;
        // 假如父亲与儿子都不染色
        for (int v : g[u])
            if (v != fa)
                ans += dfs(v, false, u);
        // 如果父亲已经被染过色了,那直接返回
        if (state) return dp[u][state] = ans;
        //    如果父没有被染
        // 1. 首先还是可以选择一个儿子都不染
        // 2. 选一个儿子与父亲染色,其余儿子都不染。
        for (int v : g[u])
            if (v != fa and check(u, v))
                // 这里 dfs(v, true) + 2 是染色; ans - dfs(v, false, u) 是"其余儿子都不染",要从中剔除染色的儿子。
                res = max(res, dfs(v, true, u) + 2 + ans - dfs(v, false, u));
        return dp[u][state] = max(res, ans);    // 从 1. 和 2. 中取较大者
    };

    cout << dfs(1, false, -1);

    return 0;
}


编辑于 2023-08-18 15:10:08 回复(0)
从叶节点考虑,拓扑排序
#include<iostream>
#include<vector>
#include<queue>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;

inline bool check(int x,int y){
    ll res=(ll)x*y,z=(ll)sqrt(res);
    return z*z==res;
}

int solve(int n){
    int u,v;
    vector<int> value(n+1),degrees(n+1),flag(n+1),color(n+1);
    vector<vector<int>> edges(n+1);
    for(int i=1;i<=n;i++){
        cin>>value[i];
    }
    for(int i=1;i<n;i++){
        cin>>u>>v;
        edges[u].push_back(v);
        edges[v].push_back(u);
        degrees[u]++,degrees[v]++;
    }
    queue<int> q;
    for(int i=1;i<=n;i++){
        if(degrees[i]==1){
            q.push(i);
        }
    }
    int ans=0;
    while(!q.empty()){
        int top=q.front();
        q.pop();
        flag[top]=1;
        for(auto& p:edges[top]){
            if(!flag[p]){
                if(!color[top]&&!color[p]&&check(value[top],value[p])){
                    color[top]=1,color[p]=1;
                    ans+=2;
                }
                if(--degrees[p]==1){
                    q.push(p);
                }
            }
        }
    }
    return ans;
}

//优先考虑叶节点,拓扑排序
int main() {
    int n;
    while(cin>>n){
        cout<<solve(n)<<endl;
    }
}
// 64 位输出请用 printf("%lld")


发表于 2023-08-16 21:17:14 回复(0)
import java.util.Scanner;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    public static void main(String[] args) {
        Scanner in = new Scanner(System.in);
        int n = in.nextInt();
        int [][]nums = new int[n][2];
        for (int i = 0; i < n; i++) {//赋节点权值
            nums[i][0] = in.nextInt();//节点值
            nums[i][1] = 0;//0白色,1红***r />         }
        int [][]conn = new int[n - 1][2];//创建节点关系二维数组
        for (int i = 0; i < n - 1;i++) { //总共n-1行节点关系 0 - n-2    (n-2+1=n-1)
            conn[i][0] = in.nextInt();
            conn[i][1] = in.nextInt();
            conn[i][0] -= 1;//对应索引要减一
            conn[i][1] -= 1;
        }
        int sum = 0; //红色节点个数
        int i = n - 1;
        while (--i >= 0) {
            if (nums[conn[i][0]][1] == 0 && nums[conn[i][1]][1] == 0) {
                if (Math.pow((int)Math.sqrt(nums[conn[i][0]][0] * nums[conn[i][1]][0]),
                             2) == (nums[conn[i][0]][0] * nums[conn[i][1]][0])) {
                    nums[conn[i][0]][1] = 1; //节点值设为1,红***r />                     nums[conn[i][1]][1] = 1;
                    sum += 2;
                }
            }
        }
        System.out.println(sum);
    }
}
发表于 2023-10-05 22:34:48 回复(2)
可以用贪心的思路解决,优先处理叶子节点更容易获得最大值。这里使用度来判断是否为叶子结点:
#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
#include <queue>
#include <unordered_set>
#include <cmath>

using namespace std;

bool check(long n) {
    long t = sqrt(n);
    return t * t == n;
}

void test() {
    int n, w, u, v;
    while (cin >> n) {
		vector<int> weight(n);
		vector<bool> color(n);
        for (int i = 0; i < n; ++i) {
            cin >> w;
            weight[i] = w;
            color[i] = true;  // true 表示可以涂色
        }
        vector<unordered_set<int>> edge(n, unordered_set<int>());
        vector<int> degree(n);
        for (int i = 0; i < n - 1; ++i) {
            cin >> u >> v;
            --u, --v;
            edge[u].insert(v);
            edge[v].insert(u);
            ++degree[u];
            ++degree[v];
        }
        queue<int> q;
        for (int i = 0; i < n; ++i) {
            if (degree[i] == 1) q.push(i);
        }
        int ans = 0;
        while (!q.empty()) {
            int sz = q.size();
            for (int i = 0; i < sz; ++i) {
                int u = q.front();
                q.pop();
                if (degree[u] == 0) continue;
                v = *edge[u].begin();
                --degree[u];
                --degree[v];
                edge[v].erase(u);
                if (!color[v] || !color[u]) continue;
                if (check(1LL * weight[u] * weight[v])) {
                    ans += 2;
                    color[v] = false;
                }
            }
            for (int i = 0; i < n; ++i) {
                if (degree[i] == 1) q.push(i);
            }
        }
        cout << ans << endl;
    }
}

int main() {
    test();
}

 
编辑于 2024-03-30 00:16:52 回复(0)
#include <bits/stdc++.h>
using namespace std;

int n;
vector<int> val;
vector<vector<int>> son;

vector<int> ansColor;
vector<int> ansNoColor;

bool judge(int i, int j) {
    int tmp = val[i] * val[j];
    int root = sqrt(tmp);
    return root * root == tmp;
}

void dfs(int cur, int far) {
    if (son[cur].size() == 1 && son[cur][0] == far) return;

    int tmpNoColor = 0;
    for (int i : son[cur]) {
        if (i == far) continue;
        dfs(i, cur);
        tmpNoColor += max(ansColor[i], ansNoColor[i]);
    }
    ansNoColor[cur] = tmpNoColor;

    for (int i : son[cur]) {
        if (i == far) continue;
        if (judge(cur, i)) {
            int tmp = tmpNoColor - max(ansColor[i], ansNoColor[i]);
            tmp += 2;
            tmp += ansNoColor[i];
            ansColor[cur] = max(ansColor[cur], tmp);
        }
    }
}

int main() {
    cin >> n;
    ansColor = vector<int>(n, 0);
    ansNoColor = vector<int>(n, 0);
    for (int i = 0; i < n; ++i) {
        int tmp; cin >> tmp;
        val.push_back(tmp);
    }
    son = vector<vector<int>>(n, vector<int>());
    for (int i = 0; i < n - 1; ++i) {
        int tmp1, tmp2; cin >> tmp1 >> tmp2;
        tmp1--; tmp2--;
        son[tmp1].push_back(tmp2);
        son[tmp2].push_back(tmp1);
    }
    dfs(0, -1);
    cout << max(ansColor[0], ansNoColor[0]) << endl;
}
编辑于 2024-03-11 23:14:10 回复(0)
#include <iostream>
#include<vector>
#include<cmath>
using namespace std;
const int N = 1e5 + 10;
vector<int>g[N];
int a[N];
int color[N];
int ans;
bool check(long long x,long long y){
    x *= y;
    long long z = sqrt(x);
    if(z * z == x)return 1;
    return 0;
}
void dfs(int u,int fa){
    for(int j : g[u]){
        if(j == fa)continue;
        dfs(j,u);
        if(color[j] || color[u])continue;
        if(check(a[u],a[j])){
            ans += 2;
            color[u] = color[j] = 1;
        }
    }
}
int main() {
    int n;
    scanf("%d",&n);
    for(int i =  1; i <= n ;i ++){
        scanf("%d",&a[i]);
    }
    for(int i = 1 ; i < n ;i ++){
        int u,v;
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
    cout << ans;
    return 0;
}
// 64 位输出请用 printf("%lld")

貌似直接贪心就可以,从儿子节点开始。 不知道是不是数据水了
发表于 2023-09-23 09:11:33 回复(2)
#include <iostream>
#include <vector>
#include <algorithm>
#include <functional>
#include <cmath>
using namespace std;

int cmp(pair<int,int>a,pair<int,int>b){
    if (a.first!=b.first) return a.first<b.first;
    return a.second<b.second;
}

int main() {
    int n,t,a,b;
    vector<int> value;
    vector<pair<int,int>> edge;
    cin>>n;
    vector<vector<int>> f(n,vector<int>(2,-1));
    for (int i=0;i<n;i++) {
        cin>>t;
        value.push_back(t);
    }
    for (int i=0;i<n-1;i++) {
        cin>>a>>b;
        edge.push_back(make_pair(a-1,b-1));
    }
    sort(edge.begin(),edge.end(),cmp);
    vector<int> start;
    int now = 0;
    for (int i=0;i<n-1;i++){
        if (now>=n) break;
        while (edge[i].first>=now){
            start.push_back(i);
            now++;
        }
    }
    while (now<n){
        start.push_back(n-2);
        now ++;
    }
    // for (int i=0;i<n-1;i++){
    //     cout<<edge[i].first<<" "<<edge[i].second<<endl;
    // }
    // for (int i=0;i<n;i++){
    //     cout<<start[i]<<endl;
    // }
    function<int(int,int)> dfs=[&](int root,int fa)->int{
        int l;
        l = start[root];
        f[root][0] = 0;
        f[root][1] = 0;
        int son,vl;
        int last = 0;
        bool pick=false;
        while (l<edge.size()&&edge[l].first==root){
            son = edge[l].second;
            if (son==fa){
                l++;
                continue;
            }
            vl = value[root]*value[son];
            dfs(son,root);
            if (pow((int)sqrt(vl),2)==vl){
                // cout<<"value: "<<vl<<" "<<sqrt(vl)<<" "<<value[root]<<" "<<value[son]<<endl;
                // cout<<root<<" "<<son<<endl;
                if (last<=f[son][0]&&f[son][0]+2-last>f[son][1]){
                    f[root][1]+=f[son][0]-last;
                    last = f[son][0];
                    pick = true;
                }
                else{
                    f[root][1]+=max(f[son][1],f[son][0]);
                }
            }
            else{
                f[root][1]+=max(f[son][0],f[son][1]);
            }
            // f[root][0] = max(f[root][0],max(f[son][0],f[son][1]));
            f[root][0] += max(f[son][0],f[son][1]);
            l++;
        }
        if (pick) f[root][1]+=2;
        return 1;
    };
    dfs(0,-1);
    cout<<max(f[0][0],f[0][1])<<endl;
}
// 64 位输出请用 printf("%lld")


发表于 2023-08-22 20:35:31 回复(0)