思想
点分治是一种用于处理静态树上路径统计问题的算法,其核心原理还是基于分治思想。
我们不妨对这类树上静态路径统计问题抽象化,例如:
给定一个棵无根树,求满足要求P的路径有几条(点对)
我们可以使用分治算法求解本题,对于无根树,我们指定一个节点为根,显然,这样的路径可以归为两类:
- 满足要求P且经过根root的路径
- 满足要求P但不经过根root的路径
第二类路径不经过根,所以一定在根的某一刻子树中,这我们就可以递归地进行求解,而重点就在于统计第一类路径。 先不考虑统计问题,我们可以设计如下的分治算法:
- 求出根root
- 得到相关信息,统计第一类路径并累加入答案中
- 删除节点root ,对于root 的每一棵子树递归的执行分治算法
显然,递归的层数和每一次选取的根root有关,当树退化成一条链时,如果任意的选取根,递归层数很可能就达到$O(n)$层。解决方法就是选取树的重心作为树的根,容易证明,每一次选取树的重心作为根时,点分治算法最多递归$log_2n$层。
一般我们有下面几种方法来计算以当前重心作为根对答案的贡献
第一种
先一次性的计算出当前根$u$所有子树的贡献,这里面包括了经过当前树根的路径和不经过当前树根的路径,因此我们还要遍历$u$的所有子树,减去各个子树单独的贡献,即不经过当前树根$u$的路径数量。
POJ1741
按照dis从小到大排序,对于每个i维护一个R使得区间[i+1,R]内所有点都满足,可知R是递减的,先算总的贡献,再减去各个子树的贡献。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 1e4 + 10;
const int M = 1e6 + 10;
const int mod = 998244353;
const int inf = 0x3f3f3f3f;
const int seed = 131;
int n , m , k;
int head[N] , tot , root , cnt , MX , ans , cursz , vis[N] , ver[N << 1] , w[N << 1] , nxt[N << 1] , sz[N];
ll dis[N] , dist[N];
void addedge ( int from , int to , int len ) {
ver[++ tot] = to;
w[tot] = len;
nxt[tot] = head[from];
head[from] = tot;
}
void get ( int u , int fa ) {
sz[u] = 1;
int mx = 0;
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) {
get ( v , u );
sz[u] += sz[v];
mx = max ( mx , sz[v] );
}
}
mx = max ( mx , cursz - sz[u] );
if ( mx < MX ) root = u , MX = mx;
}
void dfs1 ( int u , int fa ) {
dist[++ cnt] = dis[u];
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) {
dis[v] = dis[u] + w[i];
dfs1 ( v , u );
}
}
}
int cal ( int u ) {
int res = 0;
cnt = 0;
dfs1 ( u , 0 );
sort ( dist + 1 , dist + cnt + 1 );
int L = 1 , R = cnt;
while ( L < R ) {
while ( L < R && dist[L] + dist[R] > k ) -- R;
res += R - L;
++ L;
}
return res;
}
void dfs ( int u ) {
vis[u] = 1;
dis[u] = 0;
ans += cal ( u );
get ( u , 0 );
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] ) {
//dis[v] = w[i];
ans -= cal ( v );
MX = inf;
cursz = sz[v];
get ( v , 0 );
dfs ( root );
}
}
}
int main () {
int u , v , len;
while ( scanf ( "%d%d" , &n , &k ) != EOF && ( n || k ) ) {
tot = 0;
root = 0;
ans = 0;
for ( int i = 0 ; i <= n ; ++ i ) head[i] = - 1 , vis[i] = 0;
for ( int i = 1 ; i < n ; ++ i ) {
scanf ( "%d%d%d" , &u , &v , &len );
addedge ( u , v , len );
addedge ( v , u , len );
}
MX = inf;
cursz = n;
get ( 1 , 0 );
dfs ( root );
printf ( "%d\n" , ans );
}
return 0;
}
HDU 5314
按照最小值从小到大排序,对于i,所有在i之前,且最小值大于i的最大值-D都满足,所以二分最大值-D即可,为什么要满足在i之前呢,因为在i之后的无法保证最大值减去最小值小于等于D。列如i之后有一个j,(i,j)之间最小值肯定是i的最小值,最大值可能是i也可能是j,如果是j可能就不满足最大值减去最小值小于等于D。因此我们只处理i之前的,最后将答案乘2即可。
#include<bits/stdc++.h>
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 1e5 + 10;
const int M = 1e7 + 10;
const int mod = 1e6 + 3;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
int n , m , k;
int head[N] , tot , root , MX , cursz , vis[N] , ver[N << 1] , w[N << 1] , nxt[N << 1] , sz[N];
ll ans;
int a[N] , MAX[N] , MIN[N] , D;
P tmp[N];
void addedge ( int from , int to , int len ) {
ver[++ tot] = to;
w[tot] = len;
nxt[tot] = head[from];
head[from] = tot;
}
void get ( int u , int fa ) {
sz[u] = 1;
int mx = 0;
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) {
get ( v , u );
sz[u] += sz[v];
mx = max ( mx , sz[v] );
}
}
mx = max ( mx , cursz - sz[u] );
if ( mx < MX ) root = u , MX = mx;
}
void dfs1 ( int u , int fa ) {
if ( MAX[u] - MIN[u] <= D )
tmp[++ tot] = mp( MIN[u] , MAX[u] );
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) {
MAX[v] = max ( MAX[u] , a[v] );
MIN[v] = min ( MIN[u] , a[v] );
dfs1 ( v , u );
}
}
}
ll cal ( int u ) {
ll res = 0;
tot = 0;
dfs1 ( u , 0 );
sort ( tmp + 1 , tmp + tot + 1 );
for ( int i = 1 ; i <= tot ; ++ i ) {
int dis = lower_bound ( tmp + 1 , tmp + i + 1 , mp( tmp[i].second - D , 0 ) ) - tmp;
res += i - dis;
}
return res;
}
void dfs ( int u ) {
MAX[u] = MIN[u] = a[u];
vis[u] = 1;
ans += cal ( u );
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] ) {
ans -= cal ( v );
MX = inf;
cursz = sz[v];
get ( v , 0 );
dfs ( root );
}
}
}
int main () {
int u , v , len , T;
scanf ( "%d" , &T );
while ( T -- ) {
tot = 0;
ans = 0;
scanf ( "%d%d" , &n , &D );
for ( int i = 1 ; i <= n ; ++ i ) {
scanf ( "%d" , &a[i] );
head[i] = - 1;
vis[i] = 0;
}
for ( int i = 1 ; i < n ; ++ i ) {
scanf ( "%d%d" , &u , &v );
addedge ( u , v , 0 );
addedge ( v , u , 0 );
}
MX = inf;
cursz = n;
get ( 1 , 0 );
dfs ( root );
printf ( "%lld\n" , ans * 2 );
}
return 0;
}
第二种
逐个处理当前根u的各个子树,处理到某一子树时,之前遍历过的子树信息已经保存下来了,可供当前子树使用,即是之前子树中的点与当前子树形成的贡献。每处理完一个子树,将其信息保存下来。
这种方法一般用于枚举点才能算出贡献的题目中,既是枚举一个点,利用某个东西求得与这个点满足关系的点有多少个。一般是求解无序对(u,v)的个数,既是(u,v)与(v,u) 是一样的。下面一种方法我们将看到有的(u,v)与(v,u) 是不一样的。
SPOJ 1825
这题使用到了第二种方法。我们用树状数组保存访问<=i个拥挤点所经过道路的最大距离,对于每个子树中的每个点v,根到v经过了若干距离,同时访问了若干个拥挤点,我们可以利用树状数组查询(k-已访问的拥挤点)所经过的最大距离,再加上当前的距离,尝试更新最大值。
#include<bits/stdc++.h>
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 2e5 + 10;
const int M = 1e7 + 10;
const int mod = 998244353;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
int n , m , k;
int tot , root , cnt[N] , cr[N] , MX , cursz , vis[N] , sz[N];
int dis[N] , ans , _k , tmp[N] , c[N];
P dist[N];
vector < P > G[N];
void add ( int x , int v ) {
for ( ; x <= m + 1 ; x += lowbit( x ) ) c[x] = max ( c[x] , v );
}
void del ( int x ) {
for ( ; x <= m + 1 ; x += lowbit( x ) ) c[x] = - INF;
}
int ask ( int x ) {
int res = - INF;
for ( ; x > 0 ; x -= lowbit ( x ) ) res = max ( res , c[x] );
return res;
}
void get ( int u , int fa ) {
sz[u] = 1;
int mx = 0;
for ( auto &E:G[u]) {
int v = E.first;
if ( ! vis[v] && v != fa ) {
get ( v , u );
sz[u] += sz[v];
mx = max ( mx , sz[v] );
}
}
mx = max ( mx , cursz - sz[u] );
if ( mx < MX ) root = u , MX = mx;
}
void dfs1 ( int u , int fa ) {
dist[++ tot] = mp( dis[u] , cnt[u] );
for ( auto &E:G[u] ) {
int v = E.first , w=E.second;
if ( ! vis[v] && v != fa ) {
dis[v] = dis[u] + w;
cnt[v] = cnt[u] + cr[v];
dfs1 ( v , u );
}
}
}
void cal ( int u ) {
dis[u] = 0;
//cnt[u] = cr[u];
cnt[u] = 0;
add ( 1 , 0 );
int tot1 = 0;
tmp[++ tot1] = cnt[u] + 1;
for ( auto &E:G[u] ) {
int v = E.first , w = E.second;
if ( ! vis[v] ) {
dis[v] = w;
cnt[v] = cnt[u] + cr[v];
tot = 0;
dfs1 ( v , u );
for ( int j = 1 ; j <= tot ; ++ j ) {
if ( k - ( dist[j].second + cr[u] ) + 1 > 0 ) {
ans = max ( 1LL * ans , 1LL * dist[j].first + ask ( k - ( dist[j].second + cr[u] ) + 1 ) );
}
}
for ( int j = 1 ; j <= tot ; ++ j )
if ( dist[j].second <= k )
add ( dist[j].second + 1 , dist[j].first ) , tmp[++ tot1] = dist[j].second + 1;
}
}
for ( int i = 1 ; i <= tot1 ; ++ i ) del ( tmp[i] );
}
void dfs ( int u ) {
vis[u] = 1;
cal ( u );
for ( auto &E:G[u]) {
int v = E.first;
if ( ! vis[v] ) {
MX = inf;
cursz = sz[v];
get ( v , 0 );
dfs ( root );
}
}
}
int main () {
int u , v , len;
scanf ( "%d%d%d" , &n , &k , &m );
for ( int i = 0 ; i <= m + 1 ; ++ i ) c[i] = - INF;
for ( int i = 1 ; i <= m ; ++ i ) {
scanf ( "%d" , &u );
cr[u] = 1;
}
for ( int i = 1 ; i < n ; ++ i ) {
scanf ( "%d%d%d" , &u , &v , &len );
G[u].push_back ( mp( v , len ) );
G[v].push_back ( mp( u , len ) );
}
MX = inf;
cursz = n;
get ( 1 , 0 );
dfs ( root );
printf ( "%d\n" , ans );
return 0;
}
HDU 4812
将从根出发的乘积保存下来,直接使用即可(利用逆元)
第三种
这个同时结合了第一种和第二种方法,有时我们必须要枚举点才能算贡献,并且(u,v)和(v,u)不等价,我们就要用到这种方法了。我们通过一次dfs把u的所有子树保存下来,然后遍历到某个子树时,我们先删去这个子树的信息,然后遍历这个子树所有点,计算答案,遍历完后,将信息恢复。
本题u->v和v->u形成的数是不一样的,所以对答案的贡献都要计算。参考:题解
#include<bits/stdc++.h>
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 1e5 + 10;
const int M = 1e7 + 10;
const int mod = 1e6 + 3;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
int n , m , k;
int head[N] , tot , root , MX , cursz , vis[N] , ver[N << 1] , w[N << 1] , nxt[N << 1] , sz[N];
int dep[N] , tmp[N];
P dist[N];
ll a[N] , b[N] , c[N] , inv10[N] , exp10[N] , ans;
/*
b[i]*10^(dep)+a[i]=0(%m)
b[i] = -a[i]/(10^(dep))%m = (m-a[i])/(10^(dep))
a[i]: 从根到i
b[i]: 从i到根
c[i]:
*/
map < ll , int > mp;
void addedge ( int from , int to , int len ) {
ver[++ tot] = to;
w[tot] = len;
nxt[tot] = head[from];
head[from] = tot;
}
int exgcd ( int a , int b , int &x , int &y ) {
if ( ! b ) {
x = 1;
y = 0;
return a;
}
int d = exgcd ( b , a % b , y , x );
y -= ( a / b ) * x;
return d;
}
void get ( int u , int fa ) {
sz[u] = 1;
int mx = 0;
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) {
get ( v , u );
sz[u] += sz[v];
mx = max ( mx , sz[v] );
}
}
mx = max ( mx , cursz - sz[u] );
if ( mx < MX ) root = u , MX = mx;
}
void dfs1 ( int u , int fa ) {
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) {
dep[v] = dep[u] + 1;
a[v] = ( a[u] * 10 + w[i] ) % m;
b[v] = ( exp10[dep[v] - 1] * w[i] + b[u] ) % m;
c[v] = 1LL * ( m - a[v] ) % m * inv10[dep[v]] % m;
dfs1 ( v , u );
}
}
}
void add ( int u , int fa , int val ) {
mp[c[u]] += val;
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) add ( v , u , val );
}
}
void solve ( int u , int fa ) {
if ( mp.count ( b[u] ) ) ans += mp[b[u]];
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] && v != fa ) solve ( v , u );
}
}
void cal ( int u ) {
a[u] = b[u] = c[u] = dep[u] = 0;
mp.clear ();
dfs1 ( u , 0 );
add ( u , 0 , 1 );
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] ) {
add ( v , u , - 1 );
solve ( v , u );
add ( v , u , 1 );
}
}
ans += mp[0] - 1;//从根出发的答案
//add ( u , 0 , - 1 );
}
void dfs ( int u ) {
vis[u] = 1;
cal ( u );
for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
int v = ver[i];
if ( ! vis[v] ) {
MX = inf;
cursz = sz[v];
get ( v , 0 );
dfs ( root );
}
}
}
int main () {
int u , v , len;
int inv , t;
scanf ( "%d%d" , &n , &m );
exgcd ( 10 , m , inv , t );
inv = ( inv%m + m ) % m;
inv10[0] = exp10[0] = 1;
for ( int i = 1 ; i <= n ; ++ i ) {
exp10[i] = exp10[i - 1] * 10 % m;
inv10[i] = inv10[i - 1] * inv % m;
}
for ( int i = 1 ; i <= n ; ++ i ) head[i] = - 1;
for ( int i = 1 ; i < n ; ++ i ) {
scanf ( "%d%d%d" , &u , &v , &len );
++ u , ++ v;
addedge ( u , v , len );
addedge ( v , u , len );
}
MX = inf;
cursz = n;
get ( 1 , 0 );
dfs ( root );
printf ( "%lld\n" , ans );
return 0;
}