-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
71 lines (62 loc) · 1.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from gym.envs.registration import register
import numpy as np
# register non-slippery/ deterministic FrozenLake environment
register(
id='FrozenLakeNotSlippery-v0',
entry_point='gym.envs.toy_text:FrozenLakeEnv',
kwargs={'map_name' : '4x4', 'is_slippery': False},
)
def act_to_str(act: int) -> str:
"""
Map actions (of FrozenLake environment) to interpretable symbols corresponding to directions
:param act (int): action to map to string
:return (str): interpretable action name
"""
if act == 0:
return "L"
elif act == 1:
return "D"
elif act == 2:
return "R"
elif act == 3:
return "U"
else:
raise ValueError("Invalid action value")
def visualise_q_table(q_table):
"""
Print q_table in human-readable format
:param q_table (Dict): q_table in form of a dict mapping (observation, action) pairs to
q-values
"""
for key in sorted(q_table.keys()):
obs, act = key
act_name = act_to_str(act)
q_value = q_table[key]
print(f"Pos={obs}\tAct={act_name}\t->\t{q_value}")
def visualise_policy(q_table):
"""
Given q_table print greedy policy for each FrozenLake position
:param q_table (Dict): q_table in form of a dict mapping (observation, action) pairs to
q-values
"""
# extract best acts
act_table = np.zeros((4,4))
str_table = []
for row in range(4):
str_table.append("")
for col in range(4):
pos = row * 4 + col
max_q = None
max_a = None
for a in range(4):
q = q_table[(pos, a)]
if max_q is None or q > max_q:
max_q = q
max_a = a
act_table[row, col] = max_a
str_table[row] += act_to_str(max_a)
# print best actions in human_readable format
print("\nAction selection table:")
for row_str in str_table:
print(row_str)
print()