-
Notifications
You must be signed in to change notification settings - Fork 130
/
lqr.py
65 lines (50 loc) · 2.18 KB
/
lqr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
'''Linear Quadratic Regulator (LQR).'''
from safe_control_gym.controllers.base_controller import BaseController
from safe_control_gym.controllers.lqr.lqr_utils import compute_lqr_gain, get_cost_weight_matrix
from safe_control_gym.envs.benchmark_env import Task
class LQR(BaseController):
'''Linear quadratic regulator.'''
def __init__(
self,
env_func,
# Model args.
q_lqr: list = None,
r_lqr: list = None,
discrete_dynamics: bool = True,
**kwargs):
'''Creates task and controller.
Args:
env_func (Callable): Function to instantiate task/environment.
q_lqr (list): Diagonals of state cost weight.
r_lqr (list): Diagonals of input/action cost weight.
discrete_dynamics (bool): If to use discrete or continuous dynamics.
'''
super().__init__(env_func, **kwargs)
self.env = env_func()
# Controller params.
self.model = self.get_prior(self.env)
self.discrete_dynamics = discrete_dynamics
self.Q = get_cost_weight_matrix(q_lqr, self.model.nx)
self.R = get_cost_weight_matrix(r_lqr, self.model.nu)
self.env.set_cost_function_param(self.Q, self.R)
self.gain = compute_lqr_gain(self.model, self.model.X_EQ, self.model.U_EQ,
self.Q, self.R, self.discrete_dynamics)
def reset(self):
'''Prepares for evaluation.'''
self.env.reset()
def close(self):
'''Cleans up resources.'''
self.env.close()
def select_action(self, obs, info=None):
'''Determine the action to take at the current timestep.
Args:
obs (ndarray): The observation at this timestep.
info (dict): The info at this timestep.
Returns:
action (ndarray): The action chosen by the controller.
'''
step = self.extract_step(info)
if self.env.TASK == Task.STABILIZATION:
return -self.gain @ (obs - self.env.X_GOAL) + self.model.U_EQ
elif self.env.TASK == Task.TRAJ_TRACKING:
return -self.gain @ (obs - self.env.X_GOAL[step]) + self.model.U_EQ