-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathbipartite_matching.m
240 lines (211 loc) · 6.43 KB
/
bipartite_matching.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
function [val m1 m2 mi]=bipartite_matching(varargin)
% BIPARTITE_MATCHING Solve a maximum weight bipartite matching problem
%
% [val m1 m2]=bipartite_matching(A) for a rectangular matrix A
% [val m1 m2 mi]=bipartite_matching(x,ei,ej,n,m) for a matrix stored
% in triplet format. This call also returns a matching indicator mi so
% that val = x'*mi.
%
% The maximum weight bipartite matching problem tries to pick out elements
% from A such that each row and column get only a single non-zero but the
% sum of all the chosen elements is as large as possible.
%
% This function is slightly atypical for a graph library, because it will
% be primarily used on rectangular inputs. However, these rectangular
% inputs model bipartite graphs and we take advantage of that stucture in
% this code. The underlying graph adjency matrix is
% G = spaugment(A,0);
% where A is the rectangular input to the bipartite_matching function.
%
% Matlab already has the dmperm function that computes a maximum
% cardinality matching between the rows and the columns. This function
% gives us the maximum weight matching instead. For unweighted graphs, the
% two functions are equivalent.
%
% Note: If ei and ej contain duplicate edges, the results of this function
% are incorrect.
%
% See also DMPERM
%
% Example:
% A = rand(10,8); % bipartite matching between random data
% [val mi mj] = bipartite_matching(A);
% val
% David F. Gleich and Ying Wang
% Copyright, Stanford University, 2008-2009
% Computational Approaches to Digital Stewardship
% 2008-04-24: Initial coding (copy from Ying Wang matching_sparse_mex.cpp)
% 2008-11-15: Added triplet input/output
% 2009-04-30: Modified for gaimc library
% 2009-05-15: Fixed error with empty inputs and triple added example.
[rp ci ai tripi n m] = bipartite_matching_setup(varargin{:});
if isempty(tripi)
error(nargoutchk(0,3,nargout,'struct'));
else
error(nargoutchk(0,4,nargout,'struct'));
end
if ~isempty(tripi) && nargout>3
[val m1 m2 mi] = bipartite_matching_primal_dual(rp, ci, ai, tripi, n, m);
else
[val m1 m2] = bipartite_matching_primal_dual(rp, ci, ai, tripi, n, m);
end
function [rp ci ai tripi n m]= bipartite_matching_setup(A,ei,ej,n,m)
% convert the input
if nargin == 1
if isstruct(A)
[nzi nzj nzv]=csr_to_sparse(A.rp,A.ci,A.ai);
else
[nzi nzj nzv]=find(A);
end
[n m]=size(A);
triplet = 0;
elseif nargin >= 3 && nargin <= 5
nzi = ei;
nzj = ej;
nzv = A;
if ~exist('n','var') || isempty(n), n = max(nzi); end
if ~exist('m','var') || isempty(m), m = max(nzj); end
triplet = 1;
else
error(nargchk(3,5,nargin,'struct'));
end
nedges = length(nzi);
rp = ones(n+1,1); % csr matrix with extra edges
ci = zeros(nedges+n,1);
ai = zeros(nedges+n,1);
if triplet, tripi = zeros(nedges+n,1); % triplet index
else tripi = [];
end
%
% 1. build csr representation with a set of extra edges from vertex i to
% vertex m+i
%
rp(1)=0;
for i=1:nedges
rp(nzi(i)+1)=rp(nzi(i)+1)+1;
end
rp=cumsum(rp);
for i=1:nedges
if triplet, tripi(rp(nzi(i))+1)=i; end % triplet index
ai(rp(nzi(i))+1)=nzv(i);
ci(rp(nzi(i))+1)=nzj(i);
rp(nzi(i))=rp(nzi(i))+1;
end
for i=1:n % add the extra edges
if triplet, tripi(rp(i)+1)=-1; end % triplet index
ai(rp(i)+1)=0;
ci(rp(i)+1)=m+i;
rp(i)=rp(i)+1;
end
% restore the row pointer array
for i=n:-1:1
rp(i+1)=rp(i);
end
rp(1)=0;
rp=rp+1;
%
% 1a. check for duplicates in the data
%
colind = false(m+n,1);
for i=1:n
for rpi=rp(i):rp(i+1)-1
if colind(ci(rpi)), error('bipartite_matching:duplicateEdge',...
'duplicate edge detected (%i,%i)',i,ci(rpi));
end
colind(ci(rpi))=1;
end
for rpi=rp(i):rp(i+1)-1, colind(ci(rpi))=0; end % reset indicator
end
function [val m1 m2 mi]=bipartite_matching_primal_dual(...
rp, ci, ai, tripi, n, m)
% BIPARTITE_MATCHING_PRIMAL_DUAL
alpha=zeros(n,1); % variables used for the primal-dual algorithm
beta=zeros(n+m,1);
queue=zeros(n,1);
t=zeros(n+m,1);
match1=zeros(n,1);
match2=zeros(n+m,1);
tmod = zeros(n+m,1);
ntmod=0;
%
% initialize the primal and dual variables
%
for i=1:n
for rpi=rp(i):rp(i+1)-1
if ai(rpi) > alpha(i), alpha(i)=ai(rpi); end
end
end
% dual variables (beta) are initialized to 0 already
% match1 and match2 are both 0, which indicates no matches
i=1;
while i<=n
% repeat the problem for n stages
% clear t(j)
for j=1:ntmod, t(tmod(j))=0; end
ntmod=0;
% add i to the stack
head=1; tail=1;
queue(head)=i; % add i to the head of the queue
while head <= tail && match1(i)==0
k=queue(head);
for rpi=rp(k):rp(k+1)-1
j = ci(rpi);
if ai(rpi) < alpha(k)+beta(j) - 1e-8, continue; end % skip if tight
if t(j)==0,
tail=tail+1; queue(tail)=match2(j);
t(j)=k;
ntmod=ntmod+1; tmod(ntmod)=j;
if match2(j)<1,
while j>0,
match2(j)=t(j);
k=t(j);
temp=match1(k);
match1(k)=j;
j=temp;
end
break; % we found an alternating path
end
end
end
head=head+1;
end
if match1(i) < 1, % still not matched, so update primal, dual and repeat
theta=inf;
for j=1:head-1
t1=queue(j);
for rpi=rp(t1):rp(t1+1)-1
t2=ci(rpi);
if t(t2) == 0 && alpha(t1) + beta(t2) - ai(rpi) < theta,
theta = alpha(t1) + beta(t2) - ai(rpi);
end
end
end
for j=1:head-1, alpha(queue(j)) = alpha(queue(j)) - theta; end
for j=1:ntmod, beta(tmod(j)) = beta(tmod(j)) + theta; end
continue;
end
i=i+1; % increment i
end
val=0;
for i=1:n
for rpi=rp(i):rp(i+1)-1
if ci(rpi)==match1(i), val=val+ai(rpi); end
end
end
noute = 0; % count number of output edges
for i=1:n
if match1(i)<=m, noute=noute+1; end
end
m1=zeros(noute,1); m2=m1; % copy over the 0 array
noute=1;
for i=1:n
if match1(i)<=m, m1(noute)=i; m2(noute)=match1(i);noute=noute+1; end
end
if nargout>3
mi= false(length(tripi)-n,1);
for i=1:n
for rpi=rp(i):rp(i+1)-1
if match1(i)<=m && ci(rpi)==match1(i), mi(tripi(rpi))=1; end
end
end
end