可持久化线段树
题目链接: F ABCBA
题目要求树上任意两点$u,v$形成的字符串中子序列为ABCBA个数,我们利用线段树维护一条链上各个子串的个数,$u,v$两点的$LCA$为$lca$,设$lca$到$v$形成的字符串为$L$,$lca$到$u$形成的字符串为$R$,则答案既是$$ L.ABCBA + R.ABCBA + L.A * R.BCBA + L.BA * R.CBA + L.CBA * R.BA + L.BCBA * R.A $$ (注意顺序,建树是从上到下建树,题目要求 $u->lca->v$,因此有些字符串顺序要注意) 由于是区间查询子串的个数,我们可以根据父节点建立可持久化线段树,维护树上节点深度上的信息,这样查询某个区间就可以转化为查询深度区间。
#include<bits/stdc++.h>
using namespace std;
#define lowbit( x ) (x&(-x))
typedef long long ll;
const int N = 3e4 + 10;
const int mod = 10007;
int n , q , root[N] , tot , dep[N] , fat[N][22];
int ls[N * 32] , rs[N * 32];//注意这里必须要单独出来,不然合并的时候会导致lc,rc消失
struct tree {
int cnt , A , AB , ABC , ABCB , B , BC , BCB , BCBA , C , CB , CBA , BA;
tree operator + ( const tree &O ) const {
tree res;
res.cnt = ( cnt + O.cnt + A * O.BCBA + AB * O.CBA + ABC * O.BA + ABCB * O.A ) % mod;
res.A = ( A + O.A ) % mod;
res.AB = ( AB + A * O.B + O.AB ) % mod;
res.ABC = ( ABC + A * O.BC + AB * O.C + O.ABC ) % mod;
res.ABCB = ( ABCB + A * O.BCB + AB * O.CB + ABC * O.B + O.ABCB ) % mod;
res.B = ( B + O.B ) % mod;
res.BC = ( BC + B * O.C + O.BC ) % mod;
res.BCB = ( BCB + B * O.CB + BC * O.B + O.BCB ) % mod;
res.BCBA = ( BCBA + B * O.CBA + BC * O.BA + BCB * O.A + O.BCBA ) % mod;
res.C = ( C + O.C ) % mod;
res.CB = ( CB + C * O.B + O.CB ) % mod;
res.CBA = ( CBA + C * O.BA + CB * O.A + O.CBA ) % mod;
res.BA = ( BA + B * O.A + O.BA ) % mod;
return res;
}
} T[N * 22];
void UP ( int &cur , int pre , int l , int r , int p , char ch ) {
cur = ++ tot;
//T[cur] = T[pre];//不影响
ls[cur] = ls[pre];
rs[cur] = rs[pre];
if ( l == r ) {
if ( ch == 'A' ) T[cur].A = 1;
else if ( ch == 'B' ) T[cur].B = 1;
else if ( ch == 'C' ) T[cur].C = 1;
return;
}
int mid = ( l + r ) >> 1;
if ( p <= mid ) UP ( ls[cur] , ls[pre] , l , mid , p , ch );
else UP ( rs[cur] , rs[pre] , mid + 1 , r , p , ch );
T[cur] = T[ls[cur]] + T[rs[cur]];
}
tree query ( int rt , int l , int r , int x , int y ) {
if ( x <= l && r <= y ) return T[rt];
int mid = ( l + r ) >> 1;
if ( y <= mid ) return query ( ls[rt] , l , mid , x , y );
else if ( x > mid ) return query ( rs[rt] , mid + 1 , r , x , y );
else {
tree L , R;
L = query ( ls[rt] , l , mid , x , mid );
R = query ( rs[rt] , mid + 1 , r , mid + 1 , y );
return L + R;
}
}
vector < int > G[N];
char s[N];
void dfs ( int u , int fa ) {
dep[u] = dep[fa] + 1;
UP ( root[u] , root[fa] , 1 , n , dep[u] , s[u] );
for ( auto v : G[u] ) {
if ( v == fa ) continue;
fat[v][0] = u;
for ( int i = 1 ; i <= 18 ; i ++ )
fat[v][i] = fat[fat[v][i - 1]][i - 1];
dfs ( v , u );
}
}
int LCA ( int u , int v ) {
if ( dep[u] < dep[v] ) swap ( u , v );
for ( int j = 18 ; j >= 0 ; j -- ) {
if ( dep[fat[u][j]] >= dep[v] ) u = fat[u][j];
}
if ( u == v ) return v;
for ( int j = 18 ; j >= 0 ; j -- ) {
if ( fat[v][j] != fat[u][j] ) {
v = fat[v][j];
u = fat[u][j];
}
}
return fat[v][0];
}
int main () {
int u , v , ans;
scanf ( "%d%d%s" , &n , &q , s + 1 );
for ( int i = 1 ; i < n ; i ++ ) {
scanf ( "%d%d" , &u , &v );
G[u].push_back ( v );
G[v].push_back ( u );
}
dfs ( 1 , 0 );
tree L , R;
while ( q -- ) {
scanf ( "%d%d" , &u , &v );
int lca = LCA ( u , v );
if ( u == lca )
printf ( "%d\n" , query ( root[v] , 1 , n , dep[u] , dep[v] ).cnt );
else if ( v == lca )
printf ( "%d\n" , query ( root[u] , 1 , n , dep[v] , dep[u] ).cnt );
else {
L = query ( root[v] , 1 , n , dep[lca] + 1 , dep[v] );
R = query ( root[u] , 1 , n , dep[lca] , dep[u] );
ans = ( L.cnt + R.cnt + L.A * R.BCBA + L.BA * R.CBA + L.CBA * R.BA + L.BCBA * R.A ) % mod;
printf ( "%d\n" , ans );
}
}
return 0;
}
我改域名了!http://shanbu.fun 山卜方