第一行输入两个正整数
,代表城市数量和道路数量。
第二行输入一个长度为
的 01 串。第
个字符为 '0' 代表小欧未占领该城市,'1' 代表小欧已经占领了该城市。
接下来的
行,每行输入两个正整数
,代表城市
和城市
有一条道路连接。
输出一行两个空格隔开的整数,第一个整数代表占领的城市编号,第二个整数代表占领后的收益。
请保证收益的最大化。如果有多种方案收益最大,小欧会优先占领编号最小的城市。
5 5 01010 1 2 1 3 1 4 4 5 1 5
1 3
占领 1 号城市后,总收益为 3。1 号城市和 2 号城市经商,1 号城市和 4 号城市经商,2 号城市和 4 号城市经商。
from collections import defaultdict
def solve(n, m, occupied, roads):
# 创建邻接表
graph = defaultdict(list)
for u, v in roads:
graph[u-1].append(v-1)
graph[v-1].append(u-1)
# 使用并查集来跟踪连通分量
parent = list(range(n))
size = [1 if occupied[i] else 0 for i in range(n)]
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px != py:
if size[px] < size[py]:
px, py = py, px
parent[py] = px
size[px] += size[py]
# 初始化连通分量
for i in range(n):
if occupied[i]:
for j in graph[i]:
if occupied[j]:
union(i, j)
# 计算当前收益
current_profit = sum(size[i] * (size[i] - 1) // 2 for i in range(n) if i == find(i) and size[i] > 0)
max_profit_increase = 0
best_city = 0
# 尝试占领每个未占领的城市
for city in range(n):
if not occupied[city]:
connected_components = set()
total_size = 1 # 包括新占领的城市
for neighbor in graph[city]:
if occupied[neighbor]:
root = find(neighbor)
if root not in connected_components:
connected_components.add(root)
total_size += size[root]
new_profit = total_size * (total_size - 1) // 2
old_profit = sum(size[root] * (size[root] - 1) // 2 for root in connected_components)
profit_increase = new_profit - old_profit
if profit_increase > max_profit_increase:
max_profit_increase = profit_increase
best_city = city + 1
return best_city, current_profit + max_profit_increase
# 读取输入
n, m = map(int, input().split())
occupied = [c == '1' for c in input().strip()]
roads = [tuple(map(int, input().split())) for _ in range(m)]
# 解决问题并输出结果
city, profit = solve(n, m, occupied, roads)
print(f"{city} {profit}")
#include "bits/stdc++.h"
using namespace std;
struct dsu {
vector<int> _pa, _size;
explicit dsu(int size): _pa(size), _size(size, 1) {iota(_pa.begin(), _pa.end(), 0);}
int find(int x) {return _pa[x] == x ? x : (_pa[x] = find(_pa[x])); }
void unite(int x, int y) {
x = find(x);
y = find(y);
if (x == y) return;
if (_size[x] < _size[y]) swap(x, y);
_pa[y] = x;
_size[x] += _size[y];
}
};
vector<bool> occupied;
vector<vector<int>> graph;
int main() {
int n, m;
cin >> n >> m;
graph.resize(n);
occupied.resize(n);
string s;
cin >> s;
for (int i = 0; i < n; i++) {
occupied[i] = (s[i] == '1');
}
int u, v;
for (int i = 0; i < m; i++) {
cin >> u >> v;
u--, v--;
graph[u].push_back(v);
graph[v].push_back(u);
}
dsu uu(n);
for (int i = 0; i < n; i++) {
if (occupied[i]) {
for (const int j: graph[i]) {
if (occupied[j]) {
uu.unite(i, j);
}
}
}
}
int cur_profit = 0;
for (int i = 0; i < n; i++) {
if (occupied[i] && uu._pa[i] == i) {
cur_profit += uu._size[i] * (uu._size[i] - 1) / 2;
}
}
int max_profit_inc = 0;
int city = 0;
for (int i = 0; i < n; i++) {
if (occupied[i]) continue;
int total_size = 1;
set<int> connected_components;
for (const int j: graph[i]) {
if (occupied[j]) {
connected_components.insert(uu.find(j));
}
}
int old_profit = 0;
for (const int j: connected_components) {
old_profit += uu._size[j] * (uu._size[j] - 1) / 2;
total_size += uu._size[j];
}
int new_profit = total_size * (total_size - 1) / 2;
int profit_inc = new_profit - old_profit;
if (max_profit_inc < profit_inc) {
max_profit_inc = profit_inc;
city = i + 1;
}
}
cout << city << " " << cur_profit + max_profit_inc;
return 0;
} #include <iostream>
#include <vector>
#include <list>
using namespace std;
void visit(int group_number, string & cities, vector <list<int>> & graph, int i, vector <int> & group, vector <int> & group_mass){
if (group[i])
return;
group[i] = group_number;
if (cities[i-1]=='1')
{
group_mass[group_number]++;
}
for (int adjacent:graph[i])
if (cities[adjacent-1] == '1')
visit(group_number, cities, graph, adjacent, group, group_mass);
}
int main() {
int n, m;
cin>>n>>m;
string cities;
cin>>cities;
vector <int> income_city(n+1);
vector <list<int>> graph(n+1);
vector <int> group(n+1);
vector <int> group_mass(n+1);
vector <int> st(n+1);
for(int i=0;i<=n;i++)
{
income_city[i]=0;
group[i]=0;
group_mass[i]=0;
st[i] = 0;
}
int a,b;
while (cin >> a >> b) { // 注意 while 处理多个 case
graph[a].push_back(b);
graph[b].push_back(a);
}
int group_number = 1;
for(int i=1;i<=n;i++)
{
if (group[i] || (cities[i-1]=='0'))
continue;
visit(group_number, cities, graph, i, group, group_mass);
group_number++;
}
int max = 0;
int max_index = 0;
for(int i=1;i<=n;i++)
{
int self_degree = 0;
int sum = 0;
int g = 0;
int deduce = 0;
if(cities[i-1]=='1')
continue;
else
for (int adjacent:graph[i])
{
if(cities[adjacent-1]=='1')
{
g = group[adjacent];
if(!st[g])
{
st[g] = 1;
sum += group_mass[g];
deduce += group_mass[g] * (group_mass[g]-1)/2;
}
}
}
for (int adjacent:graph[i])
{
g = group[adjacent];
st[g] = 0;
}
if(sum*(sum+1)/2 - deduce> max)
{
max=sum*(sum+1)/2 - deduce;
max_index=i;
}
}
int total = 0;
for (int i=1;i<=n;i++)
{
if(!st[i])
{
total = total + group_mass[i] * (group_mass[i]-1)/2;
}
}
cout << max_index << ' ' << max + total;
}
// 64 位输出请用 printf("%lld" 用到的中间变量有点多。这要怎么在真的考试的时候debug出来啊,真的难绷。
import java.util.*;
// 注意类名必须为 Main, 不要有任何 package xxx 信息
public class Main {
static int[] fa, size;
static int find(int x) {
return x == fa[x] ? x : (fa[x] = find(fa[x]));
}
static void union(int x, int y) {
int fx = find(x), fy = find(y);
if (fx == fy) return;
if (size[fx] < size[fy]) {
size[fy] += size[fx];
fa[fx] = fy;
} else {
size[fx] += size[fy];
fa[fy] = fx;
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt(), m = sc.nextInt();
sc.nextLine();
fa = new int[n + 1];
size = new int[n + 1];
for (int i = 1; i <= n; i++) {
fa[i] = i;
size[i] = 1;
}
String occ = sc.nextLine();
char[] chs = occ.toCharArray();
List<Integer>[] G = new ArrayList[n + 1];
for (int i = 1; i <= n; i++) G[i] = new ArrayList<>();
for (int i = 0; i < m; i++) {
int u = sc.nextInt(), v = sc.nextInt();
G[u].add(v);
G[v].add(u);
if (chs[u - 1] == '1' && chs[v - 1] == '1') union(v, u);
}
long ans = 0;
boolean[] used = new boolean[n + 1];
for (int u = 1; u <= n; u++) {
int f = find(u);
if (used[f]) continue;
used[f] = true;
ans += (long) size[f] * (size[f] - 1) / 2;
}
// System.out.println(ans);
long maxProfit = 0;
int idx = -1;
Arrays.fill(used, false);
for (int u = 1; u <= n; u++) {
if (chs[u - 1] == '0') {
long curProfit = 0, pref = 1;
for (int v : G[u]) {
int f = find(v);
if (chs[v - 1] == '0' || used[f]) continue;
used[f] = true;
curProfit += pref * size[f];
pref += size[f];
}
for (int v : G[u]) {
int f = find(v);
if (chs[v - 1] == '0' || !used[f]) continue;
used[f] = false;
}
if (curProfit > maxProfit) {
maxProfit = curProfit;
idx = u;
}
}
}
System.out.println(idx + " " + (ans + maxProfit));
}
}