双周赛131(240525)

AK的第二场😝, 第四题是用C++ 找网上模板,套用查找最长的连0串过的, 最后两分钟才过,很惊险。赛后发现做法并不优秀,但是不需要平衡树。

3161. 物块放置查询

题目大意:

有一条无限长的数轴,原点在 0 处,沿着 x 轴正方向无限延伸。给定一个二维数组 queries 包含两种操作:

  1. 操作类型 1:在距离原点 x 处建一个障碍物。
  2. 操作类型 2:判断在数轴范围 [0, x] 内是否可以放置一个长度为 sz 的物块,该物块需要完全放置在范围 [0, x] 内。如果物块与任何障碍物有重合,则该物块不能被放置,但物块可以与障碍物刚好接触。

对于操作类型 2 的查询,若可以放置物块,则返回 true,否则返回 false。

实现思路:

  1. 线段树的构建和维护:

    • 使用一个数组 tr 来表示线段树,初始化为全零。数组的长度可以通过查询中的最大范围确定,即 m = max(q[1] for q in queries)
    • 定义 pushup 函数用于更新线段树中的节点信息,每当修改线段树中某个节点的值时调用。
    • 定义 modify 函数用于修改线段树中的值,将某个位置的值更新为新值,并根据需要递归更新父节点的信息。
    • 定义 query 函数用于查询线段树中某个区间内的最大值。
  2. Sorted List 的维护:

    • 使用 SortedList 结构来维护障碍物的位置信息,SortedSet 是一个有序的集合数据结构。
    • 在操作类型 1 中,即添加障碍物时,使用 bisect_left 找到应该插入的位置,并在相应的位置插入新的障碍物位置,并更新线段树的信息。
  3. 查询操作类型 2:

    • 对于每一个操作类型 2 的查询,先找到该位置之前最近的障碍物的位置 pre,然后通过线段树查询该区间内的最大值。
    • 如果该值大于等于所需放置物块的大小,则说明可以放置物块,否则不行。
from sortedcontainers import SortedList
class Solution:
    def getResults(self, queries: List[List[int]]) -> List[bool]:
        m = max(q[1] for q in queries)
        tr = [0]*(1<<m.bit_length() + 1)

        def pushup(u):
            tr[u] = max(tr[u<<1], tr[u<<1|1])

        def modify(u, l, r, x, val):
            if l==r:
                tr[u] = val
            else:
                mid = l+r>>1
                if x<=mid:
                    modify(u<<1, l, mid, x, val)
                else:
                    modify(u<<1|1, mid+1, r, x, val)
                pushup(u)
            
        def query(u, l, r, ql, qr):
            if ql<=l and r<=qr:
                return tr[u]
            
            mid = l+r>>1
            res = 0

            if ql<=mid:
                res = query(u<<1, l, mid, ql, qr)
            if qr>mid:
                res = max(res, query(u<<1|1, mid+1, r, ql, qr))
            
            return res

        sl = SortedList([0, m])
        res = []
        for q in queries:
            x = q[1]
            i = sl.bisect_left(x)
            pre = sl[i-1]
            if q[0]==1:
                nxt = sl[i]
                sl.add(x)
                modify(1, 0, m, x, x-pre)
                modify(1, 0, m, nxt, nxt-x)
            else:
                sz = q[2]
                ans = max(x-pre, query(1, 0, m, 0, pre))
                res.append(ans >= sz)

        return res

比赛时的做法,本来是C++,转成了python

class Node:
    def __init__(self):
        self.l = 0
        self.r = 0
        self.d = 0
        self.ld = 0
        self.rd = 0
        self.low = -1

class Solution:
    def getResults(self, queries):
        def build(u, l, r):
            tr[u].l = l
            tr[u].r = r
            tr[u].d = r - l + 1
            tr[u].ld = r - l + 1
            tr[u].rd = r - l + 1
            tr[u].low = -1
            if l != r:
                mid = (l + r) >> 1
                build(u << 1, l, mid)
                build(u << 1 | 1, mid + 1, r)

        def pushup(u):
            tr[u].ld = tr[u << 1].ld
            if tr[u << 1].ld == tr[u << 1].r - tr[u << 1].l + 1:
                tr[u].ld += tr[u << 1 | 1].ld

            tr[u].rd = tr[u << 1 | 1].rd
            if tr[u << 1 | 1].rd == tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1:
                tr[u].rd += tr[u << 1].rd

            tr[u].d = max(max(tr[u << 1].d, tr[u << 1 | 1].d), tr[u << 1].rd + tr[u << 1 | 1].ld)

        def pushdown(u):
            if tr[u].low == -1:
                return

            tr[u << 1].low = tr[u << 1 | 1].low = tr[u].low
            if tr[u].low:
                tr[u << 1].d = tr[u << 1].ld = tr[u << 1].rd = 0
                tr[u << 1 | 1].d = tr[u << 1 | 1].ld = tr[u << 1 | 1].rd = 0
            else:
                tr[u << 1].d = tr[u << 1].ld = tr[u << 1].rd = tr[u << 1].r - tr[u << 1].l + 1
                tr[u << 1 | 1].d = tr[u << 1 | 1].ld = tr[u << 1 | 1].rd = tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1
            tr[u].low = -1

        def query(u, ql, qr):
            if tr[u].l > qr or tr[u].r < ql:
                return 0
            if tr[u].l >= ql and tr[u].r <= qr:
                return tr[u].d

            pushdown(u)

            mid = (tr[u].l + tr[u].r) >> 1
            res = 0
            if ql <= mid:
                res = max(res, query(u << 1, ql, qr))
            if qr > mid:
                res = max(res, query(u << 1 | 1, ql, qr))

            # Cross left and right subintervals handling
            if ql <= mid and qr > mid:
                left_suffix = min(tr[u << 1].rd, mid - ql + 1)
                right_prefix = min(tr[u << 1 | 1].ld, qr - mid)

                res = max(res, left_suffix + right_prefix)

            return res

        def modify(u, l, r, d):
            if tr[u].l >= l and tr[u].r <= r:
                if d:
                    tr[u].d = tr[u].ld = tr[u].rd = 0
                    tr[u].low = 1
                else:
                    tr[u].d = tr[u].ld = tr[u].rd = tr[u].r - tr[u].l + 1
                    tr[u].low = 0
                return

            pushdown(u)
            mid = (tr[u].l + tr[u].r) >> 1
            if l <= mid:
                modify(u << 1, l, r, d)
            if r > mid:
                modify(u << 1 | 1, l, r, d)
            pushup(u)

        n = max(q[1] for q in queries)
        tr = [Node() for _ in range(1<<n.bit_length() + 1)]
        build(1, 0, n)

        res = []
        modify(1, 0, 0, 1)
        mp = {}

        for q in queries:
            if q[0] == 1:
                x = q[1]
                modify(1, x, x, 1)
                mp[x] = 1
            else:
                x = q[1]
                sz = q[2]

                if x < sz:
                    res.append(False)
                    continue

                if mp.get(x) != 1:
                    modify(1, x, x, 1)

                ans = query(1, 0, x)
                
                if mp.get(x) != 1:
                    modify(1, x, x, 0)

                res.append(ans >= sz-1)

        return res