forked from Airspace-Encounter-Models/em-model-manned-bayes
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bn_sample.m
58 lines (52 loc) · 2.09 KB
/
bn_sample.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
function S = bn_sample(G, r, N, alpha, num_samples, start, order)
% Copyright 2008 - 2021, MIT Lincoln Laboratory
% SPDX-License-Identifier: BSD-2-Clause
% BN_SAMPLE Produces a sample from a Bayesian network.
% Returns a matrix whose rows consist of n-dimensional samples from the
% specified Bayesian network.
%
% S = BN_SAMPLE(G, R, N, ALPHA, NUM_SAMPLES) returns a matrix S whose
% rows consist of n-dimensional samples from a Bayesian network with
% graphical structure G, sufficient statistics N, and prior ALPHA. The
% number of samples is specified by NUM_SAMPLES. The array R specifies
% the number of bins associated with each variable.
%
% 'start', a cell array with same number of elements as
% N, allows parameters to be preset. A parameter can only be preset if
% all of its parents, if any, are also preset. WARNING: Will currently
% only allow discrete parameters to be preset.
%
% Optional argument 'order', is the sorted order of the graphical
% structure G. This is calculated using bn_sort in em_read.
%
% SEE ALSO bn_sort select_random em_read
% topological sort bayes net
if nargin < 7
order = bn_sort(G);
end
n = length(N);
% Input handling
assert(iscell(start));
assert(length(start) == n);
assert(length(order) == n);
% Preallocate
S = zeros(num_samples, n);
for sample_index = 1:num_samples
% generate each sample
for i = order
parents = G(:, i);
j = 1;
if ~isempty(start{i}) & ~isnan(start{i})
if any(parents) && length([start{parents > 0}]) < sum(parents)
error('Attempt to preset a dependent variable');
else
S(sample_index, i) = start{i};
end
else
if any(parents)
j = asub2ind(r(parents), S(sample_index, parents));
end
S(sample_index, i) = select_random(N{i}(:, j) + alpha{i}(:, j));
end
end
end