首页 > 试题广场 >

小欧皇

[编程题]小欧皇
  • 热度指数:936 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
小欧正在扮演一个中世纪的皇帝。地图上有n个城市,其中有m条道路,每条道路连接了两个城市。
小欧占领了其中一些城市。如果两个城市可以通过若干条道路互相到达,且这些道路经过的城市都是小欧占领的,那么这两个城市之间就可以通过经商获得收益1。请注意,每两个城市之间的收益只会被计算一次。
现在,小欧准备占领一个未被占领的城市,使得总收益最大化。你能帮帮她吗?

输入描述:
第一行输入两个正整数nm,代表城市数量和道路数量。
第二行输入一个长度为n的 01 串。第i个字符为'0'代表小欧未占领该城市,'1'代表小欧占领了该城市。
接下来的m行,每行输入两个正整数uv,代表城市u和城市v有一条道路连接。
1\leq n,m \leq 10^5
u \neq v


输出描述:
输出两个整数,第一个整数代表占领的城市编号,第二个整数代表占领后的收益。
请保证收益的最大化。如果有多种方案收益最大,请输出占领编号最小的城市。
示例1

输入

5 5
01010
1 2
1 3
1 4
4 5
1 5

输出

1 3

说明

占领 1 号城市后,总收益为 3。
1 号城市和 2 号城市经商,1 号城市和 4 号城市经商,2 号城市和 4 号城市经商。

from collections import defaultdict

def solve(n, m, occupied, roads):
    # 创建邻接表
    graph = defaultdict(list)
    for u, v in roads:
        graph[u-1].append(v-1)
        graph[v-1].append(u-1)
    
    # 使用并查集来跟踪连通分量
    parent = list(range(n))
    size = [1 if occupied[i] else 0 for i in range(n)]

    def find(x):
        if parent[x] != x:
            parent[x] = find(parent[x])
        return parent[x]

    def union(x, y):
        px, py = find(x), find(y)
        if px != py:
            if size[px] < size[py]:
                px, py = py, px
            parent[py] = px
            size[px] += size[py]

    # 初始化连通分量
    for i in range(n):
        if occupied[i]:
            for j in graph[i]:
                if occupied[j]:
                    union(i, j)

    # 计算当前收益
    current_profit = sum(size[i] * (size[i] - 1) // 2 for i in range(n) if i == find(i) and size[i] > 0)

    max_profit_increase = 0
    best_city = 0

    # 尝试占领每个未占领的城市
    for city in range(n):
        if not occupied[city]:
            connected_components = set()
            total_size = 1  # 包括新占领的城市

            for neighbor in graph[city]:
                if occupied[neighbor]:
                    root = find(neighbor)
                    if root not in connected_components:
                        connected_components.add(root)
                        total_size += size[root]

            new_profit = total_size * (total_size - 1) // 2
            old_profit = sum(size[root] * (size[root] - 1) // 2 for root in connected_components)
            profit_increase = new_profit - old_profit

            if profit_increase > max_profit_increase:
                max_profit_increase = profit_increase
                best_city = city + 1

    return best_city, current_profit + max_profit_increase

# 读取输入
n, m = map(int, input().split())
occupied = [c == '1' for c in input().strip()]
roads = [tuple(map(int, input().split())) for _ in range(m)]

# 解决问题并输出结果
city, profit = solve(n, m, occupied, roads)
print(f"{city} {profit}")

发表于 2024-07-16 19:53:49 回复(0)
补一个C++代码
#include "bits/stdc++.h"

using namespace std;

struct dsu {
    vector<int> _pa, _size;
    explicit dsu(int size): _pa(size), _size(size, 1) {iota(_pa.begin(), _pa.end(), 0);}
    int find(int x) {return _pa[x] == x ? x : (_pa[x] = find(_pa[x])); }
    void unite(int x, int y) {
        x = find(x);
        y = find(y);
        if (x == y) return;
        if (_size[x] < _size[y]) swap(x, y);
        _pa[y] = x;
        _size[x] += _size[y];
    }
};

vector<bool> occupied;
vector<vector<int>> graph;

int main() {
    int n, m;
    cin >> n >> m;
    graph.resize(n);
    occupied.resize(n);
    string s;
    cin >> s;
    for (int i = 0; i < n; i++) {
        occupied[i] = (s[i] == '1');
    }
    int u, v;
    for (int i = 0; i < m; i++) {
        cin >> u >> v;
        u--, v--;
        graph[u].push_back(v);
        graph[v].push_back(u);
    }

    dsu uu(n);
    for (int i = 0; i < n; i++) {
        if (occupied[i]) {
            for (const int j: graph[i]) {
                if (occupied[j]) {
                    uu.unite(i, j);
                }
            }
        }
    }
    int cur_profit = 0;
    for (int i = 0; i < n; i++) {
        if (occupied[i] && uu._pa[i] == i) {
            cur_profit += uu._size[i] * (uu._size[i] - 1) / 2;
        }
    }

    int max_profit_inc = 0;
    int city = 0;
    for (int i = 0; i < n; i++) {
        if (occupied[i]) continue;
        int total_size = 1;
        set<int> connected_components;
        for (const int j: graph[i]) {
            if (occupied[j]) {
                connected_components.insert(uu.find(j));
            }
        }
        int old_profit = 0;
        for (const int j: connected_components) {
            old_profit += uu._size[j] * (uu._size[j] - 1) / 2;
            total_size += uu._size[j];
        }
        int new_profit = total_size * (total_size - 1) / 2;
        int profit_inc = new_profit - old_profit;
        if (max_profit_inc < profit_inc) {
            max_profit_inc = profit_inc;
            city = i + 1;
        }
    }

    cout << city << " " << cur_profit + max_profit_inc;
    return 0;
}


编辑于 2024-08-20 20:56:04 回复(1)
#include <iostream>
#include <vector>
#include <list>
using namespace std;
 
void visit(int group_number, string & cities, vector <list<int>> & graph, int i, vector <int> & group, vector <int> & group_mass){
    if (group[i])
         return;
    group[i] = group_number;
    if (cities[i-1]=='1')
    {
        group_mass[group_number]++;
    }
    for (int adjacent:graph[i])
        if (cities[adjacent-1] == '1')
            visit(group_number, cities, graph, adjacent, group, group_mass);
}
int main() {
    int n, m;
    cin>>n>>m;
    string cities;
    cin>>cities;
    vector <int> income_city(n+1);
    vector <list<int>> graph(n+1);
    vector <int> group(n+1);
    vector <int> group_mass(n+1);
    vector <int> st(n+1);
    for(int i=0;i<=n;i++)
    {
        income_city[i]=0;
        group[i]=0;
        group_mass[i]=0;
        st[i] = 0;
    }
    int a,b;
    while (cin >> a >> b) { // 注意 while 处理多个 case
        graph[a].push_back(b);
        graph[b].push_back(a);
    }
    int group_number = 1;
    for(int i=1;i<=n;i++)
    {
        if (group[i] || (cities[i-1]=='0'))
            continue;
        visit(group_number, cities, graph, i, group, group_mass);
        group_number++;
         
    }
     
    int max = 0;
    int max_index = 0;
     
    for(int i=1;i<=n;i++)
    {
        int self_degree = 0;
        int sum = 0;
        int g = 0;
        int deduce = 0;
        if(cities[i-1]=='1')
            continue;
        else
            for (int adjacent:graph[i])
            {
                if(cities[adjacent-1]=='1')
                {
                    g = group[adjacent];
                    if(!st[g])
                    {
                        st[g] = 1;
                        sum += group_mass[g];
                        deduce += group_mass[g] * (group_mass[g]-1)/2;
                    }
                }
            }
            for (int adjacent:graph[i])
            {
                g = group[adjacent];
                st[g] = 0;
            }
        if(sum*(sum+1)/2 - deduce> max)
        {
            max=sum*(sum+1)/2 - deduce;
            max_index=i;
        }
    }
    int total = 0;
    for (int i=1;i<=n;i++)
    {
        if(!st[i])
        {
            total = total + group_mass[i] * (group_mass[i]-1)/2;
        }
    }
    cout << max_index << ' ' << max + total;
}
 
// 64 位输出请用 printf("%lld"
用到的中间变量有点多。这要怎么在真的考试的时候debug出来啊,真的难绷。
发表于 2024-10-28 13:24:04 回复(0)
  • 观察1:处于同一个联通分量当中的两两节点都可以做生意。假设某个联通分量大小为k,则联通分量内对答案的贡献为C(k, 2)。
  • 观察2:考虑一个尚未占领的点占领后对答案的贡献。假设该点叫u,则该点的贡献为所有与u邻接的已经被占领的点的联通分量大小之和(每个被占领的点都可以和u做生意)与 已经被占领的联通分量之间的大小乘积之和(任意两个联通分量间此时均因为u的加入而被联通,所以可以互相做生意,用乘法原理计算即可)
算法:
  1. 并查集维护所有被占领点的联通分量与其大小
  2. 将被占领的联通分量内分别按观察1算贡献
  3. 枚举要占领的点,使用观察2算贡献
import java.util.*;

// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
    static int[] fa, size;

    static int find(int x) {
        return x == fa[x] ? x : (fa[x] = find(fa[x]));
    }

    static void union(int x, int y) {
        int fx = find(x), fy = find(y);
        if (fx == fy) return;
        if (size[fx] < size[fy]) {
            size[fy] += size[fx];
            fa[fx] = fy;
        } else {
            size[fx] += size[fy];
            fa[fy] = fx;
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();
        sc.nextLine();
        fa = new int[n + 1];
        size = new int[n + 1];
        for (int i = 1; i <= n; i++) {
            fa[i] = i;
            size[i] = 1;
        }

        String occ = sc.nextLine();
        char[] chs = occ.toCharArray();
        List<Integer>[] G = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++) G[i] = new ArrayList<>();
        for (int i = 0; i < m; i++) {
            int u = sc.nextInt(), v = sc.nextInt();
            G[u].add(v);
            G[v].add(u);
            if (chs[u - 1] == '1' && chs[v - 1] == '1') union(v, u);
        }

        long ans = 0;
        boolean[] used = new boolean[n + 1];
        for (int u = 1; u <= n; u++) {
            int f = find(u);
            if (used[f]) continue;
            used[f] = true;
            ans += (long) size[f] * (size[f] - 1) / 2;
        }
        // System.out.println(ans);

        long maxProfit = 0;
        int idx = -1;
        Arrays.fill(used, false);
        for (int u = 1; u <= n; u++) {
            if (chs[u - 1] == '0') {
                long curProfit = 0, pref = 1;
                for (int v : G[u]) {
                    int f = find(v);
                    if (chs[v - 1] == '0' || used[f]) continue;
                    used[f] = true;
                    curProfit += pref * size[f];
                    pref += size[f];
                }
                for (int v : G[u]) {
                    int f = find(v);
                    if (chs[v - 1] == '0' || !used[f]) continue;
                    used[f] = false;
                }
                if (curProfit > maxProfit) {
                    maxProfit = curProfit;
                    idx = u;
                }
            }
        }
        System.out.println(idx + " " + (ans + maxProfit));
    }
}


发表于 2024-09-11 01:18:41 回复(0)
C++代码加一
所求的最大值是 配对数,多少对城市之间相互联通。显然就是每个联通块的值是 C(n,2), n 是联通块的节点数量,然后相加。

求联通块大小最好的方法就是并查集算法。
然后
1. 主要是要知道增加一个节点带来的增益是什么,很容易想到就是新形成更大图,然后我们用新的图的收益减去原来各个独立的图的受益和,就是这个节点的增益。

2.最后遍历找到最大的,就ok 了

#include <iostream>
#include<bits/stdc++.h>
using namespace std;


int f[1000000];
int v[1000000];

int fa(int x)
{
    if(f[x]!=x) f[x]=fa(f[x]);
    return f[x];
}

void link(int x,int y){
    int fx=fa(x),fy=fa(y);
    if(fx!=fy){
        f[fx]=fy;
        v[fy]=v[fy]+v[fx];
        
    }
    return;
}
long long   cc[1000000];

void init(){
    for (int i=0;i<=1e5+100;i++){
        f[i]=i;
        v[i]=1;
        cc[i]=1;//1 point
    }
    
    return ;
}

vector<vector<int>> g(100005);

long long  fv(int x){

    return x*(x-1)/2;
}


int main() {

    init();
    // int a, b;
    int n,m;
    cin>>n>>m;
    string s;
    bool e[1000000];
    
    cin>>s;
    for(int i=0;i<s.size();i++) e[i+1]=s[i]-'0';

    
    for (int i=0;i<m;i++){
        int x,y;
        cin>>x>>y;
        if (e[x]&&e[y])
        link(x,y);
        g[x].push_back(y);
        g[y].push_back(x);
    }
    // return 0;

    int maxnode=1;
    int maxvalue=0;
    for(int i=1;i<=n;i++){
        if(!e[i]){
            set<int> lset;
            for( auto node : g[i]){
                if (e[node]) //occupied
                  lset.insert(fa(node));
            }
        if (lset.empty()) continue;;
        int num=1;
        long long  oscore=0;
        for (auto node:lset){
            int nn=v[node];
            num+=nn;
            oscore+=fv(nn);
        }
        cc[i]=fv(num)-oscore;
        if (cc[i]>maxvalue){ maxvalue=cc[i];maxnode=i;}
        }
        
    }
    
    int ans=0;

    for (int i=1;i<=n;i++){
        if(e[i]&&f[i]==i) ans+= fv(v[i]);
    }

    ans+=maxvalue;

    cout<<maxnode<<" "<<ans;
    return 0;
}
// 64 位输出请用 printf("%lld")
发表于 2024-10-23 18:03:41 回复(0)
C(82,2)=3321 C(83,2)=3403 所以这个第一个案例3350是哪个的组合数,我城市算对了收益算错了
发表于 2024-10-09 15:24:18 回复(0)