views
Word count: 1.4k
(~7 mins to read)
Last updated:
原题地址:https://www.luogu.org/problemnew/show/P3384
题目简述
给定一些序列(没有重复数字),每个序列支持:
给定一些数k(对于每个序列不重复),每次在序列里找到最接近k的数删除(如果有2个数字与k差一样,即分别是k-b和k+b,则选择较小的k-b),累加与k的差,输出。
思路
其实关键就是维护一个有序序列,支持插入,查询前继后继,删除指定数字。
自然我们会想到手打平衡树,Treap/Splay皆可。(这里只有旋转实现的Treap,非旋Treap(Split+Merge)和Splay日后加上)
Tips:为了防止越界等问题以及方便提取区间(尤其是Splay),序列前后一般塞上一个-INF和INF
然而作为C++选手,我们应该妙用STL。set可以实现这样的功能,内部是红黑树实现的也很快。
代码
- 旋转实现的Treap(160ms,3.03MB)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
using namespace std;
const int INF=1e9;
inline int randad(){
static int seed=114514;
return seed=int(seed*48271LL%2147483647);//48271使得随机数有完全周期,即2147483647内取遍不重复
}
int delta=0;
struct node {
int pri,val,ch[2],size,tot;
//pri:Treap的随机数
//val:数字
//ch[0,1]:左孩子右孩子
//size:以该节点为根的子树里有几个数字
//tot:这个数字出现了几次(本题无用)
}T[111111];
int k,size=0,ANS,ans;//k:根节点,size:树的大小,ANS:临时,ans:赶走了几个人
void update(int k){T[k].size=T[T[k].ch[0]].size+T[T[k].ch[1]].size+T[k].tot;}
void rturn(int &k)//右旋,把k旋到右边,k左孩子提到根
{
int t=T[k].ch[0];
T[k].ch[0]=T[t].ch[1];
T[t].ch[1]=k;
T[t].size=T[k].size;
update(k);
k=t;
}
void lturn(int &k)//左旋,把k旋到左边,k右孩子提到根
{
int t=T[k].ch[1];
T[k].ch[1]=T[t].ch[0];
T[t].ch[0]=k;
T[t].size=T[k].size;
update(k);
k=t;
}
void ins(int &k,int val) //插入
{
if (k==0) {
size++;
k=size;
T[k].pri=randad();
T[k].val=val;
T[k].size=T[k].tot=1;
return ;
}
T[k].size++;
if (T[k].val==val) T[k].tot++;
else if (val>T[k].val) {
ins(T[k].ch[1],val);
if (T[T[k].ch[1]].pri<T[k].pri) lturn(k);
} else {
ins(T[k].ch[0],val);
if (T[T[k].ch[0]].pri<T[k].pri) rturn(k);
}
}
void del(int &k,int val)//删除值为val的数
{
if (k==0) return ;
if (T[k].val==val) {
if (T[k].tot>1) {
T[k].tot--;
T[k].size--;
return ;
}
if (T[k].ch[0]==0||T[k].ch[1]==0) k=T[k].ch[0]+T[k].ch[1];
else if(T[T[k].ch[0]].pri<T[T[k].ch[1]].pri) rturn(k),del(k,val);
else lturn(k),del(k,val);
} else if(val>T[k].val) T[k].size--,del(T[k].ch[1],val);
else T[k].size--,del(T[k].ch[0],val);
}
int xth(int &k,int x)//查询第x小的数是什么
{
if(k==0||x==0)return 0;
if(x<=T[T[k].ch[0]].size) return xth(T[k].ch[0],x);
else if(x>T[T[k].ch[0]].size+T[k].tot) return xth(T[k].ch[1],x-T[T[k].ch[0]].size-T[k].tot);
else return T[k].val;
}
int find(int &k,int x)//查询第x小数在树中位置
{
if (k==0||x==0) return 0;
if(x<=T[T[k].ch[0]].size)return find(T[k].ch[0],x);
if(x==T[T[k].ch[0]].size+1)return k;
return find(T[k].ch[1],x-T[T[k].ch[0]].size-1);
}
void pre(int k,int x)//查询不比x大的且最接近x的数所在位置(x前继)
{
if(k==0)return;
if(T[k].val<x) ANS=k,pre(T[k].ch[1],x);
else pre(T[k].ch[0],x);
}
void next(int k,int x)//查询不比x小的且最接近x的数所在位置(x后继)
{
if(k==0)return;
if(T[k].val>x) ANS=k,next(T[k].ch[0],x);
else next(T[k].ch[1],x);
}
void Catch(int num)//匹配宠物和饲养人
{
int a,b;
pre(k,num),a=T[ANS].val;
next(k,num), b=T[ANS].val;
if(num-a<=b-num && a != -INF) {
ans += num-a;
del(k,a);
} else {
ans += b-num;
del(k,b);
}
ans %= 1000000;
}
int main()
{
int n;
scanf("%d", &n);
int cur;
ins(k,-INF),ins(k,INF);
for(int i = 1; i <= n; i++) {
int a, b;
scanf("%d%d", &a, &b);
if(T[k].size == 2) {
cur=a;//cur:当前是宠物等人认领还是人在等着接受宠物(看原题,不然谁看得懂啊= =)
ins(k,b);
} else if(a == cur) ins(k,b);
else Catch(b);
}
printf("%d\n", ans);
return 0;
return 0;
}
```
2. set实现(304ms,2.57MB)
```cpp
using namespace std;
const int maxn = 1111111;
const int INF = 1000000000;
int n, ans;
set <int> s;
void find(int x) {
set<int>::iterator left=--s.lower_bound(x),right=s.lower_bound(x);//lower_bound的实现是二分查找,迭代器指向不比x小的且最接近x的数的位置,所以left就是前继,right就是后继
if(x-*left<=*right-x&&*left!=-INF) {
ans+=x-*left;
s.erase(left);
} else {
ans+=*right-x;
s.erase(right);
}
ans%=1000000;
}
int main()
{
scanf("%d",&n);
int cur;
s.insert(-INF),s.insert(INF);
for(int i=1;i<=n;i++) {
int a,b;
scanf("%d%d", &a, &b);
if(s.size()==2) {
cur=a;
s.insert(b);
} else if(a==cur) s.insert(b);
else find(b);
}
printf("%d\n", ans);
return 0;
}