题解 | #Product of GCDs#
Arithmetic Progression
https://ac.nowcoder.com/acm/contest/11253/A
J - Product of GCDs
题意: 给出一个数组,求这个数组里所有大小为的子集的
的乘积。
题解:
- 我们考虑通过枚举
来求得答案,也就是对于每个
,我们需要求有几个集合的
为
。
- 我们考虑枚举
的倍数,这样
只可能为
或者
的倍数。但这样会有一个问题,会把
为
的倍数的集合也筛选进来。我们可以从大到小枚举
,然后减去
为
的倍数的集合。考虑组合数
,
为
的倍数的个数。假设
为
的集合个数为
,最后答案就是
- 但是因为
会很大,所以没办法直接求上述式子。故我们考虑扩展欧拉降幂,因为模数并没保证是质数。但是模数又来到了很大的
,所以我们无法通过线性递推来求欧拉函数,我们得借助分解质因数和容斥原理来求欧拉函数。因为要分解质因数的数很大,所以我直接上了玄学算法*pollard_rho *来分解质因数。容斥求欧拉函数的时候会爆
。可以先做除法,或者直接上__int128。
- 这题有点卡常,得加快读。最后别忘了初始化。
#include<bits/stdc++.h> using namespace std; #define dbg(x...) do { cout << #x << " -> "; err(x); } while (0) void err () { cout << endl;} template <class T, class... Ts> void err(const T& arg, const Ts&... args) { cout << arg << ' '; err(args...);} #define ll long long #define ull unsigned long long #define LL __int128 #define inf 0x3f3f3f3f #define INF 0x3f3f3f3f3f3f3f3f #define pii pair<int, int> #define PII pair<ll, ll> #define fi first #define se second #define pb push_back #define mp(a,b) make_pair(a,b) #define PAUSE system("pause"); const double Pi = acos(-1.0); const double eps = 1e-8; const int maxn = 4e4 + 10; const int maxm = 6e2 + 10; const int mod = 998244353; const int P = 131; const int S = 5; inline ll rd() { ll f = 0; ll x = 0; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) f |= (ch == '-'); for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0'; if (f) x = -x; return x; } void out(ll a) {if(a<0)putchar('-'),a=-a;if(a>=10)out(a/10);putchar(a%10+'0');} #define pt(x) out(x),puts("") inline void swap(ll &a, ll &b){ll t = a; a = b; b = t;} inline void swap(int &a, int &b){int t = a; a = b; b = t;} inline ll min(ll a, ll b){return a < b ? a : b;} inline ll max(ll a, ll b){return a > b ? a : b;} ll s, k, p, phi, ans[maxn * 2]; ll c[maxn * 2][40], num[maxn * 2]; ll mul(ll a, ll b, ll c) { return (LL)a * b % c; } ll pow_mod(ll a, ll n, ll mod) { ll ret = 1; ll temp = a % mod; while(n) { if(n & 1) ret = mul(ret, temp, mod); temp = mul(temp, temp, mod); n >>= 1; } return ret; } bool check(ll a, ll n, ll x, ll t) { ll ret = pow_mod(a, x, n); ll last = ret; for(int i = 1; i <= t; i++) { ret = mul(ret, ret, n); if(ret == 1 && last != 1 && last != n - 1) return true; last = ret; } if(ret != 1) return true; else return false; } bool Miller_Rabin(ll n) { if(n < 2) return false; if(n == 2) return true; if((n & 1) == 0) return false; ll x = n - 1; ll t = 0; while((x & 1) == 0) { x >>= 1; t++; } srand(time(NULL)); for(int i = 0; i < S; i++) { ll a = rand() % (n - 1) + 1; if(check(a, n, x, t)) return false; } return true; } ll fac[100]; int tol; ll gcd(ll a, ll b) { ll t; while(b) { t = a; a = b; b = t % b; } if(a >= 0) return a; else return -a; } ll pollard_rho(ll x, ll c) { ll i = 1, k = 2; srand(time(NULL)); ll x0 = rand() % (x - 1) + 1; ll y = x0; while(1) { i++; x0 = (mul(x0, x0, x) + c) % x; ll d = gcd(y - x0, x); if(d != 1 && d != x) return d; if(y == x0) return x; if(i == k) { y = x0; k += k; } } } void findfac(ll n, int k) { if(n == 1) return ; if(Miller_Rabin(n)) { fac[++tol] = n; return ; } ll p = n; int c = k; while(p >= n) p = pollard_rho(p, c--); findfac(p, k); findfac(n / p, k); } void init(ll mod, int n) { tol = 0; for(int i = 0; i <= n; i++){ c[i][0] = 1; if(i < 32) c[i][i] = 1; for(int j = 1; j <= 31 && j < i; j++){ c[i][j] = c[i - 1][j] + c[i - 1][j - 1]; if(c[i][j] >= mod) c[i][j] -= mod; } } } ll get_phi(ll x) { for(int i = 1; i <= tol; i++) { x /= fac[i]; x *= (fac[i] - 1); } return x; } void solve(){ s = rd(); k = rd(); p = rd(); findfac(p, 107); tol = unique(fac + 1, fac + tol + 1) - fac - 1; phi = get_phi(p); init(phi, s); ll maxx = -1; for(int i = 1; i <= s; i++) { int x = rd(); num[x]++; maxx = max(maxx, x); } ll res = 1; for(int i = maxx; i >= 1; i--) { int now = 0; for(int j = i; j <= maxx; j += i) { now += num[j]; } ans[i] = c[now][k]; for(int j = i + i; j <= maxx; j += i) ans[i] -= ans[j]; ans[i] = (ans[i] % phi + phi) % phi; res = mul(res, pow_mod(i, ans[i], p), p); } for(int i = 1; i <= maxx; i++) num[i] = ans[i] = 0; printf("%lld\n", res); } int main() { // freopen("02.in","r",stdin); int t = 1; t = rd(); while(t--) solve(); return 0; }