分层图最短路
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = 10005;
const ll inf = 1e18 + 7;
struct edge
{
int to;
ll w;
int nex;
}e[50000 * 2];
int head[MAXN], tot;
int n, m, k;
void add(int a, int b, ll c)
{
e[tot] = edge{ b,c,head[a] };
head[a] = tot++;
}
ll dis[MAXN][12];
bool vis[MAXN][12];
struct node
{
int now;
ll w;
int cnt;
};
auto cmp = [](node q, node w) {
return q.w > w.w;
};
void dij()
{
priority_queue<node, vector<node>, decltype(cmp) >q(cmp);
q.push(node{ 1,0,0 });
dis[1][0] = 0;
while (!q.empty())
{
node temp = q.top();
q.pop();
if (vis[temp.now][temp.cnt])
continue;
vis[temp.now][temp.cnt] = true;
for (int i = head[temp.now]; i + 1; i = e[i].nex)
{
int t = e[i].to;
if (dis[t][temp.cnt] > dis[temp.now][temp.cnt] + e[i].w)
{
dis[t][temp.cnt] = dis[temp.now][temp.cnt] + e[i].w;
q.push(node{ t,dis[t][temp.cnt],temp.cnt });
}
if (dis[t][temp.cnt + 1] > dis[temp.now][temp.cnt] + e[i].w / 2)
{
if (temp.cnt > 10)
continue;
dis[t][temp.cnt + 1] = dis[temp.now][temp.cnt] + e[i].w / 2;
q.push(node{ t,dis[t][temp.cnt + 1],temp.cnt + 1 });
}
}
}
//printf("%lld\n", dis[n]);
}
void init()
{
for (int i = 0; i < MAXN; i++)
for (int j = 0; j < 12; j++)
dis[i][j] = inf;
memset(head, -1, sizeof(head));
tot = 1;
}
int main()
{
init();
scanf("%d%d%d", &n, &m, &k);
while (m--)
{
int a, b;
ll c;
scanf("%d%d%lld", &a, &b, &c);
add(a, b, c);
}
dij();
ll ans = inf;
for (int i = 0; i <= k; i++)
{
ans = min(ans, dis[n][i]);
}
printf("%lld\n", ans);
}
code:
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int MAXN = 6e5;
const ll mod = 1e9 + 7;
const ll inf = 1e18 + 7;
struct edge
{
int to;
ll w;
int nex;
}e[MAXN * 2];
int head[MAXN], tot;
ll dis[MAXN];
bool vis[MAXN];
void init()
{
memset(head, -1, sizeof(head));
tot = 1;
for (int i = 0; i < MAXN; i++)
dis[i] = inf;
}
void add(int a, int b, ll c)
{
e[tot] = edge{ b,c,head[a] };
head[a] = tot++;
}
#define Pair pair<int,int>
void dij(int s)
{
priority_queue<Pair, vector<Pair>, greater<Pair> >q;
q.push(Pair{ 0,s });
dis[s] = 0;
while (!q.empty())
{
int u = q.top().second;
q.pop();
if (vis[u])
continue;
vis[u] = true;
for (int i = head[u]; i + 1; i = e[i].nex)
{
int v = e[i].to;
if (dis[v] > dis[u] + e[i].w)
{
dis[v] = dis[u] + e[i].w;
q.push(Pair{ dis[v],v });
}
}
}
}
int n, m, s, t;
int in[1005];
int main()
{
init();
scanf("%d%d%d%d", &n, &m, &s, &t);
int a, b, cnt;
for (int i = 0; i < m; i++)
{
scanf("%d%d%d", &a, &b, &cnt);
for (int j = 0; j < cnt; j++)
{
scanf("%d", &in[j]);
if (j > 0)
{
add(i * n + in[j - 1], i * n + in[j], b);
add(i * n + in[j], i * n + in[j - 1], b);
}
add(i * n + in[j], m * n + in[j], 0);
add(m * n + in[j], i * n + in[j], a);
}
}
dij(m * n + s);
if (dis[m * n + t] == inf)
dis[m * n + t] = -1;
printf("%lld\n", dis[m * n + t]);
}