<span>UOJ #62.【UR #5】怎样跑得更快</span>
Description
Solution
如题,有
\[\sum_{j = 1} ^ n gcd(i, j) ^ c \times lcm(i, j) ^ d \times x_j \equiv b_i \pmod p \]
首先先把\(lcm(i, j)\)用\(\frac{i \times j}{gcd(i, j)}\)替换,得到
\[\sum_{j = 1}^n gcd(i, j)^{c - d} \times i^d \times j^d \times x_j \equiv b_i \pmod p \]
设\(h(x) = x ^ d\),\(f(x) = x ^ {c - d}\),那么将它们带入,得到
\[\sum_{j = 1}^n f(gcd(i, j)) \times h(i) \times h(j) \times x_j \equiv b_i \pmod p \]
式子中的\(f(gcd(i, j))\)看起来不大好处理,那么假设有\(f(n) = \sum_{d | n} fr(d)\),就可以将\(fr\)带入替换\(f\),得到
\[\sum_{j = 1}^n \sum_d [d | i] [d | j] fr(d) \times h(i) \times h(j) \times x_j \equiv b_i \pmod p \]
我们发现这样就好处理很多,所以只要求\(fr\)就行,那么求\(fr\)就可以直接莫反。
交换求和号,将\([d|i]\)扔进\(\sum\)里,得到
\[\sum_{d | i} fr(d) \sum_{j = 1}^n [d | j] h(i) \times h(j) \times x_j \equiv b_i \pmod p \]
设\(z(d) = \sum_{d | j, j <= n} h(j) \times x_j\),将\(h(i)\)移到等号右边,得到
\[\sum_{d | i} fr(d) \times z(d) \equiv \frac{b_i}{h(i)} \pmod p \]
那么发现这也是莫反的形式,可以莫反一次求出\(fr(i) \times z(d)\)的值。
然后都\(\div fr(i)\)就得到了\(z(i)\)的值,发现\(z\)的形式也是莫反,就可以求出\(h(i) * x_i\)的值,再\(\div h(i)\)就能得到答案。
注意因为不能有正数除\(0\),所以要判断无解的情况。
复杂度的瓶颈在于莫反,最快大概能做到\(O(n \log \log n)\),懒得写了,直接\(O(n \log n)\)滚粗。
注意\(c - d\)可能很大,因为模数是质数所以可以将\(c - d \bmod (MOD - 1)\)
Code
#include <iostream>
#include <cstdio>
#include <cstring>
const int N = 100000;
const int MOD = 998244353;
int n, c, d, q, b[N + 50], tmp[N + 50], h[N + 50], p, f[N + 50], fr[N + 50], invfr[N + 50], invh[N + 50];
void Read(int &x)
{
x = 0; int p = 0; char st = getchar();
while (st < '0' || st > '9') p = (st == '-'), st = getchar();
while (st >= '0' && st <= '9') x = (x << 1) + (x << 3) + st - '0', st = getchar();
x = p ? -x : x;
return;
}
int Ksm(int a, int b)
{
int tmp = 1;
while (b)
{
if (b & 1) tmp = 1LL * tmp * a % MOD;
a = 1LL * a * a % MOD;
b >>= 1;
}
return tmp;
}
int main()
{
Read(n); Read(c); Read(d); Read(q); p = (c - d + MOD - 1) % (MOD - 1);
for (int i = 1; i <= n; i++) h[i] = Ksm(i, d), invh[i] = Ksm(h[i], MOD - 2), f[i] = Ksm(i, p);
for (int i = 1; i <= n; i++) fr[i] = f[i];
for (int i = 1; i <= n; i++)
for (int j = i + i; j <= n; j += i)
fr[j] = (fr[j] - fr[i] + MOD) % MOD;
for (int i = 1; i <= n; i++) invfr[i] = Ksm(fr[i], MOD - 2);
while (q--)
{
int flag = 0;
for (int i = 1; i <= n; i++) Read(b[i]);
for (int i = 1; i <= n; i++)
{
if (b[i] && !h[i]) flag = 1;
b[i] = 1LL * b[i] * invh[i] % MOD;
}
for (int i = 1; i <= n; i++) tmp[i] = b[i];
for (int i = 1; i <= n; i++)
for (int j = i + i; j <= n; j += i)
tmp[j] = (tmp[j] - tmp[i] + MOD) % MOD;
for (int i = 1; i <= n; i++)
{
if (tmp[i] && !fr[i]) flag = 1;
tmp[i] = 1LL * tmp[i] * invfr[i] % MOD;
b[i] = 0;
}
for (int i = 1; i <= n; i++) b[i] = tmp[i];
for (int i = n; i >= 1; i--)
for (int j = i + i; j <= n; j += i)
b[i] = (b[i] - b[j] + MOD) % MOD;
for (int i = 1; i <= n; i++)
{
if (b[i] && !h[i]) flag = 1;
b[i] = 1LL * b[i] * invh[i] % MOD;
}
if (flag) { puts("-1"); continue; }
for (int i = 1; i <= n; i++) printf("%d ", b[i]);
printf("\n");
}
return 0;
}