题解 | #Product of GCDs#
Arithmetic Progression
https://ac.nowcoder.com/acm/contest/11253/A
J - Product of GCDs
题意: 给出一个数组,求这个数组里所有大小为的子集的
的乘积。
题解:
- 我们考虑通过枚举
来求得答案,也就是对于每个
,我们需要求有几个集合的
为
。
- 我们考虑枚举
的倍数,这样
只可能为
或者
的倍数。但这样会有一个问题,会把
为
的倍数的集合也筛选进来。我们可以从大到小枚举
,然后减去
为
的倍数的集合。考虑组合数
,
为
的倍数的个数。假设
为
的集合个数为
,最后答案就是
- 但是因为
会很大,所以没办法直接求上述式子。故我们考虑扩展欧拉降幂,因为模数并没保证是质数。但是模数又来到了很大的
,所以我们无法通过线性递推来求欧拉函数,我们得借助分解质因数和容斥原理来求欧拉函数。因为要分解质因数的数很大,所以我直接上了玄学算法*pollard_rho *来分解质因数。容斥求欧拉函数的时候会爆
。可以先做除法,或者直接上__int128。
- 这题有点卡常,得加快读。最后别忘了初始化。
#include<bits/stdc++.h>
using namespace std;
#define dbg(x...) do { cout << #x << " -> "; err(x); } while (0)
void err () { cout << endl;}
template <class T, class... Ts>
void err(const T& arg, const Ts&... args) { cout << arg << ' '; err(args...);}
#define ll long long
#define ull unsigned long long
#define LL __int128
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define PII pair<ll, ll>
#define fi first
#define se second
#define pb push_back
#define mp(a,b) make_pair(a,b)
#define PAUSE system("pause");
const double Pi = acos(-1.0);
const double eps = 1e-8;
const int maxn = 4e4 + 10;
const int maxm = 6e2 + 10;
const int mod = 998244353;
const int P = 131;
const int S = 5;
inline ll rd() {
ll f = 0; ll x = 0; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) f |= (ch == '-');
for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
if (f) x = -x;
return x;
}
void out(ll a) {if(a<0)putchar('-'),a=-a;if(a>=10)out(a/10);putchar(a%10+'0');}
#define pt(x) out(x),puts("")
inline void swap(ll &a, ll &b){ll t = a; a = b; b = t;}
inline void swap(int &a, int &b){int t = a; a = b; b = t;}
inline ll min(ll a, ll b){return a < b ? a : b;}
inline ll max(ll a, ll b){return a > b ? a : b;}
ll s, k, p, phi, ans[maxn * 2];
ll c[maxn * 2][40], num[maxn * 2];
ll mul(ll a, ll b, ll c) {
return (LL)a * b % c;
}
ll pow_mod(ll a, ll n, ll mod) {
ll ret = 1;
ll temp = a % mod;
while(n) {
if(n & 1) ret = mul(ret, temp, mod);
temp = mul(temp, temp, mod);
n >>= 1;
}
return ret;
}
bool check(ll a, ll n, ll x, ll t) {
ll ret = pow_mod(a, x, n);
ll last = ret;
for(int i = 1; i <= t; i++) {
ret = mul(ret, ret, n);
if(ret == 1 && last != 1 && last != n - 1) return true;
last = ret;
}
if(ret != 1) return true;
else return false;
}
bool Miller_Rabin(ll n) {
if(n < 2) return false;
if(n == 2) return true;
if((n & 1) == 0) return false;
ll x = n - 1;
ll t = 0;
while((x & 1) == 0) {
x >>= 1;
t++;
}
srand(time(NULL));
for(int i = 0; i < S; i++) {
ll a = rand() % (n - 1) + 1;
if(check(a, n, x, t))
return false;
}
return true;
}
ll fac[100];
int tol;
ll gcd(ll a, ll b) {
ll t;
while(b) {
t = a;
a = b;
b = t % b;
}
if(a >= 0) return a;
else return -a;
}
ll pollard_rho(ll x, ll c) {
ll i = 1, k = 2;
srand(time(NULL));
ll x0 = rand() % (x - 1) + 1;
ll y = x0;
while(1) {
i++;
x0 = (mul(x0, x0, x) + c) % x;
ll d = gcd(y - x0, x);
if(d != 1 && d != x) return d;
if(y == x0) return x;
if(i == k) {
y = x0;
k += k;
}
}
}
void findfac(ll n, int k) {
if(n == 1) return ;
if(Miller_Rabin(n)) {
fac[++tol] = n;
return ;
}
ll p = n;
int c = k;
while(p >= n)
p = pollard_rho(p, c--);
findfac(p, k);
findfac(n / p, k);
}
void init(ll mod, int n) {
tol = 0;
for(int i = 0; i <= n; i++){
c[i][0] = 1;
if(i < 32) c[i][i] = 1;
for(int j = 1; j <= 31 && j < i; j++){
c[i][j] = c[i - 1][j] + c[i - 1][j - 1];
if(c[i][j] >= mod) c[i][j] -= mod;
}
}
}
ll get_phi(ll x) {
for(int i = 1; i <= tol; i++) {
x /= fac[i];
x *= (fac[i] - 1);
}
return x;
}
void solve(){
s = rd(); k = rd(); p = rd();
findfac(p, 107);
tol = unique(fac + 1, fac + tol + 1) - fac - 1;
phi = get_phi(p);
init(phi, s);
ll maxx = -1;
for(int i = 1; i <= s; i++) {
int x = rd();
num[x]++;
maxx = max(maxx, x);
}
ll res = 1;
for(int i = maxx; i >= 1; i--) {
int now = 0;
for(int j = i; j <= maxx; j += i) {
now += num[j];
}
ans[i] = c[now][k];
for(int j = i + i; j <= maxx; j += i)
ans[i] -= ans[j];
ans[i] = (ans[i] % phi + phi) % phi;
res = mul(res, pow_mod(i, ans[i], p), p);
}
for(int i = 1; i <= maxx; i++)
num[i] = ans[i] = 0;
printf("%lld\n", res);
}
int main() {
// freopen("02.in","r",stdin);
int t = 1;
t = rd();
while(t--)
solve();
return 0;
}
查看18道真题和解析