USACO 2021 US Open 铂金组 T1 题解

想了挺久

United Cows of Farmer John

题意:

给定一个长度为 n 的数列 $a_i$,求符合下列条件的三元组 $\{i,j,k\}$ 个数:

  • $1\le i < j < k \le n$
  • $a_i,a_j,a_k$ 的值在 $a_{[i,k]}$ 各只出现一次(即他们自己)。

思路:

为了方便描述,我们称 $a_i$ 为 i 点的颜色,记 $pre_i$ 表示 $a_i$ 的前驱(即 i 之前与 $a_i$ 相同的最后一个数),记集合 $last$ 表示右端点扫描到 r 时,每个颜色最后一次出现的位置。

朴素想法

考虑固定右端点 r,然后枚举左端点 l 的位置,求此时 m 有多少种选择。发现答案即为 $a_{[l+1,r-1]}$ 中仅出现过一次的颜色数。

为了方便计算,我们可以使用代表元法。即随着 r 的向右推进,维护一个序列 $appear_i$:

$$ appear_i=\left\{ \begin{aligned} 1\qquad&(a_{[i+1,r]}中没有与a_i相同的颜色)\cr -1\qquad&(a_{[i+1,r]}中恰有一个与a_i相同的颜色)\cr 0\qquad&(a_{[i+1,r]}中有多于一个与a_i相同的颜色) \end{aligned} \right. $$

具体的维护方式就是每当 r 右移到新的一位,就进行以下操作:

1
2
3
appear[i]++;
appear[pre[i]]-=2;
appear[pre[pre[i]]]++;

然后每次暴力枚举 l,并将后缀和加起来,我们就得到了一个 $O(n^2)$ 的做法。

注: 当然,在 r 固定的情况下,l 也是有限制的,即:$l > pre[r]$ 并且 $l \in last$。

线段树优化

虽然在枚举 l 的时候有 $l \in last$ 的限制,但是这个二重后缀求和仍有希望用数据结构维护。

记数组 $sumappear_i$ 表示以下的内容:

$$ sumappear_i=\left\{ \begin{aligned} \sum_{j=i+1}^{r-1}appear_j \qquad&(i \in last)\cr 0\qquad&(i \notin last) \end{aligned} \right. $$

则对 $appear_i$ 的修改可以转化为对 $sumappear$ 前缀第 $j(j\in last \cap [1,i])$ 位的修改。而枚举 l 的过程也可以转化为对 $sumappear$ 第 $j,j\in[pre_r+1,r-1]$ 的求和。

考虑用线段树维护这个信息,在每个节点上不仅要存区间的权值和,还要存一个 cnt 表示区间内有多少项在 $last$ 集合里,每当 r 右移到新的一位,就将 $sumappear_i$ 的 cnt 设为 1,并将 $sumappear_{pre_i}$ 的 cnt 设为 0,给区间打 tag 时也按照 sum+=cnt*tag 加,于是这就是一个 $O(n\log n)$ 的算法。

代码

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include<bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define rd scanf
#define pt printf
#define setp(x) setprecision(x)
#define mem(x,y) memset(x,y,sizeof(x))
#define umn(x,y) x=min(x,y)
#define umx(x,y) x=max(x,y)
typedef long long ll;
typedef unsigned long long ull;
typedef pair<ll,ll> pi;
typedef queue<ll> qi;
typedef vector<ll> vi;
typedef vector<pi> vpi;

const ll INF=1e18;
const ll P=1e9+7;
const ll MXN=2e5+5;

inline ll mod(const ll &x){return x>=P?x-P:x;}
inline ll mul(ll x,ll y){if(x<y)swap(x,y);ll res=0;while(y){if(y&1)res=mod(res+x);x=mod(x<<1);y>>=1;}return res;}
inline ll qpow(ll x,ll y){ll res=1;while(y){if(y&1)res=res*x%P;x=x*x%P;y>>=1;}return res;}
ll gcd(const ll &x,const ll &y){return !y?x:gcd(y,x%y);}
inline void pbin(const ll &x,ll y){for(;~y;y--)cerr<<char('0'+((x>>y)&1));cerr<<endl;}

ll n,arr[MXN],res;
ll last[MXN],pre[MXN];

#define ls (p<<1)
#define rs (p<<1|1)
const ll NR=MXN<<2;
struct node{
	ll va,sz,tag;
	inline node(ll va=0,ll sz=0,ll tag=0):va(va),sz(sz),tag(tag){}
	inline void addt(ll k){
		tag+=k;
		va+=sz*k;
	}
	inline node operator + (const node &y)const{return node(va+y.va,sz+y.sz);}
}t[NR];
inline void pushu(ll p){t[p]=t[ls]+t[rs];}
inline void pushd(ll p){
	t[ls].addt(t[p].tag);
	t[rs].addt(t[p].tag);
	t[p].tag=0;
}
void build(ll p,ll l,ll r){
	if(l==r){
		t[p]=node(-2,1);
		return;
	}
	ll mid=(l+r)>>1;
	build(ls,l,mid);
	build(rs,mid+1,r);
	pushu(p);
}
void del(ll p,ll l,ll r,ll dx){
	if(l==r){
		t[p].sz=t[p].va=0;
		return;
	}
	ll mid=(l+r)>>1;
	pushd(p);
	if(dx<=mid)del(ls,l,mid,dx);
	else del(rs,mid+1,r,dx);
	pushu(p);
}
void mod(ll p,ll l,ll r,ll ml,ll mr,ll mx){
	if(ml<=l && r<=mr){
		t[p].addt(mx);
		return;
	}
	ll mid=(l+r)>>1;
	pushd(p);
	if(ml<=mid)mod(ls,l,mid,ml,mr,mx);
	if(mr>mid)mod(rs,mid+1,r,ml,mr,mx);
	pushu(p);
}
node que(ll p,ll l,ll r,ll ql,ll qr){
	if(ql<=l && r<=qr)return t[p];
	ll mid=(l+r)>>1;node res;
	pushd(p);
	if(ql<=mid)res=res+que(ls,l,mid,ql,qr);
	if(qr>mid)res=res+que(rs,mid+1,r,ql,qr);
	return res;
}


inline void solve(){
	scanf("%lld",&n);
	build(1,1,n);
	for(int i=1;i<=n;i++){
		scanf("%lld",arr+i);
		pre[i]=last[arr[i]];
		last[arr[i]]=i;
		
		if(pre[i])del(1,1,n,pre[i]);
		mod(1,1,n,1,i,1);
		if(pre[i])mod(1,1,n,1,pre[i],-2);
		if(pre[pre[i]])mod(1,1,n,1,pre[pre[i]],1);
		if(pre[i]+2<=i)res+=que(1,1,n,pre[i]+1,i-1).va;
	}
	printf("%lld",res);
    
}
int main(){
	//freopen(".in","r",stdin);
	//freopen(".out","w",stdout);
	//ios_base::sync_with_stdio(0);cin.tie(0);
	cout<<setiosflags(ios::fixed);
	srand(time(NULL));

	ll tq=1;
	//cin>>tq;
	while(tq--)solve();
	return 0;
}