题解 | #小苯的蓄水池(hard)#

小苯吃糖果

https://ac.nowcoder.com/acm/contest/93847/A

我曾经是 acmer,现在是社畜,但是最近心血来潮想刷题,发现线段树都坑点 这里给出这个问题的线段树题解

首先,先写一个暴力版本的,也就是 easy

const ll inf = (1LL << 60);

void solve() {
    int n, q;
    cin >> n >> q;

    vector<double> a(n+1, 0.0), s(n+1, 0);
    double sum = 0;
    for (int i = 1; i <= n; i++) cin >> a[i], sum += a[i];
    for (int i = 1; i <= n; i++) s[i] = s[i-1] + a[i];

    vector<int> le(n+1, 0), ri(n+1, 0);
    for (int i = 1; i <= n; i++) le[i] = ri[i] = i;

    while (q--) {
        int op;
        cin >> op;

        if(op == 1) {
            int l, r;
            cin >> l >> r;
            //if (l == r) continue;
            //debug(l);

            int low = l, high = r;
            for (int i = l; i <= r; i++) chmin(low, le[i]), chmax(high, ri[i]);
            // debug(cl), debug(cr);

            // double tot = 0;
            for (int i = low; i <= high; i++) le[i] = low, ri[i] = high;

            for (int i = low; i <= high; i++) {
                int len = high - low + 1;
                double tot = s[high] - s[low - 1]; 
                a[i] = (tot) * 1.0 / (1.0 * len);
            }
        }
        else {
            int id;
            cin >> id;

            double res = a[id];
            printf("%.10lf\n", res);
        }
    }
}

相信上面那个版本,大家都没有什么问题 我们要优化的是几个点

for (int i = l; i <= r; i++) chmin(low, le[i]), chmax(high, ri[i]);

这个是要维护每一段区间 [l, r] 最左边的位置,和最右边的位置
所以线段树中要维护 le, ri 的信息

for (int i = low; i <= high; i++) le[i] = low, ri[i] = high;
for (int i = low; i <= high; i++) {
      int len = high - low + 1;
      double tot = s[high] - s[low - 1];
      a[i] = (tot) * 1.0 / (1.0 * len);
}
这个对应区间修改,[low, high] 区间内的所有点,le 改成 low,ri 改成 high
同时我们还需要把区间内所有的数都改成一个数

那么,维护平均数 avg 是最方便的,因为操作下来,一整段的平均数都会是一样的
都改成一个数,区间的平均数都是修改后的 tot / len

比赛的时候,线段树模版比较老旧,我更新了一下线段树的模版,大家可以参考


const double eps = 1e-6;

// 根据我们上面的分析
// 我们需要维护 le, ri, 以及区间和
//(这里不取平均数,怕精度丢失,最后计算时候转平均数,所以额外维护区间长度)
struct S {
    int le, ri, len;
    double sum;
};

// 定义线段树儿子合并操作
S op(S a, S b) {
    S res;
    res.sum = a.sum + b.sum;
    res.len = a.len + b.len;
    res.le = min(a.le, b.le);
    res.ri = max(a.ri, b.ri);
    return res;
}

// F 是算子,也就是线段树中的懒标记
// 这里线段树中只有赋值操作,因为根据上面的分析,我们是修改整段区间的值
// F f
// (f, node),我们把 node 区间的值,node.le = f.le 以此类推
// 我们需要修改 le,ri,以及平均值
struct F {
    int le, ri;
    double avg;

    friend bool operator== (const F &lhs, const F &rhs) {
        bool ok = true;
        ok &= (lhs.le == rhs.le && lhs.ri == rhs.ri);
        ok &= (fabs(lhs.avg - rhs.avg) <= eps);
        return ok;
    }
};

// mapping 是区间的映射
// 也就是如果把节点 x,上面有懒标记,我们要对 x 进行修改
// x.le -> f.le,   x.ri -> f.ri,  x.sum -> f.avg * x.len
// 这里为什么 f 存数据存 avg 平均值
// 因为挡板拿掉之后,所有点的最终水位都一样,也就是说把一整个区间的值都改成它的平均值
S mapping(F f, S x) {
    S res = x;
    res.le = f.le, res.ri = f.ri;
    res.sum = 1.0 * f.avg * x.len;
    return res;
}

// 懒标记复合,意思是
// 比如之前的操作,节点上已经有懒标记 g 了
// 又进行了 标记 f
// 因为本题是赋值操作,后面的懒标记会覆盖前面的
// 其他题不一定能这么写,有的题目可能是给区间增加一些数
// 之前加了 g,后来又加了 f,那么要 return f+g
// 这个具体情况具体分析
F comp(F f, F g) {
    return f;
}

// 哨兵节点
// 如果懒标记 = id,那么说明这个点标记已经压下去了
// 这个点的标记清空
// push 操作的时候,如果发现一个点的标记是 id,不需要进行压标记操作了

F id() {
    return F{-1, -1, 0.0};
}

template<class S, S (*op)(S, S), class F, S (*mapping)(F, S), F (*comp)(F, F), F (*id)()>
class lazy_segtree {
public:
    int _n;
    vector<S> node;
    vector<F> lz;

    explicit lazy_segtree(const vector<S> &v, int n) : _n(n) {
        node = vector<S>(_n*4, S{});
        lz = vector<F>(_n*4, id());

        function<void(int, int, int)> build = [&](int p, int l, int r) -> void {
            if (l >= r) {
                node[p] = v[l];
                return;
            }
            int mid = (l + r) >> 1;
            build(p<<1, l, mid), build(p<<1|1, mid+1, r);
            pull(p);
        };

        build(1, 1, _n);
    }

    inline void pull(int p) {
        node[p] = op(node[p<<1], node[p<<1|1]);
    }
    inline void push(int p) {
        if (lz[p] == id()) return;

        auto apply = [&](int p, F f) -> void {
            node[p] = mapping(f, node[p]);
            lz[p] = comp(f, lz[p]);
        };
        apply(p<<1, lz[p]), apply(p<<1|1, lz[p]);
        lz[p] = id();
    }

    // a[pos] -> x
    void change(int p, int l, int r, int pos, S x) {
        if (l >= r) {
            node[p] = x;
            return;
        }
        int mid = (l + r) >> 1;
        push(p);

        if (pos <= mid) change(p<<1, l, mid, pos, x);
        else change(p<<1|1, mid+1, r, pos, x);

        pull(p);
    }

    void modify(int p, int l, int r, int ql, int qr, F f) {
        if (ql == l && r == qr) {
            node[p] = mapping(f, node[p]);
            lz[p] = comp(f, lz[p]);
            return;
        }

        int mid = (l + r) >> 1;
        push(p);

        if (qr <= mid) modify(p<<1, l, mid, ql, qr, f);
        else if (ql > mid) modify(p<<1|1, mid+1, r, ql, qr, f);
        else {
            modify(p<<1, l, mid, ql, mid, f);
            modify(p<<1|1, mid+1, r, mid+1, qr, f);
        }

        pull(p);
    } 

    S query(int p, int l, int r, int ql, int qr) {
        if (ql == l && r == qr) return node[p];
        int mid = (l + r) >> 1;

        push(p);
        
        if (qr <= mid) return query(p<<1, l, mid, ql, qr);
        else if (ql > mid) return query(p<<1|1, mid+1, r, ql, qr);
        else {
            return op( query(p<<1, l, mid, ql, mid),
                    query(p<<1|1, mid+1, r, mid+1, qr) );
        }
    }

    // return r <= qr, satisfy:
    // f( op(a[ql], a[ql+1], ..., a[r-1]) ) = true
    // f( op(a[ql], a[ql+1], ..., a[r]) ) = false
    // -1, we cannot find such r
    // we want to find first position f() false!
    // find in [ql, qr]
    // usage: max_right<f> (1, 1, _n, ql, qr)
    template<bool (*g)(S)>
    int max_right(int p, int l, int r, int ql, int qr) const {
        if (ql == _n) return _n;
        //debug(p);

        if (ql == l && r == qr) {
            //debug(l);
            if ( g(node[p]) ) return -1;
            if (l == r) return l;

            int mid = (l + r) >> 1;
            push(p);
            if ( !g(node[p<<1]) ) return max_right<g>(p<<1, l, mid, ql, mid);
            else return max_right<g>(p<<1|1, mid+1, r, mid+1, qr);
        }

        int mid = (l + r) >> 1;
        push(p);
        if (qr <= mid) return max_right<g>(p<<1, l, mid, ql, qr);
        else if (ql > mid) return max_right<g>(p<<1|1, mid+1, r, ql, qr);
        else {
            int pos = max_right<g>(p<<1, l, mid, ql, mid);
            if (pos != -1) return pos;
            else return max_right<g>(p<<1|1, mid+1, r, mid+1, qr);
        }
    }

    // find p
    // f( op(a[qr], a[qr-1], ..., a[p]) ) = true;
    // f( op(a[qr], a[qr-1], ..., a[p-1]) ) = false;
    // -1 cannot find such p
    // we want to find p <= qr, left most position which is satisfied f() = true
    // find in [ql, qr]
    // usage: max_left<f> (1, 1, _n, ql, qr)

    template<bool (*g)(S)>
    int min_left(int p, int l, int r, int ql, int qr) const {
        if (qr == 0) return 0;

        if (ql == l && r == qr) {
            if ( !g(node[p]) ) return -1;
            if (l == r) return l;

            int mid = (l + r) >> 1;
            push(p);
            if ( g(node[p<<1]) ) return min_left<g>(p<<1, l, mid, ql, mid);
            else return min_left<g>(p<<1|1, mid+1, r, mid+1, qr);
        }
        int mid = (l + r) >> 1;
        push(p);

        if (qr <= mid) return min_left<g>(p<<1, l, mid, ql, qr);
        else if (ql > mid) return min_left<g>(p<<1|1, mid+1, r, ql, qr);
        else {
            int pos = min_left<g>(p<<1, l, mid, ql, mid);
            if (pos != -1) return pos;
            else return min_left<g>(p<<1|1, mid+1, r, mid+1, qr);
        }
    }
};


const ll inf = (1LL << 60);

void solve() {
    int n, q;
    cin >> n >> q;

    vector<double> a(n+1, 0.0), s(n+1, 0);
    double sum = 0;
    for (int i = 1; i <= n; i++) cin >> a[i], sum += a[i];
    for (int i = 1; i <= n; i++) s[i] = s[i-1] + a[i];

    vector<S> vec(n+1, S{});
    for (int i = 1; i <= n; i++) vec[i] = S{ i, i, 1, a[i] };
    lazy_segtree<S, op, F, mapping, comp, id> tr(vec, n);

    // vector<int> le(n+1, 0), ri(n+1, 0);
    // for (int i = 1; i <= n; i++) le[i] = ri[i] = i;

    while (q--) {
        int op;
        cin >> op;

        if(op == 1) {
            int l, r;
            cin >> l >> r;
            //if (l == r) continue;
            //debug(l);

            int low = l, high = r;
            // for (int i = l; i <= r; i++) chmin(low, le[i]), chmax(high, ri[i]);
            // debug(cl), debug(cr);
            auto cur = tr.query(1, 1, n, l, r);
            chmin(low, cur.le), chmax(high, cur.ri);

            // double tot = 0;
            // for (int i = low; i <= high; i++) le[i] = low, ri[i] = high;
            double tot = s[high] - s[low - 1];
            double avg = tot * 1.0 / (high - low + 1);
            auto fun = F{ low, high, avg };
            tr.modify(1, 1, n, low, high, fun);

            // for (int i = low; i <= high; i++) {
            //     int len = high - low + 1;
            //     double tot = s[high] - s[low - 1]; 
            //     a[i] = (tot) * 1.0 / (1.0 * len);
            // }
        }
        else {
            int id;
            cin >> id;

            auto res = tr.query(1, 1, n, id, id);
            printf("%.10lf\n", res.sum);
        }
    }
}

int main() {
    freopen("input.txt", "r", stdin);
    ios::sync_with_stdio(false), cin.tie(0);

    int cas = 1;
    // cin >> cas;
    while (cas--) {
        solve();
    }
}
全部评论

相关推荐

01-16 10:30
已编辑
华南师范大学 Java
点赞 评论 收藏
分享
评论
1
收藏
分享

创作者周榜

更多
牛客网
牛客企业服务