第一行包含两个整数n,m.
第二行给出n个整数A1,A2,...,An。
数据范围
对于30%的数据,1 <= n, m <= 1000
对于100%的数据,1 <= n, m, Ai <= 10^5
输出仅包括一行,即所求的答案
3 10 6 5 10
2
// 似乎不用各种比较m和a啊 // eg: // 为了方便举例,假设所有数都是5位二进制,m = 0x00100 // 只需要枚举ai: // 找到跟ai异或后是0x1****、0x01***、0x0011*、0x00101的个数就行了 // 而异或是可逆的,也就是找 ai^m满足上面条件的个数。 #include <stdio.h> #include <string.h> #include <stdint.h> int n,m; int a[100010]; int c[(1<<18)+1]; int maxDepth = 16; // 存0x1****、0x01***等等分别的个数 void buildTrie() { memset(c, 0, sizeof(c)); for(int i=0;i<n;i++) { int now = 1; for(int j=maxDepth;j>=0;j--) { now = (now<<1) + ((a[i]&(1<<j))>0 ? 1 : 0); c[now]++; // printf("p %d add c[%d] to %d\n", a[i], now, c[now]); } } } int countTrie(int p, int depth) { int now = 1; for(int j=maxDepth;j>=depth;j--) { now = (now<<1) + ((p&(1<<j))>0 ? 1 : 0); } // printf("p=%d, depth=%d, c[%d]=%d\n", p, depth, now, c[now]); return c[now]; } int calTrie(int p) { int count = 0; int nowbit = 0; for(int j=maxDepth;j>=0;j--) { int bit = (1<<j); int k = m&bit; nowbit |= k; // 该位是0,则置上1肯定比m大 if(k == 0) { count += countTrie((nowbit|bit)^p, j); } } return count; } int main() { scanf("%d%d", &n, &m); for(int i=0;i<n;i++) { scanf("%d", &a[i]); } buildTrie(); __int64_t count = 0; for(int i=0;i<n;i++) { count += calTrie(a[i]); } printf("%ld", count/2); return 0; }
``` #include<iostream> #include<vector> using namespace std; class query_tree { public: query_tree *next[2]{NULL,NULL}; int count; query_tree() :count(1) { } }; query_tree root; void build_tree(int m) { query_tree *cur=&root; for(int j=16;j>=0;j--) { bool flag=m>>j & 1; if(!cur->next[flag]) { cur->next[flag]=new query_tree; } else cur->next[flag]->count++; cur=cur->next[flag]; } } long long query_num(int n,int m,query_tree *root,int index) { if(index<0) return 0; int n_i=n>>index & 1; int m_i=m>>index & 1; if(n_i==1 && m_i==1) { return root->next[0]?query_num(n,m,root->next[0],index-1):0; } else if(n_i==1 && m_i==0){ long long val1=root->next[0]?root->next[0]->count:0; long long val2=root->next[1]?query_num(n,m,root->next[1],index-1):0; return val1+val2; } else if(n_i==0 && m_i==1){ return root->next[1]?query_num(n,m,root->next[1],index-1):0; } else { long long val1=root->next[1]?root->next[1]->count:0; long long val2=root->next[0]?query_num(n,m,root->next[0],index-1):0; return val1+val2; } } int main() { int n,m; cin>>n>>m; vector<int> vi(n); long long count=0; for(int i=0;i<n;i++) { cin>>vi[i]; build_tree(vi[i]); } for(int i=0;i<n;i++) count += query_num(vi[i],m,&root,16); cout<<count/2; return 0; } ```
#include<iostream> #include<cstring> #include<algorithm> #include<string> usingnamespacestd; structNode { Node* next[2]; intcnt; Node() { cnt=0; memset(next,0,sizeof(next)); } }; voidinsert(Node* root,constchar* s) { while(*s) { if(!root->next[*s-'0']) root->next[*s-'0']=newNode(); root = root->next[*s-'0']; root->cnt++; ++s; } } longlongsearch(Node* root,constchar* now,constchar* s) { longlongcnt = 0; while(*s) { if(*now=='0'&&*s=='1') { if(root->next[1]==NULL) break; root=root->next[1]; } elseif(*now=='0'&&*s=='0') { if(root->next[1]) cnt+=root->next[1]->cnt; if(root->next[0]==NULL) break; root=root->next[0]; } elseif(*now=='1'&&*s=='1') { if(root->next[0]==NULL) break; root = root->next[0]; } elseif(*now=='1'&&*s=='0') { if(root->next[0]) cnt+=root->next[0]->cnt; if(root->next[1]==NULL) break; root=root->next[1]; } ++s; ++now; } returncnt; } string int2str(intvalue) { string s; while(value) { s+=(value%2+'0'); value/=2; } while(s.size()<18) s+='0'; reverse(s.begin(),s.end()); returns; } intmain() { intn,m; cin>>n>>m; intk; Node* root = newNode(); longlongcnt = 0; string sz_m = int2str(m); while(n--) { cin>>k; string s=int2str(k); insert(root, s.c_str() ); cnt+=search(root, s.c_str() ,sz_m.c_str()); } cout<<cnt<<endl; return0; }
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll ans; int m; struct Node { int sz; struct Node *ls,*rs; Node():sz(0),ls(NULL),rs(NULL){} ~Node(){delete ls;delete rs;} }; struct Trie { int B; Node *root; Trie(int B):B(B),root(new Node){} ~Trie(){delete root;} void insert(int k) { Node *now = root; root->sz+=1; for(int i=B-1;i>=0;--i) { int b=((k>>i)&1); if(b==0) { if(now->ls == NULL) now->ls = new Node; now=now->ls; } else { if(now->rs == NULL) now->rs = new Node; now=now->rs; } now->sz+=1; } } void query(int k) { Node *now = root; int x = k^m; for(int i=B-1;i>=0;--i) { int b = ((x>>i)&1); if(((m>>i)&1)^1) { if(b==1&&now->ls!=NULL) ans+=now->ls->sz; else if(b==0&&now->rs!=NULL) ans+=now->rs->sz; } if(b==0) { if(now->ls == NULL) break; now = now->ls; } else { if(now->rs == NULL) break; now = now->rs; } } } }; int main() { Trie t(20); int n; scanf("%d%d",&n,&m); for(int i=0;i<n;++i) { int k ; scanf("%d",&k); t.query(k); t.insert(k); } cout << ans << '\n'; return 0; }
#include <algorithm> #include <stdio.h> #include <string.h> #include <math.h> using namespace std; typedef long long ll; const int maxn=100007; int a[maxn]; struct Node{ int num; Node *next[2]; void init(){ num=0; memset(next,(int)NULL,sizeof(next)); } }newnode[maxn*100]; Node *root; int p; Node* getnewnode(){ newnode[p].init(); return &newnode[p++]; } void init(){ p=0,root=getnewnode(); } void insert(Node *cur,char *s) { if(*s=='\0') return; int index=*s-'0'; if(cur->next[index]==NULL) cur->next[index]=getnewnode(); cur->next[index]->num++; insert(cur->next[index],s+1); } ll query(Node *cur,int curi,char *s,char *sm){ if(!cur) return 0; int si=*s-'0',smi=*sm-'0'; if((curi^si)<smi) return 0; else if((curi^si)==smi) return query(cur->next[0],0,s+1,sm+1)+query(cur->next[1],1,s+1,sm+1); else if((curi^si)>smi) return 1ll*cur->num; return 0; } void getstr(int num,char *s,int n) { int i=n; for(i=0;i<n;i++) s[i]='0'; s[i--]='\0'; while(num){ if(num%2) s[i]='1'; i--,num>>= 1; } } char str[27],strm[27]; int main() { int i,j,n,m,mx=0,mx2=0; scanf("%d%d",&n,&m); for(i=0;i<n;i++) scanf("%d",a+i); sort(a,a+n); mx=max(m,a[n-1]); while(mx) mx2++,mx>>=1; getstr(m,strm,mx2); init(); int tmp,stri; ll ans=0; for(i=0;i<n;i++){ getstr(a[i],str,mx2); ans+=query(root->next[0],0,str,strm)+query(root->next[1],1,str,strm); insert(root,str); } printf("%lld\n",ans); }
import java.util.Scanner; public class Main { private static class TrieTree { TrieTree[] next = new TrieTree[2]; int count = 1; } public static void main(String[] args) { Scanner sc = new Scanner(System.in); while (sc.hasNext()){ int n = sc.nextInt(); int m = sc.nextInt(); int[] a = new int[n]; for (int i = 0; i < n; i++) { a[i] = sc.nextInt(); } System.out.println(solve(a, m)); } } private static long solve(int[] a, int m) { TrieTree trieTree = buildTrieTree(a); long result = 0; for (int i = 0; i < a.length; i++) { result += queryTrieTree(trieTree, a[i], m, 31); } return result / 2; } private static long queryTrieTree(TrieTree trieTree, int a, int m, int index) { if(trieTree == null) return 0; TrieTree current = trieTree; for (int i = index; i >= 0; i--) { int aDigit = (a >> i) & 1; int mDigit = (m >> i) & 1; if(aDigit == 1 && mDigit == 1) { if(current.next[0] == null) return 0; current = current.next[0]; } else if (aDigit == 0 && mDigit == 1) { if(current.next[1] == null) return 0; current = current.next[1]; } else if (aDigit == 1 && mDigit == 0) { long p = queryTrieTree(current.next[1], a, m, i - 1); long q = current.next[0] == null ? 0 : current.next[0].count; return p + q; } else if (aDigit == 0 && mDigit == 0) { long p = queryTrieTree(current.next[0], a, m, i - 1); long q = current.next[1] == null ? 0 : current.next[1].count; return p + q; } } return 0; } private static TrieTree buildTrieTree(int[] a) { TrieTree trieTree = new TrieTree(); for (int i = 0; i < a.length; i++) { TrieTree current = trieTree; for (int j = 31; j >= 0; j--) { int digit = (a[i] >> j) & 1; if(current.next[digit] == null) { current.next[digit] = new TrieTree(); } else { current.next[digit].count ++; } current = current.next[digit]; } } return trieTree; } }
/* C++ 思路来源:潇潇古月 思路: 直接计算肯定是超时的,所以这问题不能使用暴力破解,考虑到从高位到地位,依次进行位运算, 如果两个数异或结果在某高位为1,而m的对应位为0,则肯定任何这两位异或结果为1的都会比m大。 由此,考虑使用字典树(TrieTree)从高位到第位建立字典,再使用每个元素依次去字典中查对应 高位异或为1, 而m为0的数的个数,相加在除以2既是最终的结果;直接贴出代码如下,非原创,欢迎讨论; 补充:queryTrieTree在搜索的过程中,是从高位往低位搜索,那么,如果有一个数与字典中的数异或结果 的第k位大于m的第k位,那么该数与对应分支中所有的数异或结果都会大于m, 否则,就要搜索在第k位异或 相等的情况下,更低位的异或结果。queryTrieTree中四个分支的作用分别如下: 1. aDigit=1, mDigit=1时,字典中第k位为0,异或结果为1,需要继续搜索更低位,第k位为1,异或结果为0,小于mDigit,不用理会; 2. aDigit=0, mDigit=1时,字典中第k位为1,异或结果为1,需要继续搜索更低位,第k位为0,异或结果为0,小于mDigit,不用理会; 3. aDigit=1, mDigit=0时,字典中第k位为0,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为1,异或结果为0,递归获得结果; 4. aDigit=0, mDigit=0时,字典中第k位为1,异或结果为1,与对应分支所有数异或,结果都会大于m,第k位为0,异或结果为0,递归获得结果; 改进: 1.字典树17位即可保证大于100000,移位范围为1~16位,则字典树构建时从16~0即可。 字典树第一层不占位,实际上是15~-1层有数据,这也是数据中next的用法。 2.queryTrieTree函数需要考虑到index为-1时的返回值。 时间复杂度:O(n); 空间复杂度O(k),k为常数(trie树的高度),因此可以认为O(1)。 */ #include <iostream> #include <vector> using namespace std; struct TrieTree { int count;//每个节点存的次数 struct TrieTree* next[2]{NULL,NULL};//每个节点存储两个节点指针 TrieTree():count(1){} }; TrieTree* buildTrieTree(const vector<int>& array) { TrieTree* trieTree = new TrieTree(); for(int i=0;i<(int)array.size();++i) { TrieTree* cur = trieTree; for(int j=16;j>=0;--j) { int digit = (array[i] >> j) & 1; if(NULL == cur->next[digit]) cur->next[digit] = new TrieTree(); else ++(cur->next[digit]->count); cur = cur->next[digit]; } } return trieTree; } //查询字典树 long long queryTrieTree(TrieTree*& trieTree, const int a, const int m, const int index) { if(NULL == trieTree) return 0; TrieTree* cur = trieTree; for(int i=index;i>=0;--i) { int aDigit = (a >> i) & 1; int mDigit = (m >> i) & 1; if(1==aDigit && 1==mDigit) { if(NULL == cur->next[0]) return 0; cur = cur->next[0]; } else if(0 == aDigit && 1==mDigit) { if(NULL == cur->next[1]) return 0; cur = cur->next[1]; } else if(1 == aDigit && 0 == mDigit) { long long val0 = (NULL == cur->next[0]) ? 0 : cur->next[0]->count; long long val1 = queryTrieTree(cur->next[1],a,m,i-1); return val0+val1; } else if(0 == aDigit && 0 == mDigit) { long long val0 = queryTrieTree(cur->next[0],a,m,i-1); long long val1 = (NULL == cur->next[1]) ? 0 : cur->next[1]->count; return val0+val1; } } return 0;//此时index==-1,这种情况肯定返回0(其他情况在循环体中都考虑到了) } //结果可能超过了int范围,因此用long long long long solve(const vector<int>& array, const int& m) { TrieTree* trieTree = buildTrieTree(array); long long result = 0; for(int i=0;i<(int)array.size();++i) { result += queryTrieTree(trieTree,array[i],m,16); } return result /2; } int main() { int n,m; while(cin>>n>>m) { vector<int> array(n); for(int i=0;i<n;++i) cin>>array[i]; cout<< solve(array,m) <<endl; } return 0; }
import java.util.*; /*fine,超时了暂时没招了*/ public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int m,n; n=sc.nextInt();m=sc.nextInt(); int[] A=new int[n]; for(int i=0;i<n;i++) A[i]=sc.nextInt(); f(n,m,A); } public static void f(int n,int m,int[] A) { int ct=0; for(int i=0;i<n;i++) for(int j=i+1;j<n;j++) { int t=A[i]^A[j]; if(t>m)ct++; } System.out.println(ct); } }
#include <bits/stdc++.h> using namespace std; class Trie { public: Trie() : cnt(0) {} void insert(int n) { Trie* p = this; for (int i = 31; i >= -1; --i) { ++p->cnt; if (i == -1) break; int t = !!((1<<i) & n); if (!p->child[t]) p->child[t] = new Trie(); p = p->child[t]; } } long long search(int x, int M) { Trie* p = this; long long ret = 0; for (int i = 31; i >= 0; --i) { if (!p) break; int t0 = !!((1<<i) & x), t1 = !!((1<<i) & M); if (t1 == 1) { p = p->child[t0^1]; } else { ret += p->child[t0^1] ? p->child[t0^1]->cnt : 0; p = p->child[t0]; } } return ret; } Trie* child[2]; long long cnt; }; int main() { int N, M, x; cin >> N >> M; Trie* trie = new Trie(); long long res = 0; while(N--) { scanf("%d", &x); res += trie->search(x, M); trie->insert(x); } cout << res << "\n"; return 0; }
#include<bits/stdc++.h> using namespace std; const int maxn =1e5+7; int tree[2*maxn][2]; int cnt[2*maxn][2]; int tot=0; void ins(int x){ int idx=0; for(int i=17;i>=0;i--){ int bit=(x>>i)&1; if(tree[idx][bit]==0) tree[idx][bit]=++tot; cnt[idx][bit]++; idx=tree[idx][bit]; } } long long ans=0; void fd(int idx,int x,int m,int n_bit){ if(n_bit==-1||idx>tot) return ; if(!tree[idx][0]&&!tree[idx][1]) return ; // this bit of m is 1 int b1=(x>>n_bit)&1; int b2=(m>>n_bit)&1; if(b2==1){ //this bit of m is 1 ,of x is 0 then y must be 1 if(b1==0&&tree[idx][1]) fd(tree[idx][1],x,m,n_bit-1); //this bit of m is 1 ,of x is 1 then y must be 0 if(b1==1&&tree[idx][0]) fd(tree[idx][0],x,m,n_bit-1); } if(b2==0){ if(b1==0){ //this bit of m is 0 ,of x is 0,add number of bit 1 and recursion for bit 0 ans+=cnt[idx][1]; if(tree[idx][0]) fd(tree[idx][0],x,m,n_bit-1); } if(b1==1){ //this bit of m is 0 ,of x is 1,add number of bit 0 and recursion for bit 1 ans+=cnt[idx][0]; if(tree[idx][1]) fd(tree[idx][1],x,m,n_bit-1); } } } int main(){ int N,M; cin>>N>>M; vector<int> vec(N); for(int i=0;i<N;i++){ cin>>vec[i]; ins(vec[i]); } for(int i=0;i<N;i++){ fd(0,vec[i],M,17); } cout<<ans/2<<endl; //3 10 5 6 10 // for(int i=0;i<36;i++){ // cout<<tree[i][0] <<" "<<tree[i][1]<<endl; // } // cout<<"-------------------\n"; // for(int i=0;i<36;i++){ // cout<<tree[i][0] <<" "<<tree[i][1]<<endl; // } }
#include <iostream> using namespace std; #define MAX_N 100000 struct TrieTreeNode { int size = 0; TrieTreeNode* child[2]; }; int minSum; int a[MAX_N]; TrieTreeNode* root; void InsertToTree(int num) { TrieTreeNode* currentNode = root; for (int i = 17; i >= 0; i--) { if (nullptr == currentNode->child[(num >> i) & 1]) { TrieTreeNode* newNode = new TrieTreeNode(); currentNode->child[(num >> i) & 1] = newNode; currentNode = newNode; } else { currentNode = currentNode->child[(num >> i) & 1]; } currentNode->size++; } } int Calculate(int num) { TrieTreeNode* currentNode = root; int numBit, minBit; int result = 0; for (int i = 17; i >= 0; i--) { numBit = (num >> i) & 1; minBit = (minSum >> i) & 1; // 如果m的当前位是0 if (minBit == 0) { // 如果当前节点与numBit不同分支上有节点,则将其size加入result if (nullptr != currentNode->child[!numBit]) { result += (currentNode->child[!numBit]->size); } if (nullptr != currentNode->child[numBit]) { currentNode = currentNode->child[numBit]; continue; } break; } else if (nullptr != currentNode->child[!numBit]) { currentNode = currentNode->child[!numBit]; continue; } break; } return result; } int main() { root = new TrieTreeNode(); minSum = 0; // memset(a, 0, MAX_N * sizeof(int)); int n; cin >> n >> minSum; for (int i = 0; i < n; i++) { cin >> a[i]; InsertToTree(a[i]); } long long total = 0; for (int i = 0; i < n; i++) { total += Calculate(a[i]); } cout << total / 2; }
def get_bin(x): res = [] while x: res.append(str(x&1)) x >>= 1 res += (18-len(res))*['0'] return "".join(res[::-1]) # print(get_bin(100000)) x = input().split() n, m = int(x[0]), int(x[1]) s = [int(i) for i in input().split()] global ans ans = 0 m_bin = get_bin(m) # print(m_bin.find('1')) # print(m_bin) k = m_bin.find('1') bins = {} for i in range(k, 19): bins[i] = {} for num in s: num_bin = get_bin(num) for j in range(k, 19): bins[j][num_bin[:j]] = bins[j].get(num_bin[:j], 0) + 1 for key in bins[k]: ans += bins[k][key] * (n-bins[k][key]) # print(bins) ans //= 2 def search(kk, key1, key2): if kk >= 18: return global ans key1_1 = key1 + '1' key1_0 = key1 + '0' key2_1 = key2 + '1' key2_0 = key2 + '0' if m_bin[kk] == '1': if key1_1 in bins[kk+1] and key2_0 in bins[kk+1]: search(kk+1, key1_1, key2_0) if key1_0 in bins[kk + 1] and key2_1 in bins[kk + 1]: search(kk+1, key1_0, key2_1) else: # print(bins[kk+1]) # print(key1, key2) ans += bins[kk+1].get(key1_0, 0) * bins[kk+1].get(key2_1, 0) + bins[kk+1].get(key1_1, 0) * bins[kk+1].get(key2_0, 0) if key1_1 in bins[kk+1] and key2_1 in bins[kk+1]: search(kk + 1, key1_1, key2_1) if key1_0 in bins[kk+1] and key2_0 in bins[kk+1]: search(kk + 1, key1_0, key2_0) for key in bins[k]: search(k+1, key+"1", key+"0") print(ans)
看到没有Go语言版本的,那我就写个Go语言版本的吧,思路参考第一页的大佬的思路,字典树是个好东西
package main import ( "fmt" ) type TrieTree struct {//01字典树 next [2]*TrieTree count int } func createTrieTree() *TrieTree { return &TrieTree{ next: [2]*TrieTree{}, count: 1, } } func buildTrieTree(trieTree *TrieTree,A []int) *TrieTree { for i := 0; i<len(A); i++ { current := trieTree for j := 31; j >=0 ; j-- { digit := (A[i]>>j) & 1 if current.next[digit] == nil { current.next[digit] = createTrieTree() }else { current.next[digit].count++ } current = current.next[digit] } } return trieTree } func queryTrieTree(trieTree *TrieTree, a int, m int, digitNum int) int { if trieTree == nil { return 0 } current := trieTree; for i := digitNum; i >= 0; i-- { aDigit, mDigit := (a >> i) & 1, (m >> i) & 1; if aDigit == 1 && mDigit == 1 { if current.next[0] == nil { return 0 } current = current.next[0] }else if aDigit == 0 && mDigit == 1 { if current.next[1] == nil { return 0 } current = current.next[1] }else if aDigit == 0 && mDigit == 0{ p := queryTrieTree(current.next[0], a, m, i - 1) var q int if current.next[1] == nil { q = 0 }else{ q = current.next[1].count } return p + q }else if aDigit == 1 && mDigit == 0 { p := queryTrieTree(current.next[1], a, m, i -1) var q int if current.next[0] == nil{ q = 0 }else{ q = current.next[0].count } return p + q } } return 0 } func solve(n int, m int, A []int) int { trieTree := createTrieTree() trieTree = buildTrieTree(trieTree, A) num := 0 for i := 0; i < n; i++ { num += queryTrieTree(trieTree, A[i], m, 31) } return num/2 } func main() { var n int fmt.Scanf("%d", &n) var m int fmt.Scanf("%d", &m) A := make([]int, n) for i:=0; i<n; i++{ fmt.Scanf("%d", &A[i]) } fmt.Print(solve(n, m, A)) }
#include<bits/stdc++.h> using namespace std; struct TrieNode { long path; long end; vector<TrieNode*>map; TrieNode() { path=0; end=0; map.resize(2,NULL); } }; class Trie { public: TrieNode*root; vector<long>arr; long n; public: Trie(){root=new TrieNode();n=0;} void insert(long num); long solve(long m); }; void Trie::insert(long num) { n++; arr.push_back(num); TrieNode* p=root; p->path++; long pow=1; long index; for(long i=16;i>=0;i--) { pow=1<<i; if((num&pow)>0){ index=1; } else{ index=0; } if(p->map[index]==NULL) { p->map[index]=new TrieNode(); } p=p->map[index]; p->path++; } p->end++; }; long Trie::solve(long m) { long res=0; long pow=1; long t1,t2; TrieNode* p=root; for(long i=0;i<n;i++) { p=root; for(long j=16;j>=0;j--) { pow=1<<j; t1=pow&m; t2=pow&arr[i]; if(t1==0&&t2>0) { if(p->map[0]!=NULL){ res=res+p->map[0]->path; } if(p->map[1]!=NULL){ p=p->map[1]; }else{ break; } } else if(t1==0&&t2==0) { if(p->map[1]!=NULL){ res=res+p->map[1]->path; } if(p->map[0]!=NULL){ p=p->map[0]; }else{ break; } } else if(t1>0&&t2>0) { if(p->map[0]!=NULL){ p=p->map[0]; }else{ break; } } else if(t1>0&&t2==0) { if(p->map[1]!=NULL){ p=p->map[1]; }else{ break; } } } } return res/2; } int main() { long res=0; Trie myTrie; long n,m; cin>>n>>m; long num; for(long i=0;i<n;i++) { cin>>num; myTrie.insert(num); } res=myTrie.solve(m); cout<<res<<endl; return 0; }