邮局 100分代码(dfs+多重剪枝)

蓝桥杯真题-邮局

#include<iostream>
#include<algorithm>
#include<set>
#include<string>
#include<cstring>
#include<vector>
#include<cmath>
using namespace std;

const double inf = 1.7e300;

struct node
{
    int x, y;
}point[55], post[30];

double mp[55][55];
int flag[55];
int n, m, k;
double res = inf;
int cur[55];
int ans[55];


double dis(node a, node b)
{
    return sqrt((a.x - b.x)*(a.x - b.x) + (a.y - b.y)*(a.y - b.y));
}
//st表示当前考虑的邮局编号,sum_k表示的是已经加入的post总数,tmp_r[]表示
void dfs(int st, int sum_k, double sum_r, double tmp_r[])
{
    if ((k - (sum_k-1)) > (m - (st-1))) return;//剪枝 1:左边表示还需要加入的邮局,右边表示剩余的邮局,邮局数量不够
    if (st > m && sum_k <= k) return; //剪枝 2:已经没有剩余的邮局了,并且k个邮局还未加满
//  if (sum_k > k+1) return;
    if (sum_k == k+1)
    {
        //cout <<" shuchu" << st-1 << " " << sum_k-1 << " " << sum_r << endl;
        if (res > sum_r)
        {
            res = sum_r;//更新,保存每一个邮局的编号
            for (int i = 1; i <= k; i++)
            {
                ans[i] = cur[i];
            }
        }
        return;
    }
    double dis_r[55];//必须重新定义一个数组才属于这个函数
    for (int i = 1; i <= n; i++)
    {
        dis_r[i] = tmp_r[i];
    }
    dfs(st + 1, sum_k, sum_r, dis_r);//不建造此邮局
    if (flag[st] == 1) return;//表示当前位置的邮局对于缩短距离没有增益,直接return,不需要考虑选择该邮局地点的方案
    //下面表示加入当前邮局的情况
    cur[sum_k] = st;//加入
    int mark1 = 0, mark2 = 0;
    if (sum_k ==1)//第一个邮局,初始化 dis_r
    {
        for (int i = 1; i <= n; i++)
        {
            dis_r[i] = mp[i][st];
            sum_r += mp[i][st];
        }
        mark1 = 1;
    }
    else
    {
        for (int i = 1; i <= n; i++)
        {
            if (dis_r[i] > mp[i][st])
            {
                mark2 = 1;
                sum_r += mp[i][st] - dis_r[i];
                dis_r[i] = mp[i][st];
            }
        }
    }
    if (mark1 == 0 && mark2 == 0) flag[st] = 1;
    if (mark1 == 1 || mark2 == 1) dfs(st + 1, sum_k + 1, sum_r, dis_r);
    return;
}

void init()
{
    for (int i = 1; i <= n; i++)
    {
        scanf("%d %d", &point[i].x, &point[i].y);
    }

    for (int i = 1; i <= m; i++)
    {
        scanf("%d %d", &post[i].x, &post[i].y);
    }

    for (int i = 1; i <= n; i++)
    {
        for (int j = 1; j <= m; j++)
        {
            mp[i][j] = dis(point[i], post[j]);
        }
    }
}

int main()
{
    scanf("%d %d %d", &n, &m, &k);
    init();
    double dis_r[55];
    memset(dis_r, -1, sizeof(dis_r));
    dfs(1, 1, 0, dis_r);
    for (int i = 1; i <= k; i++)
    {
        cout << ans[i] << " ";
    }

    return 0;
}

全部评论

相关推荐

评论
点赞
收藏
分享
牛客网
牛客企业服务