-
Notifications
You must be signed in to change notification settings - Fork 49
/
particles.pyx
236 lines (198 loc) · 7.87 KB
/
particles.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
from pomdp_py.framework.basics cimport GenerativeDistribution
from pomdp_py.utils.cython_utils cimport det_dict_hash
import random
cdef class WeightedParticles(GenerativeDistribution):
"""
Represents a distribution :math:`\Pr(X)` with weighted particles, each is a
tuple (value, weight). "value" means a value for the random variable X. If
multiple values are present for the same value, will interpret the
probability at X=x as the average of those weights.
__init__(self, list particles, str approx_method="none", distance_func=None)
Args:
particles (list): List of (value, weight) tuples. The weight represents
the likelihood that the value is drawn from the underlying distribution. Note that the weights may not be normalized. To normalize, call condense().
approx_method (str): 'nearest' if when querying the probability
of a value, and there is no matching particle for it, return
the probability of the value closest to it. Assuming values
are comparable; "none" if no approximation, return 0.
distance_func: Used when approx_method is 'nearest'. Returns
a number given two values in this particle set.
frozen: if true, then this WeightedParticles object cannot be modified. This
makes it hashable.
"""
def __init__(self, list particles, str approx_method="none", distance_func=None, frozen=False):
self._values = [value for value, _ in particles]
self._weights = [weight for _, weight in particles]
self._particles = particles
self._hist = self.get_histogram()
self._hist_valid = True
self._approx_method = approx_method
self._distance_func = distance_func
self._frozen = frozen
if self._frozen:
self._hashcode = det_dict_hash(self._hist)
@property
def particles(self):
return self._particles
@property
def values(self):
return self._values
@property
def weights(self):
return self._weights
@property
def frozen(self):
return self._frozen
@property
def hist(self):
return self._hist
@property
def hist_valid(self):
return self._hist_valid
def add(self, particle):
"""add(self, particle)
particle: (value, weight) tuple"""
if self._frozen:
raise NotImplementedError("weighted particles is frozen and cannot be modified")
self._particles.append(particle)
s, w = particle
self._values.append(s)
self._weights.append(w)
self._hist_valid = False
def __str__(self):
return str(self.condense().particles)
def __len__(self):
return len(self._particles)
def __hash__(self):
if self._frozen:
return self._hashcode
raise NotImplementedError
def __eq__(self, other):
if isinstance(other, WeightedParticles):
return self._hist == other.hist
return False
def __getitem__(self, value):
"""Returns the probability of `value`; normalized"""
if len(self.particles) == 0:
raise ValueError("Particles is empty.")
if not self._hist_valid:
self._hist = self.get_histogram()
self._hist_valid = True
if value in self._hist:
return self._hist[value]
else:
if self._approx_method == "none":
return 0.0
elif self._approx_method == "nearest":
nearest_dist = float('inf')
nearest = self._values[0]
for s in self._values[1:]:
dist = self._distance_func(s, nearest)
if dist < nearest_dist:
nearest_dist = dist
nearest = s
return self[nearest]
else:
raise ValueError("Cannot handle approx_method:",
self._approx_method)
def __setitem__(self, value, prob):
"""
The particle belief does not support assigning an exact probability to a value.
"""
raise NotImplementedError
def random(self):
"""Samples a value based on the particles"""
value = random.choices(self._values, weights=self._weights, k=1)[0]
return value
def mpe(self):
if not self._hist_valid:
self._hist = self.get_histogram()
self._hist_valid = True
return max(self._hist, key=self._hist.get)
def __iter__(self):
return iter(self._particles)
cpdef dict get_histogram(self):
"""
get_histogram(self)
Returns a mapping from value to probability, normalized."""
cdef dict hist = {}
cdef dict counts = {}
# first, sum the weights
for s, w in self._particles:
hist[s] = hist.get(s, 0) + w
counts[s] = counts.get(s, 0) + 1
# then, average the sums
total_weights = 0.0
for s in hist:
hist[s] = hist[s] / counts[s]
total_weights += hist[s]
# finally, normalize
for s in hist:
hist[s] /= total_weights
return hist
@classmethod
def from_histogram(cls, histogram, frozen=False):
"""Given a pomdp_py.Histogram return a particle representation of it,
which is an approximation"""
particles = []
for v in histogram:
particles.append((v, histogram[v]))
return WeightedParticles(particles, frozen=frozen)
def condense(self):
"""
Returns a new set of weighted particles with unique values
and weights aggregated (taken average).
"""
return WeightedParticles.from_histogram(self.get_histogram(), frozen=self._frozen)
cdef class Particles(WeightedParticles):
""" Particles is a set of unweighted particles; This set of particles represent
a distribution :math:`\Pr(X)`. Each particle takes on a specific value of :math:`X`.
Inherits :py:mod:`~pomdp_py.representations.distribution.particles.WeightedParticles`.
__init__(self, particles, **kwargs)
Args:
particles (list): List of values.
kwargs: see __init__() of :py:mod:`~pomdp_py.representations.distribution.particles.WeightedParticles`.
"""
def __init__(self, particles, **kwargs):
super().__init__(list(zip(particles, [None]*len(particles))), **kwargs)
def __iter__(self):
return iter(self.particles)
def add(self, particle):
"""add(self, particle)
particle: just a value"""
self._particles.append((particle, None))
self._values.append(particle)
self._weights.append(None)
self._hist_valid = False
@property
def particles(self):
"""For unweighted particles, the particles are just values."""
return self._values
def get_abstraction(self, state_mapper):
"""get_abstraction(self, state_mapper)
feeds all particles through a state abstraction function.
Or generally, it could be any function.
"""
particles = [state_mapper(s) for s in self.particles]
return particles
@classmethod
def from_histogram(cls, histogram, num_particles=1000):
"""Given a pomdp_py.Histogram return a particle representation of it,
which is an approximation"""
particles = []
for _ in range(num_particles):
particles.append(histogram.random())
return Particles(particles)
cpdef dict get_histogram(self):
cdef dict hist = {}
for s in self.particles:
hist[s] = hist.get(s, 0) + 1
for s in hist:
hist[s] = hist[s] / len(self.particles)
return hist
def random(self):
"""Samples a value based on the particles"""
if len(self._particles) > 0:
return random.choice(self._values)
else:
return None