BFS和DFS版本的KM算法,复杂度$O(n^3)$

由于能力有限,不当的地方还请指出。

KM算法

KM是用来求带权二分图的最优匹配的一种算法。

概念:

  1. 顶标:每一个点有一个顶标,左边的点的顶标为lx[i],右边的点的顶标为ly[i]。
  2. 性质:保证对于算法进行的任意时刻,对于属于此二分图的任意一条边$e(u,v)$都有$lx[u]+ly[v]≥w(u,v)$。
  3. 可行边:$u,v$ 满足$lx[u]+ly[v]=w(u,v)$的边
  4. 增广路:全部由 可行边 构成,类似于匈牙利算法那样,由一条匹配边,一条非匹配边构成的,且两端都是非匹配边的一条路,这样我们反转匹配边和非匹配边,就会增加一条可行边,使得我们得到的权值更大。
  5. 交错树:当我们从某个左部节点出发寻找增广路失败时,那么在DFS过程中,所有访问过的节点,以及为了访问这些节点经过的边,共同构成一棵树。这棵树的根是左部节点,所有叶子节点也是左部节点(因为最终匹配失败了),且这棵树具有偶数条边(奇数个节点),并且树上第$1,3,5, \cdots$ 层的边都是非匹配边,第$2,4,5,\cdots$层都是匹配边,这棵树被称为交错树。

做法:

既然对于每一条边都满足上述性质, 如果我们能够在二分图中找到一种全部由可行边构成的完备匹配,使得任意$u,v$ 都满足 $lx[u]+ly[v]=w(u,v)$,那么我们就可以证明这个完备匹配是最优的(参见顶标的性质)。
既然我们可以这样来求出最优匹配,那么就要构造这样的顶标集合满足上面的所有条件。
初始的时候我们把每个左边的点的顶标设为和它相连的权值最大的边的权值。然后用匈牙利算法去一遍一遍地跑增广路,直到是全部由可行边构成的完备匹配为止。但是数据给的图不一定有完备匹配怎么办?我们可以构造一个完全图,原图中没有的边把它的权值设为0即可。

顶标的修改:

显然这样不一定可以找到满足条件的完备匹配。所以我们要在满足性质的情况下修改点的顶标使得可以找得到满足条件的完备匹配。修改顶标也是该算法的核心所在。 通过修改顶标来增加可行边,即增加我们的匹配。考虑修改目前在交错树中的边,如果将左边的点的顶标减去了$delta$,那么相应的,对于已经匹配上的边,右边的对应点的顶标要加上$delta$,这样原来的边的可行性才不会发生变化。

  1. 两端都在交错树中的边(u,v),$lx[u]+ly[v]$的值没有变化。也就是说,它原来是可行边,现在仍是可行边。
  2. 两端都不在交错树中的边$(u,v)$,$lx[u]和ly[v]$都没有变化。也就是说,这条边的可行性并没有发生改变。
  3. $u$端不在交错树中,$v$端在交错树中的边$(u,v)$,它的$lx[u]+ly[v]$的值有所增大。它原来不是可行边,现在仍不是可行边。
  4. $u$端在交错树中,$v$端不在交错树中的边$(u,v)$,它的$lx[u]+ly[v]$的值有所减小。也就说,它原来不是可行边,现在就可能成为了可行边。

所以,为了找到由可行边构成的增广路,我们要尽量修改顶标使得可行边的条数增多,对应上面的第四点,修改量即为$delta = lx[u]+ly[v]−va[u][v]$。修改顶标使得可行边增多,同样会使我们最大匹配的答案减小,并且$delta$就是我们增加可行边后所减少的最大匹配的值。为了让整个图都满足限制条件,同时使得答案最大,修改量越小越好。即为所有满足$u$端在在交错树中,而$v$端不在交错树的$min(lx[u]+ly[v]−va[u][v])$。

过程:

通过上面的证明和分析,可以得到KM算法的步骤:

1. 用匈牙利算法给每一个左端点来找增广路。
2. 若找不到增广路,记下修改量的最小值,修改访问过的点的顶标,重复第一步
3. 换下一个点来找增广路

时间复杂度:

网上有很多的分析是不严谨的,我们按照上面的方法,外层循环中为每一个点匹配一条边,时间上一个$n$,对于每一次寻找,可能要将$n$个点添加进目前的可行边,时间上又一个$n$,每一次匈牙利至多访问$n^2$条边,所以时间复杂度是$O(n^4)$。

模板题

模板题:二分图最大权匹配
按照我们刚刚的分析过程,对应这下面的代码。

代码

#include <bits/stdc++.h>
#define eps 1e-6
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit(x) (x & (-x))
#define SZ(v) ((int)(v).size())
#define All(v) (v).begin(), (v).end()
#define mp(x, y) make_pair(x, y)
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> P;
const int N = 4e2 + 10;
const int M = 1e5 + 10;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
int nl, nr, n, m ,match[N] , mb[N];
int a[N][N] , la[N] , lb[N] , slack[N];//左右顶标
bool va[N] , vb[N];
int delta;
bool dfs(int u)
{
    va[u] = 1;
    for(int i = 1; i <= n; ++ i)
    {
        if(!vb[i])
        {
            int val = la[u] + lb[i] - a[u][i];
            if(val == 0) 
            {
                vb[i] = 1;
                if(!match[i] || dfs(match[i])) 
                {
                    match[i] = u;
                    return 1;
                }
            }else{
                slack[i] = min(slack[i] ,val ) ;
            }
        }
    }
    return 0;
}
void KM()
{
    for(int i = 1; i <= n; ++ i) {
        la[i] = - inf;
        for(int j = 1; j <= n; ++ j) {
            la[i] = max(la[i] , a[i][j]);
        }
    }
    for(int i = 1; i <= n; ++ i)
    {
        memset(slack,inf,sizeof(slack));
        while(1)
        {
            memset(va,0,sizeof(va));
            memset(vb,0,sizeof(vb));
            delta = inf;
            if(dfs(i)) break;
            int delta = inf;
            for(int j = 1; j <= n; ++ j) {
                if(!vb[j] ) delta = min(delta , slack[j]);
            }
            for(int j = 1; j <= n; ++ j) {
                if(va[j]) la[j] -= delta;
            }
            for(int j = 1; j <= n; ++ j) {
                if(vb[j]) lb[j] += delta;
                else slack[j] -= delta;
            }
        }
    }
    
}
int main()
{
    scanf("%d%d%d",&nl,&nr,&m);
    n = max(nl,nr);
    int u , v, w ;
    while(m --)
    {
        scanf("%d%d%d",&v,&u,&w);
        a[v][u] = w;
    }
    KM();
    ll ans = 0;
    for(int i = 1; i <= n; ++ i){
        if(match[i]) {
            ans += a[match[i]][i];
            mb[match[i]] = i;//转化为男生对应的女生
        }
    }
    printf("%lld\n",ans);
    for(int i = 1; i <= nl; ++ i) printf("%d ", a[i][mb[i]] ? mb[i] : 0);
    return 0;
}


😊这个代码交上去会T,因为其复杂度为$O(n^4)$

优化

我们接受不了如此高的复杂度,所以要进行优化。发现如果修改了顶标之后再从最开始的点去跑匈牙利会浪费很多次无用的循环,因为当前修改了之后可能最少只增加了一条可行边,所以我们可以修改之后只取顶标的最小修改量所在的那个点进行下一次匈牙利,这样对于每一个要增广的点,访问的总的边数便至多是$n^2$条。但是这样会有一个问题,我们无法在回溯的时候修改匹配,这里用一个数组保存其前驱,然后在交错树上利用这个前驱不断修改即可。同时再次强调,我们所找的每一条边都是可行边!
下面是DFS版本的。代码中有注释。

#include <bits/stdc++.h>
#define eps 1e-6
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit(x) (x & (-x))
#define SZ(v) ((int)(v).size())
#define All(v) (v).begin(), (v).end()
#define mp(x, y) make_pair(x, y)
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> P;
const int N = 4e2 + 10;
const int M = 1e5 + 10;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
int nl, nr, n, m ,match[N] , mb[N];
int a[N][N]  , slack[N] , fa[N];//fa其实是二级祖先
int la[N] , lb[N];//左右顶标
bool va[N] , vb[N];
int delta;
bool dfs(int py)//u是右部节点
{
    vb[py] = 1;
    int x = match[py];//找到右部节点py的匹配x,x为左部节点
    if(!x) //找到增广路
    {
        //此时要把整个增光路取反
        while(py){//借助二级祖先,修改匹配
            match[py] = match[fa[py]];//fa[py] 一定被匹配过了,修改当前节点的匹配
            py = fa[py];
        }
        return 1;
    }
    //左部节点x存在,继续寻找增光路
    for(int y = 1; y <= n; ++ y)//尝试匹配右部节点y
    {
        if(!vb[y])//x-y为非匹配边,y未被尝试去匹配过(右部节点y可能之前已经有了匹配,即match[y]不为0), 尝试x - y
        {
            int val = la[x] + lb[y] - a[x][y];
            if(val == 0) 
            {  
                fa[y] = py;//记录二级祖先,此时py一定被匹配过了,因为match[py]不为0
                if( dfs(y) ) return 1;
            }else if(val < slack[y]){
                slack[y] = val;
                fa[y] = py; //这里同样要修改前驱
            }
        }
    }
    return 0;
}
void KM()
{
    for(int i = 1; i <= n; ++ i) {
        la[i] = - inf;
        for(int j = 1; j <= n; ++ j) {
            la[i] = max(la[i] , a[i][j]);
        }
    }
    for(int x = 1; x <= n; ++ x)
    {
        memset(slack,inf,sizeof(slack));
        memset(vb, 0, sizeof(vb));
        match[0] = x;//虚拟一个右部节点0
        int py = 0;
        while(1)
        {
           
            if (dfs(py) ) break;
            int delta = inf;
            for(int j = 1; j <= n; ++ j) {
                if(!vb[j] && slack[j] < delta )  {
                    delta = slack[j];
                    py = j;
                }
            }
            for(int j = 0; j <= n; ++ j) {
                if(vb[j]) la[match[j]] -= delta, lb[j] += delta;
                else slack[j] -= delta;
            }
        }
    }
    
}
int main()
{
    scanf("%d%d%d",&nl,&nr,&m);
    int u , v, w , op = 1;
    n = max(nl,nr);
    while(m --)
    {
        scanf("%d%d%d",&v,&u,&w);
        a[v][u] = w;
    }
    KM();
    ll ans = 0;
    for(int i = 1; i <= n; ++ i){
        if(match[i]) {
            ans += a[match[i]][i];
            mb[match[i]] = i;
        }
    }
    printf("%lld\n",ans);
    for(int i = 1; i <= nl; ++ i) printf("%d ", a[i][mb[i]] ? mb[i] : 0);
    return 0;
}

画个图来理解一下$fa$数组,红色边是匹配边,蓝色边是非匹配边,我们找到了增广路 ①②③
b.png

下面是BFS版本的,并且不需要再初始$la$数组(即$lx$),我们每次需要用到最小的$slack[i]$对应的点, 不需要其具体值。

#include <bits/stdc++.h>
#define eps 1e-6
#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit(x) (x & (-x))
#define SZ(v) ((int)(v).size())
#define All(v) (v).begin(), (v).end()
#define mp(x, y) make_pair(x, y)
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> P;
const int N = 4e2 + 10;
const int M = 1e5 + 10;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
int nl, nr, n, m ,match[N] , mb[N];
int a[N][N]  , slack[N] , fa[N];//fa其实是二级祖先
int la[N] , lb[N];//左右顶标
bool va[N] , vb[N];
void bfs(int u)//左部节点u
{
    memset(slack,inf,sizeof(slack));
    memset(vb, 0, sizeof(vb));
    int py , x , p, delta;
    //虚拟一个右部节点
    for(match[py = 0] = u; match[py]; py = p)
    {
        vb[py] = 1, x = match[py] , delta = inf;
        for(int y = 1; y <= n; ++ y) {
            if(vb[y]) continue ;
            if(la[x] + lb[y] - a[x][y] < slack[y])
            {
                slack[y] = la[x] + lb[y] - a[x][y];
                fa[y] = py;//二级祖先
            }
            if(slack[y] < delta) {
                delta = slack[y];
                p = y;
            }
        }
        for(int y = 0; y <= n; ++ y ) {
            if(vb[y]) la[match[y]] -=  delta , lb[y] += delta;
            else slack[y] -= delta;
        }
    }
    for(; py ; py = fa[py]) match[py] = match[fa[py]];
}
void KM()
{
    for(int x = 1; x <= n; ++ x) bfs(x);
}
int main()
{
    scanf("%d%d%d",&nl,&nr,&m);
    int u , v, w , op = 1;
    n = max(nl,nr);
    while(m --)
    {
        scanf("%d%d%d",&v,&u,&w);
        a[v][u] = w;
    }
    KM();
    ll ans = 0;
    for(int i = 1; i <= n; ++ i){
        ans += a[match[i]][i];
        mb[match[i]] = i;
    }
    printf("%lld\n",ans);
    for(int i = 1; i <= nl; ++ i) printf("%d ", a[i][mb[i]] ? mb[i] : 0);
    return 0;
}

参考博客: KM算法详解

Last modification:April 4th, 2020 at 12:27 am
如果觉得我的文章对你有用,请随意赞赏