-
Notifications
You must be signed in to change notification settings - Fork 2
/
GibbsSamplerPotts.m
38 lines (36 loc) · 1.09 KB
/
GibbsSamplerPotts.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
% Ãåíåðèðóåò âûáîðêó èç ðàñïðåäåëåíèÿ Ãèááñà
% X - ìàòðèöà íà÷àëüíîé êîíôèãóðàöèè
% B - êîëè÷åñòâî ïðîïóñêàåìûõ ñîñòîÿíèé
% M - âîçâðàùàåìîå êîëè÷åñòâî ñîñòîÿíèé
% k - êîëè÷åñòâî êëàññîâ ñåãìåíòàöèè
% beta - ïàðàìåòð ìîäåëè Ïîòòñà
% neighbours_count - êîëè÷åñòâî ñîñåäåé
% Âîçâðàùàåò ìàòðèöó ðàçìåðíîñòè: Mx(êîëè÷åñòâî âîêñåëåé)
function [Y] = GibbsSamplerPotts(X, B, M, k, beta, neighbours_count)
sz = size(X);
flat = X(:);
flatsz = size(flat,1);
all_neighbours_ind = GetNeighbours(sz, neighbours_count);
Y = zeros(M, flatsz);
for j=1:(B+M)
permutations = randperm(flatsz);
for i=permutations
P = zeros(1, k);
pex = zeros(1, k);
for l=1:k
neighbours = all_neighbours_ind(all_neighbours_ind(:, i)~=i, i);
neib = -beta * sum(flat(neighbours)~=l);
pex(l) = exp(neib);
end
for l=1:k
P(l) = pex(l)/sum(pex);
end
if sum(P.^2) > 0
ind = randsample(1:k, 1, true, P);
flat(i) = ind;
end
end
if j > B
Y(j-B, :) = flat;
end
end