Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qlearning+sarsa v2.0 #1005

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions dlib/control.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include "control/lspi.h"
#include "control/mpc.h"
#include "control/qlearning.h"
#include "control/sarsa.h"

#endif // DLIB_CONTRoL_

Expand Down
147 changes: 119 additions & 28 deletions dlib/control/approximate_linear_models.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@

#include "approximate_linear_models_abstract.h"
#include "../matrix.h"
#include <random>

namespace dlib
{

// ----------------------------------------------------------------------------------------

template <
typename feature_extractor
typename model_type
>
struct process_sample
{
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
typedef typename model_type::state_type state_type;
typedef typename model_type::action_type action_type;

process_sample(){}

Expand Down Expand Up @@ -56,68 +56,159 @@ namespace dlib
// ----------------------------------------------------------------------------------------

template <
typename feature_extractor
typename model_type
>
class policy
{
public:

typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;

typedef typename model_type::state_type state_type;
typedef typename model_type::action_type action_type;

policy (
)
const model_type& model_ = model_type()
) : model(model_)
{
w.set_size(fe.num_features());
w = 0;
weights.set_size(model.num_features());
weights = 0;
}

policy (
const matrix<double,0,1>& weights_,
const feature_extractor& fe_
) : w(weights_), fe(fe_) {}
const model_type &model_
) : weights(weights_), model(model_) {}

policy(const policy<model_type>&) = default;
policy<model_type>& operator=(const policy<model_type>&) = default;

policy(policy<model_type>&&) = default;
policy<model_type>& operator=(policy<model_type>&&) = default;

action_type operator() (
const state_type& state
) const
{
return fe.find_best_action(state,w);
return model.find_best_action(state,weights);
}

const feature_extractor& get_feature_extractor (
) const { return fe; }
const model_type& get_model (
) const { return model; }

const matrix<double,0,1>& get_weights (
) const { return w; }
) const { return weights; }

matrix<double,0,1>& get_weights (
) { return weights; }

private:
matrix<double,0,1> w;
feature_extractor fe;
matrix<double,0,1> weights;
model_type model;
};

template < typename feature_extractor >
inline void serialize(const policy<feature_extractor>& item, std::ostream& out)
template < typename model_type >
inline void serialize(const policy<model_type>& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.get_feature_extractor(), out);
serialize(item.get_model(), out);
serialize(item.get_weights(), out);
}
template < typename feature_extractor >
inline void deserialize(policy<feature_extractor>& item, std::istream& in)
template < typename model_type >
inline void deserialize(policy<model_type>& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::policy object.");
feature_extractor fe;
model_type model;
matrix<double,0,1> w;
deserialize(fe, in);
deserialize(model, in);
deserialize(w, in);
item = policy<feature_extractor>(w,fe);
item = policy<model_type>(w,model);
}

// ----------------------------------------------------------------------------------------

template <
typename policy_type,
typename prng_engine = std::default_random_engine
>
class epsilon_policy
{
public:
typedef typename policy_type::state_type state_type;
typedef typename policy_type::action_type action_type;

epsilon_policy (
double epsilon_,
const policy_type& policy_,
const prng_engine &gen_ = prng_engine()
) : underlying_policy(policy_), epsilon(epsilon_), gen(gen_) {}

epsilon_policy(const epsilon_policy<policy_type, prng_engine>&) = default;
epsilon_policy<policy_type, prng_engine>& operator=(const epsilon_policy<policy_type, prng_engine>&) = default;

epsilon_policy(epsilon_policy<policy_type, prng_engine>&&) = default;
epsilon_policy<policy_type, prng_engine>& operator=(epsilon_policy<policy_type, prng_engine>&&) = default;

action_type operator() (
const state_type& state
) const
{
std::bernoulli_distribution d(epsilon);
return d(gen) ? underlying_policy.get_model().random_action(state) : underlying_policy(state);
}

const policy_type& get_policy(
) const { return underlying_policy; }

auto get_model (
) const -> decltype(this->get_policy().get_model()) { return underlying_policy.get_model(); }

matrix<double,0,1>& get_weights (
) { return underlying_policy.get_weights(); }

const matrix<double,0,1>& get_weights (
) const { return underlying_policy.get_weights(); }

double get_epsilon(
) const { return epsilon; }

const prng_engine& get_generator(
) const { return gen; }

private:
policy_type underlying_policy;
double epsilon;

mutable prng_engine gen;
};

template < typename policy_type, typename generator >
inline void serialize(const epsilon_policy<policy_type, generator>& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
serialize(item.get_policy(), out);
serialize(item.get_epsilon(), out);
serialize(item.get_generator(), out);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? Are there serialize routines defined for the random number generators in std::? You should add unit tests that invoke the serialization routines for these new objects to make sure they all work.

}

template < typename policy_type, typename generator >
inline void deserialize(epsilon_policy<policy_type, generator>& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::policy object.");

policy_type policy;
double epsilon;
generator gen;
deserialize(policy, in);
deserialize(epsilon, in);
deserialize(gen, in);
item = epsilon_policy<policy_type, generator>(epsilon, policy, gen);
}

// ----------------------------------------------------------------------------------------
Expand Down
Loading