题解 | #HH的项链#

[SDOI2009]HH的项链

https://ac.nowcoder.com/acm/contest/19684/B

B HH的项链

在线做法:主席树

思路

last[i]: 数字 i 上一次出现的位置,第一次出现则为 0
w[i]: 位置 i 上的数字,上一次出现的位置,第一次出现则为 0

求[l,r]内的数字种类,即求 w[i] < l 的数量,i属于 [l,r]

ps: w[i] 取整范围为 [0,n-1],为避免出现 0,所以我们将 w[i] 加 1,因此 w[i] 可取值 [1,n]

Code

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6+10;

int n,m,idx;
int rt[N],last[N],w[N];
struct node{
    int l,r;
    int cnt;
}tr[N*30];

int build(int l,int r){
    int p=++idx;
    
    if(l==r) return p;
    
    int mid=l+r>>1;
    tr[p].l=build(l,mid);
    tr[p].r=build(mid+1,r);
    
    return p;
}

int insert(int p,int l,int r,int pos){
    int q=++idx;
    tr[q]=tr[p];
    
    if(l==r){
        tr[q].cnt++;
        return q;
    }
    
    int mid=l+r>>1;
    if(pos<=mid) tr[q].l=insert(tr[p].l,l,mid,pos);
    else tr[q].r=insert(tr[p].r,mid+1,r,pos);
    tr[q].cnt=tr[tr[q].l].cnt+tr[tr[q].r].cnt;
    
    return q;
}

int query(int p,int q,int l,int r,int x){
    if(l==r)  return tr[q].cnt-tr[p].cnt;
    
    int mid=l+r>>1;
    int sum=tr[tr[q].l].cnt-tr[tr[p].l].cnt;
    
    if(x<=mid) return query(tr[p].l,tr[q].l,l,mid,x);
    else return sum+query(tr[p].r,tr[q].r,mid+1,r,x);
}

int main(){
    scanf("%d",&n);
    
    for(int i=1;i<=n;i++) {
        int x; scanf("%d",&x); 
        w[i]=last[x],last[x]=i;
    }
    
    rt[0]=build(1,n);
    for(int i=1;i<=n;i++) rt[i]=insert(rt[i-1],1,n,w[i]+1);
    
    scanf("%d",&m);
    
    while(m--){
        int l,r; scanf("%d%d",&l,&r);
        printf("%d\n",query(rt[l-1],rt[r],1,n,l));
    }
    
    return 0;
}

离线做法:树状数组 + 排序

思路

对于以 r 作为区间右端点的区间 [l,r] 而言,每种数字最后一次出现的位置对结果有贡献,l <= r
所以我们按照 r 从小到大排序,并使用树状数组来维护每个位置的贡献

Code

#include <bits/stdc++.h>

using namespace std;

const int N = 1e6+10;

int n,m;
int w[N],tr[N],last[N],ans[N];
struct node{
    int id;
    int l,r;
    bool operator < (const node &x) const {
        return r < x.r;
    }
}p[N];

int lowbit(int x){
    return x & -x;
}

void add(int x,int v){
    for(int i=x;i<=n;i+=lowbit(i)) tr[i]+=v;
}

int sum(int x){
    int res=0;
    for(int i=x;i;i-=lowbit(i))  res+=tr[i];
    return res;
}

int main(){
    cin>>n;
    
    for(int i=1;i<=n;i++) cin>>w[i];
    
    cin>>m;
    
    for(int i=1;i<=m;i++) {
        int l,r; cin>>l>>r;
        p[i]={i,l,r};
    }
    
    sort(p+1,p+1+m);
    
    for(int i=1,j=1;i<=m;i++){ 
       //可能有 j > r 但 i 仍然 <=m 的情况存在,所以不能加条件 j<=r
       int l=p[i].l,r=p[i].r,id=p[i].id;
       
        while(j<=r){
            if(last[w[j]])  add(last[w[j]],-1);
            add(j,1);
            last[w[j]]=j;
            j++;
        }
        
        ans[id]=sum(r)-sum(l-1);
    }
    
    for(int i=1;i<=m;i++) cout<<ans[i]<<endl;
    
    return 0;
}
全部评论

相关推荐

11-22 16:49
已编辑
北京邮电大学 Java
美团 质效,测开 n*15.5
点赞 评论 收藏
分享
10-30 10:16
南京大学 Java
龚至诚:给南大✌️跪了
点赞 评论 收藏
分享
点赞 评论 收藏
分享
3 收藏 评论
分享
牛客网
牛客企业服务