P6154 游走[dag + dp + 数学期望]

图片说明
图片说明
图片说明

#include <stdio.h>
#include <cstring>
#include <algorithm>
#include <vector>
#include <stack>
#include <queue>
#include <iostream>
#include <map>

#define go(i, l, r) for(int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define god(i, r, l) for(int i = (r), i##end = (int)(l); i >= i##end; --i)
#define ios ios_base::sync_with_stdio(0),cin.tie(0),cout.tie(0)
#define debug_in  freopen("in.txt","r",stdin)
#define debug_out freopen("out.txt","w",stdout);
#define pb push_back
#define all(x) x.begin(),x.end()
#define fs first
#define sc second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll,ll> pii;
const ll maxn = 1e5+10;
const ll maxM = 1e6+10;
const ll inf_int = 1e8;
const ll inf_ll = 1e17;

template<class T>void read(T &x){
    T s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
    x = s*w;
}
template<class H, class... T> void read(H& h, T&... t) {
    read(h);
    read(t...);
}

void pt(){ cout<<'\n';}
template<class H, class ... T> void pt(H h,T... t){ cout<<" "<<h; pt(t...);}

//--------------------------------------------
const int mod = 998244353;
ll N,M,total = 0;
int h[maxn],e[maxn*7],ne[maxn*7],idx = 1;
int length[maxn],dp[maxn];
bool rec[maxn];
void add(int a,int b){
    e[++idx] = b;
    ne[idx] = h[a];
    h[a] = idx;
}
void dfs(int u){
    if(rec[u]) return ;
    dp[u] = 1;
    for(int i = h[u];i;i = ne[i]){
        int v = e[i];
        dfs(v);
        dp[u] = (1LL * dp[u] + dp[v])%mod;
        length[u] = (1LL * length[u] + length[v] + dp[v])%mod;
        //length[v] + dp[v]就相当于是原来是以v开始的所有路径长度,现在要表示成以u开始的所有路径长度,现在增加了一个点,所有路径长度都+1
    }
    rec[u] = 1;
}
ll ksm(ll a,ll b){
    ll res = 1;
    while(b){
        if(b&1) res= res * a%mod;
        b>>=1;
        a = a*a%mod;
    }
    return res;
}
int main() {
//    debug_in;
//    debug_out;

    read(N,M);
    for(int i = 1;i<=M;i++){
        int x,y;read(x,y);
        add(x,y);
    }
    for(int i = 1;i<=N;i++) rec[i] = 0;
    for(int i = 1;i<=N;i++) dfs(i);
    ll sum = 0;
    for(int i = 1;i<=N;i++) sum += length[i],sum%=mod;
    for(int i = 1;i<=N;i++) total += dp[i],total%=mod;
    printf("%lld\n",sum * ksm(total,mod-2)%mod);


    return 0;
}
Ryuichi的算法分享 文章被收录于专栏

分享我的一些算法题解,致力于把题解做好,部分题解可能会采用视频的形式来讲解。

全部评论

相关推荐

10-09 09:39
门头沟学院 C++
HHHHaos:这也太虚了,工资就一半是真的
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务