题解 | #We love string#
We Love Strings
https://ac.nowcoder.com/acm/contest/57361/I
I . We love string
分治 容斥!!
将所有长度相同和模式串分到一个组中,使用一个类对象来表示一个组;
表示能于集合内的所有正则表达式匹配的的字符串个数
集合 用二进制数表示
考虑到所有模式串的长度之和不超过 ,所以当模式串的长度小于,时模式串可能会比较多,可采用暴力枚举。如果模式串长度大于 那么个数不会超过 个。
重复数的计算:
如果一个集合中包含了第 个模式串,设在当前集合中去除 第 个模式串得到集合为
, 由于模式串的增加,相当于是增加了限制能匹配到的字符串会减少,所以 。
在考虑第 个模式串时 中匹配到的字符串有一部分和 中重合了,都匹配在,中减去重复的即可。 与 的差就是重复数减去即可。
即
在实现去除第 个模式串中使用异或运算即可,异或上位,相同位上的 的相消了,其他位置保持不变。
参考代码:
#include <string>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 400, LIM = 20;
const int mod = 998244353;
ll pow2[N+5];
// 同一个类对象中的正则表达示 长度相同
class soluation{
vector<string> p;
ll solve_small(int len){
ll res = 0;
// 当字符串长度不大时 直接进行二进制枚举
for(int i =0; i < 1 << len; i++){
// 遍历所有模式串 看当前字符串是否能被匹配
bool find = false;
for(string pat : p){
bool cur = true;// 记录在当前模式串是否能被匹配
// 从低位开始向高位枚举
for(int j = 0; j < len; j++){
if(pat[j] == '?') continue;
if(pat[j] != (i >> j & 1)+ '0'){
cur = false;
}
}
if(cur){
find = true;
break;
}
}
if(find)
res++;
}
return res;
}
ll solve_big(int len,int size){
// len :模式串长度 size : 由于题目限制可保证 size < 20
vector<int> f(1 << size);
// f_i 记录有多少个字符串可以和集合 i 中的所有模式串匹配
// 初始使 每一个位置上都没有限制
string base = string(len,'?');
for(int i = 1; i < 1 << size; i++){
// 枚举每一个集合
string cur = base;
// 记录是否存在可以与集合中所有模式串匹配的字符串
bool flag = true;
for(int j = 0; j < size && flag; j++){
if(i >> j & 1){
// 如果当前集合包含第 j+1 个 字符串 对应的下标为 j
// 匹配每一位
for(int k = 0; k < len; k++){
if(p[j][k] != '?'){
if(cur[k] == '?'){
cur[k] = p[j][k];
}
else{
if(cur[k] != p[j][k]){
flag = false;
break;
}
}
}
}
}
}
if(flag){
// 最后cur 中有cnt个 ? f_i = 2^cnt;
f[i] = pow2[count(cur.begin(),
cur.end(),'?')];
}
}
// 容斥原理减去重复元素
for(int i =1; i < 1 << size; i++){
for(int j = 0; j < size; j++){
if(i >> j & 1){
// 集合i 中包含第 i 个元素
int m = i ^ (1 << j);
// +mod 是为了防止出现负数
f[m] = (f[m]-f[i]+mod) % mod;
}
}
}
ll ans = 0;
for(int i =1; i < 1 << size; i++){
ans += f[i];
ans %= mod;
}
return ans;
}
public:
ll query(){
if(p.empty()) return 0;
int len = p[0].size();
if(len > LIM){
return solve_big(len,p.size());
}
return solve_small(len);
}
void insert(string s){
p.push_back(s);
}
} S[N+5];
int main(){
int n; cin >> n;
string p;
for(int i =1; i <= n; i++){
cin >> p;
S[p.size()].insert(p);
}
// 预处理所有会用到的2的幂
pow2[0] = 1;
for(int i =1; i <= N; i++){
pow2[i] = (pow2[i-1]*2)%mod;
}
ll ans = 0;
for(int i =1; i <= N; i++){
ans = ans+S[i].query();
ans %= mod;
}
cout << ans << '\n';
}
注意事项:记得处理 不然会有 的测试点过不去;