Skip to content

Commit

Permalink
feat: add multigpu support for external native models
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob committed May 27, 2021
1 parent 5ae837f commit 90dcadd
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions src/backends/torch/native/native_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,35 @@

#include "native_net.h"

namespace std
{
// from
// https://stackoverflow.com/questions/47496358/c-lambdas-how-to-capture-variadic-parameter-pack-from-the-upper-scope
// see also:
// https://en.cppreference.com/w/cpp/experimental/apply
namespace detail
{
template <class F, class Tuple, std::size_t... I>
constexpr decltype(auto) apply_impl(F &&f, Tuple &&t,
std::index_sequence<I...>)
{
// return std::invoke(std::forward<F>(f),
// std::get<I>(std::forward<Tuple>(t))...);
// Note: std::invoke is a C++17 feature
return std::forward<F>(f)(std::get<I>(static_cast<Tuple &&>(t))...);
}
} // namespace detail

template <class F, class Tuple>
constexpr decltype(auto) apply(F &&f, Tuple &&t)
{
return detail::apply_impl(
std::forward<F>(f), std::forward<Tuple>(t),
std::make_index_sequence<
std::tuple_size<std::decay_t<Tuple>>::value>{});
}
}

namespace dd
{
template <typename TModule>
Expand All @@ -41,6 +70,14 @@ namespace dd
template <typename... Args>
NativeModuleWrapper(Args &&... args) : _module(args...)
{
_clone_function
= [args = std::make_tuple(std::forward<Args>(args)...)]() {
return std::apply(
[](auto &&... args) {
return new NativeModuleWrapper<TModule>(args...);
},
std::move(args));
};
this->register_module("wrapped", _module);
}

Expand Down Expand Up @@ -86,6 +123,9 @@ namespace dd
throw MLLibInternalException(
"NativeModuleWrapper::loss not implemented");
}

private:
std::function<NativeModuleWrapper<TModule> *()> _clone_function;
};
}

Expand Down

0 comments on commit 90dcadd

Please sign in to comment.