点分治

思想

点分治是一种用于处理静态树上路径统计问题的算法,其核心原理还是基于分治思想。

我们不妨对这类树上静态路径统计问题抽象化,例如:

给定一个棵无根树,求满足要求P的路径有几条(点对)

我们可以使用分治算法求解本题,对于无根树,我们指定一个节点为根,显然,这样的路径可以归为两类:

  1. 满足要求P且经过根root的路径
  2. 满足要求P但不经过根root的路径

第二类路径不经过根,所以一定在根的某一刻子树中,这我们就可以递归地进行求解,而重点就在于统计第一类路径。 先不考虑统计问题,我们可以设计如下的分治算法:

  1. 求出根root
  2. 得到相关信息,统计第一类路径并累加入答案中
  3. 删除节点root ,对于root 的每一棵子树递归的执行分治算法

显然,递归的层数和每一次选取的根root有关,当树退化成一条链时,如果任意的选取根,递归层数很可能就达到$O(n)$层。解决方法就是选取树的重心作为树的根,容易证明,每一次选取树的重心作为根时,点分治算法最多递归$log_2n$层。

一般我们有下面几种方法来计算以当前重心作为根对答案的贡献

第一种

先一次性的计算出当前根$u$所有子树的贡献,这里面包括了经过当前树根的路径和不经过当前树根的路径,因此我们还要遍历$u$的所有子树,减去各个子树单独的贡献,即不经过当前树根$u$的路径数量。

POJ1741
按照dis从小到大排序,对于每个i维护一个R使得区间[i+1,R]内所有点都满足,可知R是递减的,先算总的贡献,再减去各个子树的贡献。

#include <cstdio>
#include <cstring>
#include <algorithm>

#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 1e4 + 10;
const int M = 1e6 + 10;
const int mod = 998244353;
const int inf = 0x3f3f3f3f;
const int seed = 131;
int n , m , k;
int head[N] , tot , root , cnt , MX , ans , cursz , vis[N] , ver[N << 1] , w[N << 1] , nxt[N << 1] , sz[N];
ll dis[N] , dist[N];

void addedge ( int from , int to , int len ) {
    ver[++ tot] = to;
    w[tot] = len;
    nxt[tot] = head[from];
    head[from] = tot;
}

void get ( int u , int fa ) {
    sz[u] = 1;
    int mx = 0;
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) {
            get ( v , u );
            sz[u] += sz[v];
            mx = max ( mx , sz[v] );
        }
    }
    mx = max ( mx , cursz - sz[u] );
    if ( mx < MX ) root = u , MX = mx;
}

void dfs1 ( int u , int fa ) {
    dist[++ cnt] = dis[u];
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) {
            dis[v] = dis[u] + w[i];
            dfs1 ( v , u );
        }
    }
}

int cal ( int u ) {
    int res = 0;
    cnt = 0;
    dfs1 ( u , 0 );
    sort ( dist + 1 , dist + cnt + 1 );
    int L = 1 , R = cnt;
    while ( L < R ) {
        while ( L < R && dist[L] + dist[R] > k ) -- R;
        res += R - L;
        ++ L;
    }
    return res;
}

void dfs ( int u ) {
    vis[u] = 1;
    dis[u] = 0;
    ans += cal ( u );
    get ( u , 0 );
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] ) {
            //dis[v] = w[i];
            ans -= cal ( v );
            MX = inf;
            cursz = sz[v];
            get ( v , 0 );
            dfs ( root );
        }
    }
}

int main () {
    int u , v , len;
    while ( scanf ( "%d%d" , &n , &k ) != EOF && ( n || k ) ) {
        tot = 0;
        root = 0;
        ans = 0;
        for ( int i = 0 ; i <= n ; ++ i ) head[i] = - 1 , vis[i] = 0;
        for ( int i = 1 ; i < n ; ++ i ) {
            scanf ( "%d%d%d" , &u , &v , &len );
            addedge ( u , v , len );
            addedge ( v , u , len );
        }
        MX = inf;
        cursz = n;
        get ( 1 , 0 );
        dfs ( root );
        printf ( "%d\n" , ans );
    }
    return 0;
}

HDU 5314
按照最小值从小到大排序,对于i,所有在i之前,且最小值大于i的最大值-D都满足,所以二分最大值-D即可,为什么要满足在i之前呢,因为在i之后的无法保证最大值减去最小值小于等于D。列如i之后有一个j,(i,j)之间最小值肯定是i的最小值,最大值可能是i也可能是j,如果是j可能就不满足最大值减去最小值小于等于D。因此我们只处理i之前的,最后将答案乘2即可。

#include<bits/stdc++.h>

#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 1e5 + 10;
const int M = 1e7 + 10;
const int mod = 1e6 + 3;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
int n , m , k;
int head[N] , tot , root , MX , cursz , vis[N] , ver[N << 1] , w[N << 1] , nxt[N << 1] , sz[N];
ll ans;
int a[N] , MAX[N] , MIN[N] , D;
P tmp[N];

void addedge ( int from , int to , int len ) {
    ver[++ tot] = to;
    w[tot] = len;
    nxt[tot] = head[from];
    head[from] = tot;
}

void get ( int u , int fa ) {
    sz[u] = 1;
    int mx = 0;
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) {
            get ( v , u );
            sz[u] += sz[v];
            mx = max ( mx , sz[v] );
        }
    }
    mx = max ( mx , cursz - sz[u] );
    if ( mx < MX ) root = u , MX = mx;
}

void dfs1 ( int u , int fa ) {
    if ( MAX[u] - MIN[u] <= D )
        tmp[++ tot] = mp( MIN[u] , MAX[u] );
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) {
            MAX[v] = max ( MAX[u] , a[v] );
            MIN[v] = min ( MIN[u] , a[v] );
            dfs1 ( v , u );
        }
    }
}

ll cal ( int u ) {
    ll res = 0;
    tot = 0;
    dfs1 ( u , 0 );
    sort ( tmp + 1 , tmp + tot + 1 );
    for ( int i = 1 ; i <= tot ; ++ i ) {
        int dis = lower_bound ( tmp + 1 , tmp + i + 1 , mp( tmp[i].second - D , 0 ) ) - tmp;
        res += i - dis;
    }
    return res;
}

void dfs ( int u ) {
    MAX[u] = MIN[u] = a[u];
    vis[u] = 1;
    ans += cal ( u );
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] ) {
            ans -= cal ( v );
            MX = inf;
            cursz = sz[v];
            get ( v , 0 );
            dfs ( root );
        }
    }
}

int main () {
    int u , v , len , T;
    scanf ( "%d" , &T );
    while ( T -- ) {
        tot = 0;
        ans = 0;
        scanf ( "%d%d" , &n , &D );
        for ( int i = 1 ; i <= n ; ++ i ) {
            scanf ( "%d" , &a[i] );
            head[i] = - 1;
            vis[i] = 0;
        }
        for ( int i = 1 ; i < n ; ++ i ) {
            scanf ( "%d%d" , &u , &v );
            addedge ( u , v , 0 );
            addedge ( v , u , 0 );
        }
        MX = inf;
        cursz = n;
        get ( 1 , 0 );
        dfs ( root );
        printf ( "%lld\n" , ans * 2 );
    }
    return 0;
}

第二种

逐个处理当前根u的各个子树,处理到某一子树时,之前遍历过的子树信息已经保存下来了,可供当前子树使用,即是之前子树中的点与当前子树形成的贡献。每处理完一个子树,将其信息保存下来。
这种方法一般用于枚举点才能算出贡献的题目中,既是枚举一个点,利用某个东西求得与这个点满足关系的点有多少个。一般是求解无序对(u,v)的个数,既是(u,v)与(v,u) 是一样的。下面一种方法我们将看到有的(u,v)与(v,u) 是不一样的。

SPOJ 1825
这题使用到了第二种方法。我们用树状数组保存访问<=i个拥挤点所经过道路的最大距离,对于每个子树中的每个点v,根到v经过了若干距离,同时访问了若干个拥挤点,我们可以利用树状数组查询(k-已访问的拥挤点)所经过的最大距离,再加上当前的距离,尝试更新最大值。

#include<bits/stdc++.h>
 
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 2e5 + 10;
const int M = 1e7 + 10;
const int mod = 998244353;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
int n , m , k;
int  tot , root , cnt[N] , cr[N] , MX , cursz , vis[N] ,  sz[N];
int dis[N] , ans , _k , tmp[N] , c[N];
P dist[N];
vector < P > G[N];
 
void add ( int x , int v ) {
    for ( ; x <= m + 1 ; x += lowbit( x ) ) c[x] = max ( c[x] , v );
}
 
void del ( int x ) {
    for ( ; x <= m + 1 ; x += lowbit( x ) ) c[x] = - INF;
}
 
int ask ( int x ) {
    int res = - INF;
    for ( ; x > 0 ; x -= lowbit ( x ) ) res = max ( res , c[x] );
    return res;
}
void get ( int u , int fa ) {
    sz[u] = 1;
    int mx = 0;
    for ( auto &E:G[u]) {
        int v = E.first;
        if ( ! vis[v] && v != fa ) {
            get ( v , u );
            sz[u] += sz[v];
            mx = max ( mx , sz[v] );
        }
    }
    mx = max ( mx , cursz - sz[u] );
    if ( mx < MX ) root = u , MX = mx;
}
 
void dfs1 ( int u , int fa ) {
    dist[++ tot] = mp( dis[u] , cnt[u] );
    for ( auto &E:G[u] ) {
        int v = E.first , w=E.second;
        if ( ! vis[v] && v != fa ) {
            dis[v] = dis[u] + w;
            cnt[v] = cnt[u] + cr[v];
            dfs1 ( v , u );
        }
    }
}
 
void cal ( int u ) {
    dis[u] = 0;
    //cnt[u] = cr[u];
    cnt[u] = 0;
    add ( 1 , 0 );
    int tot1 = 0;
    tmp[++ tot1] = cnt[u] + 1;
    for ( auto &E:G[u] ) {
        int v = E.first , w = E.second;
        if ( ! vis[v] ) {
            dis[v] = w;
            cnt[v] = cnt[u] + cr[v];
            tot = 0;
            dfs1 ( v , u );
            for ( int j = 1 ; j <= tot ; ++ j ) {
                if ( k - ( dist[j].second + cr[u] ) + 1 > 0 ) {
                    ans = max ( 1LL * ans , 1LL * dist[j].first + ask ( k - ( dist[j].second + cr[u] ) + 1 ) );
                }
            }
            for ( int j = 1 ; j <= tot ; ++ j )
                if ( dist[j].second <= k )
                    add ( dist[j].second + 1 , dist[j].first ) , tmp[++ tot1] = dist[j].second + 1;
        }
    }
    for ( int i = 1 ; i <= tot1 ; ++ i ) del ( tmp[i] );
}
 
void dfs ( int u ) {
    vis[u] = 1;
    cal ( u );
    for ( auto &E:G[u]) {
        int v = E.first;
        if ( ! vis[v] ) {
            MX = inf;
            cursz = sz[v];
            get ( v , 0 );
            dfs ( root );
        }
    }
}
 
int main () {
    int u , v , len;
    scanf ( "%d%d%d" , &n , &k , &m );
    for ( int i = 0 ; i <= m + 1 ; ++ i ) c[i] = - INF;
    for ( int i = 1 ; i <= m ; ++ i ) {
        scanf ( "%d" , &u );
        cr[u] = 1;
    }
    for ( int i = 1 ; i < n ; ++ i ) {
        scanf ( "%d%d%d" , &u , &v , &len );
        G[u].push_back ( mp( v , len ) );
        G[v].push_back ( mp( u , len ) );
    }
    MX = inf;
    cursz = n;
    get ( 1 , 0 );
    dfs ( root );
    printf ( "%d\n" , ans );
    return 0;
}

HDU 4812
将从根出发的乘积保存下来,直接使用即可(利用逆元)

第三种

这个同时结合了第一种和第二种方法,有时我们必须要枚举点才能算贡献,并且(u,v)和(v,u)不等价,我们就要用到这种方法了。我们通过一次dfs把u的所有子树保存下来,然后遍历到某个子树时,我们先删去这个子树的信息,然后遍历这个子树所有点,计算答案,遍历完后,将信息恢复。

codeforce 716E

本题u->v和v->u形成的数是不一样的,所以对答案的贡献都要计算。参考:题解

#include<bits/stdc++.h>

#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 1e5 + 10;
const int M = 1e7 + 10;
const int mod = 1e6 + 3;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
int n , m , k;
int head[N] , tot , root , MX , cursz , vis[N] , ver[N << 1] , w[N << 1] , nxt[N << 1] , sz[N];
int dep[N] , tmp[N];
P dist[N];
ll a[N] , b[N] , c[N] , inv10[N] , exp10[N] , ans;
/*
 b[i]*10^(dep)+a[i]=0(%m)
 b[i] = -a[i]/(10^(dep))%m = (m-a[i])/(10^(dep))
 a[i]: 从根到i
 b[i]: 从i到根
 c[i]:
 */
map < ll , int > mp;

void addedge ( int from , int to , int len ) {
    ver[++ tot] = to;
    w[tot] = len;
    nxt[tot] = head[from];
    head[from] = tot;
}

int exgcd ( int a , int b , int &x , int &y ) {
    if ( ! b ) {
        x = 1;
        y = 0;
        return a;
    }
    int d = exgcd ( b , a % b , y , x );
    y -= ( a / b ) * x;
    return d;
}

void get ( int u , int fa ) {
    sz[u] = 1;
    int mx = 0;
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) {
            get ( v , u );
            sz[u] += sz[v];
            mx = max ( mx , sz[v] );
        }
    }
    mx = max ( mx , cursz - sz[u] );
    if ( mx < MX ) root = u , MX = mx;
}

void dfs1 ( int u , int fa ) {
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) {
            dep[v] = dep[u] + 1;
            a[v] = ( a[u] * 10 + w[i] ) % m;
            b[v] = ( exp10[dep[v] - 1] * w[i] + b[u] ) % m;
            c[v] = 1LL * ( m - a[v] ) % m * inv10[dep[v]] % m;
            dfs1 ( v , u );
        }
    }
}

void add ( int u , int fa , int val ) {
    mp[c[u]] += val;
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) add ( v , u , val );
    }
}

void solve ( int u , int fa ) {
    if ( mp.count ( b[u] ) ) ans += mp[b[u]];
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] && v != fa ) solve ( v , u );
    }
}

void cal ( int u ) {
    a[u] = b[u] = c[u] = dep[u] = 0;
    mp.clear ();
    dfs1 ( u , 0 );
    add ( u , 0 , 1 );
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] ) {
            add ( v , u , - 1 );
            solve ( v , u );
            add ( v , u , 1 );
        }
    }
    ans += mp[0] - 1;//从根出发的答案
    //add ( u , 0 , - 1 );
}

void dfs ( int u ) {
    vis[u] = 1;
    cal ( u );
    for ( int i = head[u] ; i != - 1 ; i = nxt[i] ) {
        int v = ver[i];
        if ( ! vis[v] ) {
            MX = inf;
            cursz = sz[v];
            get ( v , 0 );
            dfs ( root );
        }
    }
}

int main () {
    int u , v , len;
    int inv , t;
    scanf ( "%d%d" , &n , &m );
    exgcd ( 10 , m , inv , t );
    inv = ( inv%m + m ) % m;
    inv10[0] = exp10[0] = 1;
    for ( int i = 1 ; i <= n ; ++ i ) {
        exp10[i] = exp10[i - 1] * 10 % m;
        inv10[i] = inv10[i - 1] * inv % m;
    }
    for ( int i = 1 ; i <= n ; ++ i ) head[i] = - 1;
    for ( int i = 1 ; i < n ; ++ i ) {
        scanf ( "%d%d%d" , &u , &v , &len );
        ++ u , ++ v;
        addedge ( u , v , len );
        addedge ( v , u , len );
    }
    MX = inf;
    cursz = n;
    get ( 1 , 0 );
    dfs ( root );
    printf ( "%lld\n" , ans );

    return 0;
}
Last modification:January 28th, 2020 at 11:32 pm
如果觉得我的文章对你有用,请随意赞赏

Leave a Comment