小欧占领了其中一些城市。如果两个城市可以通过若干条道路互相到达,且这些道路经过的城市都是小欧占领的,那么这两个城市之间就可以通过经商获得收益
现在,小欧准备占领一个未被占领的城市,使得总收益最大化。你能帮帮她吗?
第一行输入两个正整数和
,代表城市数量和道路数量。
第二行输入一个长度为的 01 串。第
个字符为'0'代表小欧未占领该城市,'1'代表小欧占领了该城市。
接下来的行,每行输入两个正整数
和
,代表城市
和城市
有一条道路连接。
输出两个整数,第一个整数代表占领的城市编号,第二个整数代表占领后的收益。请保证收益的最大化。如果有多种方案收益最大,请输出占领编号最小的城市。
5 5 01010 1 2 1 3 1 4 4 5 1 5
1 3
占领 1 号城市后,总收益为 3。1 号城市和 2 号城市经商,1 号城市和 4 号城市经商,2 号城市和 4 号城市经商。
from collections import defaultdict def solve(n, m, occupied, roads): # 创建邻接表 graph = defaultdict(list) for u, v in roads: graph[u-1].append(v-1) graph[v-1].append(u-1) # 使用并查集来跟踪连通分量 parent = list(range(n)) size = [1 if occupied[i] else 0 for i in range(n)] def find(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] def union(x, y): px, py = find(x), find(y) if px != py: if size[px] < size[py]: px, py = py, px parent[py] = px size[px] += size[py] # 初始化连通分量 for i in range(n): if occupied[i]: for j in graph[i]: if occupied[j]: union(i, j) # 计算当前收益 current_profit = sum(size[i] * (size[i] - 1) // 2 for i in range(n) if i == find(i) and size[i] > 0) max_profit_increase = 0 best_city = 0 # 尝试占领每个未占领的城市 for city in range(n): if not occupied[city]: connected_components = set() total_size = 1 # 包括新占领的城市 for neighbor in graph[city]: if occupied[neighbor]: root = find(neighbor) if root not in connected_components: connected_components.add(root) total_size += size[root] new_profit = total_size * (total_size - 1) // 2 old_profit = sum(size[root] * (size[root] - 1) // 2 for root in connected_components) profit_increase = new_profit - old_profit if profit_increase > max_profit_increase: max_profit_increase = profit_increase best_city = city + 1 return best_city, current_profit + max_profit_increase # 读取输入 n, m = map(int, input().split()) occupied = [c == '1' for c in input().strip()] roads = [tuple(map(int, input().split())) for _ in range(m)] # 解决问题并输出结果 city, profit = solve(n, m, occupied, roads) print(f"{city} {profit}")
#include "bits/stdc++.h" using namespace std; struct dsu { vector<int> _pa, _size; explicit dsu(int size): _pa(size), _size(size, 1) {iota(_pa.begin(), _pa.end(), 0);} int find(int x) {return _pa[x] == x ? x : (_pa[x] = find(_pa[x])); } void unite(int x, int y) { x = find(x); y = find(y); if (x == y) return; if (_size[x] < _size[y]) swap(x, y); _pa[y] = x; _size[x] += _size[y]; } }; vector<bool> occupied; vector<vector<int>> graph; int main() { int n, m; cin >> n >> m; graph.resize(n); occupied.resize(n); string s; cin >> s; for (int i = 0; i < n; i++) { occupied[i] = (s[i] == '1'); } int u, v; for (int i = 0; i < m; i++) { cin >> u >> v; u--, v--; graph[u].push_back(v); graph[v].push_back(u); } dsu uu(n); for (int i = 0; i < n; i++) { if (occupied[i]) { for (const int j: graph[i]) { if (occupied[j]) { uu.unite(i, j); } } } } int cur_profit = 0; for (int i = 0; i < n; i++) { if (occupied[i] && uu._pa[i] == i) { cur_profit += uu._size[i] * (uu._size[i] - 1) / 2; } } int max_profit_inc = 0; int city = 0; for (int i = 0; i < n; i++) { if (occupied[i]) continue; int total_size = 1; set<int> connected_components; for (const int j: graph[i]) { if (occupied[j]) { connected_components.insert(uu.find(j)); } } int old_profit = 0; for (const int j: connected_components) { old_profit += uu._size[j] * (uu._size[j] - 1) / 2; total_size += uu._size[j]; } int new_profit = total_size * (total_size - 1) / 2; int profit_inc = new_profit - old_profit; if (max_profit_inc < profit_inc) { max_profit_inc = profit_inc; city = i + 1; } } cout << city << " " << cur_profit + max_profit_inc; return 0; }
#include <iostream> #include <vector> #include <list> using namespace std; void visit(int group_number, string & cities, vector <list<int>> & graph, int i, vector <int> & group, vector <int> & group_mass){ if (group[i]) return; group[i] = group_number; if (cities[i-1]=='1') { group_mass[group_number]++; } for (int adjacent:graph[i]) if (cities[adjacent-1] == '1') visit(group_number, cities, graph, adjacent, group, group_mass); } int main() { int n, m; cin>>n>>m; string cities; cin>>cities; vector <int> income_city(n+1); vector <list<int>> graph(n+1); vector <int> group(n+1); vector <int> group_mass(n+1); vector <int> st(n+1); for(int i=0;i<=n;i++) { income_city[i]=0; group[i]=0; group_mass[i]=0; st[i] = 0; } int a,b; while (cin >> a >> b) { // 注意 while 处理多个 case graph[a].push_back(b); graph[b].push_back(a); } int group_number = 1; for(int i=1;i<=n;i++) { if (group[i] || (cities[i-1]=='0')) continue; visit(group_number, cities, graph, i, group, group_mass); group_number++; } int max = 0; int max_index = 0; for(int i=1;i<=n;i++) { int self_degree = 0; int sum = 0; int g = 0; int deduce = 0; if(cities[i-1]=='1') continue; else for (int adjacent:graph[i]) { if(cities[adjacent-1]=='1') { g = group[adjacent]; if(!st[g]) { st[g] = 1; sum += group_mass[g]; deduce += group_mass[g] * (group_mass[g]-1)/2; } } } for (int adjacent:graph[i]) { g = group[adjacent]; st[g] = 0; } if(sum*(sum+1)/2 - deduce> max) { max=sum*(sum+1)/2 - deduce; max_index=i; } } int total = 0; for (int i=1;i<=n;i++) { if(!st[i]) { total = total + group_mass[i] * (group_mass[i]-1)/2; } } cout << max_index << ' ' << max + total; } // 64 位输出请用 printf("%lld"用到的中间变量有点多。这要怎么在真的考试的时候debug出来啊,真的难绷。
import java.util.*; // 注意类名必须为 Main, 不要有任何 package xxx 信息 public class Main { static int[] fa, size; static int find(int x) { return x == fa[x] ? x : (fa[x] = find(fa[x])); } static void union(int x, int y) { int fx = find(x), fy = find(y); if (fx == fy) return; if (size[fx] < size[fy]) { size[fy] += size[fx]; fa[fx] = fy; } else { size[fx] += size[fy]; fa[fy] = fx; } } public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(), m = sc.nextInt(); sc.nextLine(); fa = new int[n + 1]; size = new int[n + 1]; for (int i = 1; i <= n; i++) { fa[i] = i; size[i] = 1; } String occ = sc.nextLine(); char[] chs = occ.toCharArray(); List<Integer>[] G = new ArrayList[n + 1]; for (int i = 1; i <= n; i++) G[i] = new ArrayList<>(); for (int i = 0; i < m; i++) { int u = sc.nextInt(), v = sc.nextInt(); G[u].add(v); G[v].add(u); if (chs[u - 1] == '1' && chs[v - 1] == '1') union(v, u); } long ans = 0; boolean[] used = new boolean[n + 1]; for (int u = 1; u <= n; u++) { int f = find(u); if (used[f]) continue; used[f] = true; ans += (long) size[f] * (size[f] - 1) / 2; } // System.out.println(ans); long maxProfit = 0; int idx = -1; Arrays.fill(used, false); for (int u = 1; u <= n; u++) { if (chs[u - 1] == '0') { long curProfit = 0, pref = 1; for (int v : G[u]) { int f = find(v); if (chs[v - 1] == '0' || used[f]) continue; used[f] = true; curProfit += pref * size[f]; pref += size[f]; } for (int v : G[u]) { int f = find(v); if (chs[v - 1] == '0' || !used[f]) continue; used[f] = false; } if (curProfit > maxProfit) { maxProfit = curProfit; idx = u; } } } System.out.println(idx + " " + (ans + maxProfit)); } }