-
Notifications
You must be signed in to change notification settings - Fork 0
/
sampler.py
45 lines (40 loc) · 1.4 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from random import randrange
import numpy as np
from numba import njit
@njit(fastmath=True)
def sample_unseen(sample_size, sampler_state, remaining, result):
"""
Sample a desired number of integers from a range (starting from zero)
excluding black-listed elements defined in sample state. Used with in
conjunction with `prime_sample_state` method, which initializes state.
Inspired by Fischer-Yates shuffle.
"""
# gradually sample from the decreased size range
for k in range(sample_size):
# i = random_state.randint(remaining)
i = randrange(remaining)
result[k] = sampler_state.get(i, i)
remaining -= 1
sampler_state[i] = sampler_state.get(remaining, remaining)
sampler_state.pop(remaining, -1)
@njit(fastmath=True)
def prime_sampler_state(n, exclude):
"""
Initialize state to be used in `sample_unseen_items`.
Ensures seen items are never sampled by placing them
outside of sampling region.
"""
# initialize typed numba dicts
state = {n: n}; state.pop(n)
track = {n: n}; track.pop(n)
n_pos = n - len(state) - 1
# reindex excluded items, placing them in the end
for i, item in enumerate(exclude):
pos = n_pos - i
x = track.get(item, item)
t = state.get(pos, pos)
state[x] = t
track[t] = x
state.pop(pos, n)
track.pop(item, n)
return state