15天大厂真题带刷 - ZT25 小红的子序列逆序对
小红的子序列逆序对
https://www.nowcoder.com/practice/189a109747604763932024984f856d99
题意
给出一个数组,求该数组所有的子序列中的逆序对数量之和是多少
数组长度1e5
思路
这种求所有子序列/子数组xxx值的和一般都是考虑单个元素的贡献,这里考虑的是单个逆序对的贡献,对于每个逆序对来说,剩下的n-2个元素每个元素都有选或不选两种选择,可以构成的子序列个数为2^(n-2),那么如果有x个逆序对的话,最后答案就是x*2^(n-2)
归并排序:大概的思想是分治,一个区间的逆序对数量=左边逆序对的数量+右边逆序对的数量+跨左右边界的逆序对的数量,前两个都可以在分治的子过程里计算,跨左右边界的话,其实就是考虑在合并2个已经排序的子数组的时候,如果发现逆序对,也就是a[i] > a[j] ,那么逆序对的数量就会增加mid-i+1,其中mid是合并的子数组的中间下标。因为数组是从小到大排序的,如果a[i]>a[j],那么左边数组在下标i后面的元素也一定>a[j]。
树状数组:
这里套了个离散化的板子
代码1
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int mod = 1e9 + 7; /* 归并排序:分治 1.确定分界点 mid 2.递归排序左边和右边:左边跟右边都是有序的了 3.将两个有序的数组合并为一个有序的数组 */ ll a[100010], n, b[100010]; ll ans = 0; ll merge_sort(ll l, ll r) { if (l >= r) return 0; ll mid = (l + r) / 2; ll res = 0; res += merge_sort(l, mid); res += merge_sort(mid + 1, r); ll i = l, j = mid + 1, k = 1; while (i <= mid && j <= r) { if (a[i] <= a[j]) b[k++] = a[i], i++; else { res += mid - i + 1; b[k++] = a[j], j++; } } while (i <= mid) b[k++] = a[i], i++; while (j <= r) b[k++] = a[j], j++; for (int i = l, u = 1; i <= r; i++, u++) a[i] = b[u]; //cout<<l<<" "<<r<<" "<<res<<endl; return res; } ll ksm(ll a, ll b, ll p) { ll res = 1; a %= p; while (b) { //&运算当相应位上的数都是1时,该位取1,否则该为0。 if (b & 1) res = 1ll * res * a % p; //转换为ll型 a = 1ll * a * a % p; b >>= 1; //十进制下每除10整数位就退一位 } return res; } //4 5 6 | 1 2 3 int main() { cin >> n; for (int i = 1; i <= n; i++) cin >> a[i]; ans = merge_sort(1, n); cout << ans*ksm(2,n-2,mod)%mod << endl; return 0; }
代码2
这个a[i]只有1e5, 不用离散化也可以
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int mod = 1e9 + 7; const int N = 1e5 + 7; ll ksm(ll a, ll b, ll p) { ll res = 1; a %= p; while (b) { //&运算当相应位上的数都是1时,该位取1,否则该为0。 if (b & 1) res = 1ll * res * a % p; //转换为ll型 a = 1ll * a * a % p; b >>= 1; //十进制下每除10整数位就退一位 } return res; } ll tr[N], n, a[N]; vector<ll>nums; ll lowbit(int x) { return x & -x; } void update(int x, int c) { for (int i = x; i <= n; i += lowbit(i)) tr[i] += c; } ll query(int x) { int res = 0; for (int i = x; i; i -= lowbit(i)) res += tr[i]; return res; } int main() { cin >> n; for (int i = 1; i <= n; i ++ ) cin >> a[i], nums.push_back(a[i]); sort(nums.begin(), nums.end()); nums.erase(unique(nums.begin(), nums.end()), nums.end()); int m = nums.size(); ll res = 0; for (int i = 1; i <= n; i ++ ) { int tmp = lower_bound(nums.begin(), nums.end(), a[i]) - nums.begin() + 1; res += query(m) - query(tmp); update(tmp, 1); } cout << res*ksm(2,n-2,mod) % mod << endl; return 0; }#牛客创作赏金赛#
15天大厂真题带刷Go题解 文章被收录于专栏
15天大厂真题带刷Golang题解