-
Notifications
You must be signed in to change notification settings - Fork 2
/
MRF_MAP_GraphCutAExpansion.m
156 lines (142 loc) · 6.52 KB
/
MRF_MAP_GraphCutAExpansion.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
%% Àëãîðèòì íàõîæäåíèÿ MAP-îöåíêè
%---input---------------------------------------------------------
% X: èñõîäíîå ðàçáèåíèå, ìàòðèöà, êàæäàÿ ÿ÷åéêà ñîäåðæèò çíà÷åíèå îò 1:L
% logprobs: îòðèöàòåëüíûé ëîãàðèôì ôóíêöèè ïðàâäîïîäîáèÿ, ìàòðèöà LxN
% b: ïàðàìåòð ìîäåëè Ïîòòñà, ñêàëÿð
% L: êîëè÷åñòâî ìåòîê
% MAP_iter: ìàêñèìàëüíîå êîëè÷åñòâî èòåðàöèé
% neighbours_count: êîëè÷åñòâî ñîñåäåé, äîñòóïíûå çíà÷åíèÿ
% 2-D: 4, 8, 16
% 3-D: 6, 26
%---output--------------------------------------------------------
% X: ôèíàëüíàÿ ñåãìåíòàöèÿ
% posterior: ïîñòåðèîðíàÿ âåðîÿòíîñòü ôèíàëüíîé ñåãìåíòàöèè
function [X, posterior]=MRF_MAP_GraphCutAExpansion(X,logprobs,b,L,MAP_iter,neighbours_count)
% reshape x
sz = size(X);
flat = X(:);
flatsize = size(flat, 1);
% neighbours indexes
all_neighbours_ind = GetNeighbours(sz, neighbours_count);
all_neighbours = flat(all_neighbours_ind);
% create two new vertices
terminal0 = flatsize + 1;
terminal1 = flatsize + 2;
minimum_U = Inf;
for i=1:MAP_iter
success = 0;
fprintf('\tInner Iteration: %d out of %d\n',i,MAP_iter);
fprintf('\tCurrent U: %d\n',minimum_U);
%permutations = randperm(size(abcomb, 1));
for alpha=1:L%permutations
% vector of vertex indexes that are labeled as alpha or beta
ind_alpha = find(flat == alpha);
ind_other = find(flat ~= alpha & flat ~= 0);
if size(ind_other, 1) > 0
% ñîçäàåì äîïîëíèòåëüíûå âåðøèíû è ðåáðà
% ìåæäó ñîñåäÿìè ñ ðàçíûìè ìåòêàìè
edge_diff = zeros([flatsize * 4, 2]);
edge_same = zeros([flatsize * 4, 2]);
for neighbours_dir=1:size(all_neighbours, 1)
flat_diff_ind = find(flat' ~= all_neighbours(neighbours_dir, :));
neighbours_diff_ind = all_neighbours_ind(neighbours_dir, flat_diff_ind);
ind_start = flatsize * (neighbours_dir - 1) + 1;
ind_end = flatsize * (neighbours_dir - 1) + size(flat_diff_ind, 2);
edge_diff(ind_start:ind_end, :) = [flat_diff_ind' neighbours_diff_ind'];
flat_same_ind = find(flat' == all_neighbours(neighbours_dir, :));
neighbours_same_ind = all_neighbours_ind(neighbours_dir, flat_same_ind);
ind_start = flatsize * (neighbours_dir - 1) + 1;
ind_end = flatsize * (neighbours_dir - 1) + size(flat_same_ind, 2);
edge_same(ind_start:ind_end, :) = [flat_same_ind' neighbours_same_ind'];
end
% óäàëÿåì íóëåâûå è ñîâïàäàþùèå ðåáðà
edge_diff = edge_diff(edge_diff(:, 1) ~= 0 & edge_diff(:, 2) ~= 0 & edge_diff(:, 1) ~= edge_diff(:, 2), :);
[edge_diff, ~, ~] = unique(sort(edge_diff,2), 'rows');
edge_same = edge_same(edge_same(:, 1) ~= 0 & edge_same(:, 2) ~= 0 & edge_same(:, 1) ~= edge_same(:, 2), :);
[edge_same, ~, ~] = unique(sort(edge_same,2), 'rows');
% ñîçäàåì äîï âåðøèíû
a = ((terminal1+1):(terminal1+size(edge_diff, 1)))';
% construct edges
s = [ind_alpha; ...
ind_other; ...
ind_alpha; ...
ind_other; ...
edge_diff(:, 1); ...
edge_diff(:, 2); ...
a; ...
edge_same(:, 1)];
t = [repmat(terminal0, size(ind_alpha)); ...
repmat(terminal0, size(ind_other)); ...
repmat(terminal1, size(ind_alpha)); ...
repmat(terminal1, size(ind_other)); ...
a; ...
a; ...
repmat(terminal1, size(a));
edge_same(:, 2)];
% construct vector of weights
weights = [
logprobs(alpha, ind_alpha)'; ...
logprobs(alpha, ind_other)'; ...
Inf(size(ind_alpha)); ...
logprobs(sub2ind(size(logprobs), flat(ind_other), ind_other)); ...
b * (flat(edge_diff(:, 1)) ~= alpha); ...
b * (flat(edge_diff(:, 2)) ~= alpha); ...
b * (flat(edge_diff(:, 1)) ~= flat(edge_diff(:, 2))); ...
b * (flat(edge_same(:, 1)) ~= flat(edge_same(:, 2))) ];
% remove duplicated edges
non_empty_links = find(t~=0);
s = s(non_empty_links);
t = t(non_empty_links);
non_self_ref_links = find(t~=s);
s = s(non_self_ref_links);
t = t(non_self_ref_links);
combo = [s t];
[~, uniq_ind, ~] = unique(sort(combo,2), 'rows');
s = s(uniq_ind);
t = t(uniq_ind);
% add min weight only to the t-links
min_weight = min(weights(1:(2 * numel(ind_alpha) + 2 * numel(ind_other))));
if min_weight < 0
fprintf('\t\tSome weight is lower than zero: %f\n', min_weight);
weights(1:(2 * numel(ind_alpha) + 2 * numel(ind_other))) = weights(1:(2 * numel(ind_alpha) + 2 * numel(ind_other))) - min_weight;
end
weights = weights(non_empty_links);
weights = weights(non_self_ref_links);
weights = weights(uniq_ind);
% create graph from edges and weights
G = graph(s, t, weights);
% calculate max flow
[sum_U,~,~,ct] = maxflow(G,terminal0,terminal1);
if sum_U < minimum_U
fprintf('\tSetting new U: %d\n',sum_U);
minimum_U = sum_U;
% update image according to maxflow list of vertices
ct = ct(ct < terminal0);
ct = ct(flat(ct)~=0);
flat(ct) = alpha;
X(ct) = alpha;
success = 1;
end
end
end
if success == 0
break;
end
end
posterior = zeros(L, flatsize);
for i=1:L
posterior(i, :) = min(exp(-logprobs(i, :) - b * sum(all_neighbours ~= i)), 10^100);
posterior(i, X==0) = 0;
end
if(any(isnan(posterior(:))))
fprintf('WARNING: posterior is NaN\n');
end
norm_const = sum(posterior, 1);
if(any(norm_const(:)==0) || any(isnan(norm_const(:))))
fprintf('WARNING: norm const is zero or Nan\n');
norm_const(norm_const==0 | isnan(norm_const)) = 1;
end
posterior = bsxfun(@rdivide,posterior,norm_const);
if(any(isnan(posterior(:))))
fprintf('WARNING: posterior is NaN\n');
end