Skip to content

Commit

Permalink
Merge branch 'master' into semimdp
Browse files Browse the repository at this point in the history
  • Loading branch information
markkho committed Oct 17, 2023
2 parents df63214 + 1b7e11b commit e572cdc
Show file tree
Hide file tree
Showing 3 changed files with 318 additions and 42 deletions.
113 changes: 85 additions & 28 deletions msdm/algorithms/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import heapq
import random
from re import M
from typing import Dict, Union
from typing import Dict, Union, NamedTuple, Any

from msdm.core.algorithmclasses import Plans, Result
from msdm.core.distributions import DeterministicDistribution, DictDistribution, dictdistribution
Expand Down Expand Up @@ -86,6 +86,19 @@ def plan_on(self, dsp: MarkovDecisionProcess):
queue.append(ns)
camefrom[ns] = (s, a)

class AStarSearchNode(NamedTuple):
'''
Search nodes used in AStarSearch.
NOTE: Order of fields is important because it determines how elements are ordered in the heap.
In particular, heuristic_cost needs to be the first field, and the tie_break should be before
any others.
'''
heuristic_cost: float
tie_break: float
cost_from_start: float
state: Any

class AStarSearch(Plans):
"""
A* Search is an informed best-first search algorithm. It considers states in priority order
Expand All @@ -99,13 +112,17 @@ def __init__(
heuristic_value=lambda s: 0,
seed=None,
randomize_action_order=False,
tie_breaking_strategy='lifo'
tie_breaking_strategy='lifo',
assert_monotone_heuristic=True,
):
self.heuristic_value = heuristic_value
self.seed = seed
self.randomize_action_order = randomize_action_order
assert tie_breaking_strategy in ['random', 'lifo', 'fifo']
self.tie_breaking_strategy = tie_breaking_strategy
self.assert_monotone_heuristic = assert_monotone_heuristic
if seed is not None:
assert tie_breaking_strategy == 'random' or randomize_action_order, 'Seed was supplied, but tie-breaking and action order are deterministic.'

def plan_on(self, dsp: MarkovDecisionProcess):
rnd = random.Random(self.seed)
Expand All @@ -116,49 +133,89 @@ def plan_on(self, dsp: MarkovDecisionProcess):

dsp = DeterministicShortestPathProblem.from_mdp(dsp)

# Every queue entry is a pair of
# - a tuple of priorities/costs (the cost-to-go, a tie-breaker, and cost-so-far)
# - the state
tie_break = 0
if self.tie_breaking_strategy == 'lifo':
# The heap is a min-heap, so to ensure last-in first-out
# the tie-breaker must decrease. Since it's always
# decreasing, later elements of equivalent value have greater priority.
tie_break_delta = -1
elif self.tie_breaking_strategy == 'fifo':
# See above comment. First-in first-out requires that our tie-breaker increases.
tie_break_delta = +1

queue = []
start = dsp.initial_state()
if self.tie_breaking_strategy in ['lifo', 'fifo']:
tie_break = 0
if self.tie_breaking_strategy == 'lifo':
# The heap is a min-heap, so to ensure last-in first-out
# the tie-breaker must decrease. Since it's always
# decreasing, later elements of equivalent value have greater priority.
tie_break_delta = -1
# This holds the previous best node that was added to the queue, for each state.
# This previous best node is also the last node for a state, since we only add when a node is an improvement.
best_in_queue_by_state = dict()
def push(*, heuristic_cost, cost_from_start, state):
nonlocal tie_break
if self.tie_breaking_strategy in ['lifo', 'fifo']:
tie_break += tie_break_delta
else:
# See above comment. First-in first-out requires that our tie-breaker increases.
tie_break_delta = +1
else:
tie_break = rnd.random()
heapq.heappush(queue, ((-self.heuristic_value(start), tie_break, 0), start))
tie_break = rnd.random()
node = AStarSearchNode(heuristic_cost=heuristic_cost, tie_break=tie_break, cost_from_start=cost_from_start, state=state)
heapq.heappush(queue, node)
best_in_queue_by_state[node.state] = node
return node

# Add the initial node.
start = dsp.initial_state()
push(heuristic_cost=-self.heuristic_value(start), cost_from_start=0, state=start)

visited = set([])
camefrom = dict()
non_monotonic_counter = 0

while queue:
(heuristic_cost, _, cost_from_start), s = heapq.heappop(queue)
node = heapq.heappop(queue)
s = node.state

# If the state has been previously visited, then this is a worse node that should be skipped.
if s in visited:
assert s not in best_in_queue_by_state, 'Previously visited node should not be best node.'
continue
else:
# We use `is` instead of `==` to ensure the nodes are the same object instances, not just equal.
assert best_in_queue_by_state[s] is node, 'Newly visited state should be stored as best node.'
# Remove the reference to this node, now that it's been removed from the queue.
del best_in_queue_by_state[s]

# Handle the case of a goal state.
if dsp.is_absorbing(s):
assert node.heuristic_cost == node.cost_from_start
path = reconstruct_path(camefrom, start, s)
return Result(
path=path,
path_value=node.cost_from_start,
policy=camefrom_to_policy(path, camefrom, dsp),
visited=visited,
non_monotonic_counter=non_monotonic_counter,
)

# Mark the current state as visited.
visited.add(s)

for a in shuffled(dsp.actions(s)):
ns = dsp.next_state(s, a)
if ns not in visited and ns not in [el[-1] for el in queue]:
next_cost_from_start = cost_from_start - dsp.reward(s, a, ns)
next_heuristic_cost = next_cost_from_start - self.heuristic_value(ns)
if self.tie_breaking_strategy in ['lifo', 'fifo']:
tie_break += tie_break_delta
else:
tie_break = rnd.random()
heapq.heappush(queue, ((next_heuristic_cost, tie_break, next_cost_from_start), ns))
camefrom[ns] = (s, a)
# We skip previously-visited states.
if ns in visited:
continue
next_cost_from_start = node.cost_from_start - dsp.reward(s, a, ns)
# If the state has been reached before in a lower-cost node, then we skip.
if ns in best_in_queue_by_state and best_in_queue_by_state[ns].cost_from_start <= next_cost_from_start:
continue
# At this point, we've either newly reached this state, or we have reached it in
# a lower-cost way. So, we add it to the search queue.
next_node = push(
heuristic_cost=next_cost_from_start - self.heuristic_value(ns),
cost_from_start=next_cost_from_start,
state=ns,
)
camefrom[ns] = (s, a)

# Checking that the heuristic is monotonic, i.e. that our previous heuristic cost was a lower bound to the current one.
monotone = node.heuristic_cost <= next_node.heuristic_cost
if not monotone:
non_monotonic_counter += 1
if self.assert_monotone_heuristic:
assert monotone, f'Heuristic is non-monotonic, with previous node {node} and next node {next_node}'
39 changes: 39 additions & 0 deletions msdm/tests/domains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,42 @@ def optimal_state_value(self):
(3, 1): 0.,
(3, 2): 0.,
}

class RomaniaSubsetAIMA(DeterministicShortestPathProblem, TabularMarkovDecisionProcess, TestDomain):
'''
This small weighted graph is from Figure 3.15 in Artificial Intelligence: A Modern Approach, 3rd edition.
It's used to illustrate an important case for Uniform Cost Search (and A* with no heuristic), where
a state can be subsequently encountered through a more efficient path.
'''
state_list = ('Sibiu', 'Fagaras', 'Rimnicu Vilcea', 'Pitesti', 'Bucharest')
costs = {
frozenset({'Sibiu', 'Fagaras'}): 99,
frozenset({'Sibiu', 'Rimnicu Vilcea'}): 80,
frozenset({'Rimnicu Vilcea', 'Pitesti'}): 97,
frozenset({'Pitesti', 'Bucharest'}): 101,
frozenset({'Fagaras', 'Bucharest'}): 211,
}

def initial_state(self):
return 'Sibiu'

def is_absorbing(self, s):
return s == 'Bucharest'

def actions(self, s):
return [
ns
for ns in self.state_list
if frozenset({s, ns}) in self.costs
]

def next_state(self, s, a):
assert frozenset({s, a}) in self.costs
return a

def reward(self, s, a, ns):
assert a == ns, (a, ns)
return -self.costs[frozenset({s, ns})]

def optimal_path(self):
return ['Sibiu', 'Rimnicu Vilcea', 'Pitesti', 'Bucharest']
Loading

0 comments on commit e572cdc

Please sign in to comment.