Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[micro_wake_word] save/restore wake word enabled state to flash #149

Merged
merged 3 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion esphome/components/micro_wake_word/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ async def to_code(config):
# Use the general model loading code for the VAD codegen
config[CONF_MODELS].append(vad_model)

for model_parameters in config[CONF_MODELS]:
for i, model_parameters in enumerate(config[CONF_MODELS]):
model_config = model_parameters.get(CONF_MODEL)
data = []
manifest, data = _model_config_to_manifest_data(model_config)
Expand Down Expand Up @@ -500,6 +500,8 @@ async def to_code(config):
)
)
else:
# Only enable the first wake word by default. After first boot, the enable state is saved/loaded to the flash
default_enabled = i == 0
wake_word_model = cg.new_Pvariable(
model_parameters[CONF_ID],
str(model_parameters[CONF_ID]),
Expand All @@ -508,6 +510,7 @@ async def to_code(config):
sliding_window_size,
manifest[KEY_WAKE_WORD],
manifest[KEY_MICRO][CONF_TENSOR_ARENA_SIZE],
default_enabled,
)

for lang in manifest[KEY_TRAINED_LANGUAGES]:
Expand Down
26 changes: 23 additions & 3 deletions esphome/components/micro_wake_word/streaming_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ void StreamingModel::reset_probabilities() {
this->ignore_windows_ = -MIN_SLICES_BEFORE_DETECTION;
}

WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start,
uint8_t probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size) {
WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff,
size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
bool default_enabled) {
this->id_ = id;
this->model_start_ = model_start;
this->probability_cutoff_ = probability_cutoff;
Expand All @@ -172,8 +172,28 @@ WakeWordModel::WakeWordModel(const std::string &id, const uint8_t *model_start,
this->tensor_arena_size_ = tensor_arena_size;
this->register_streaming_ops_(this->streaming_op_resolver_);
this->current_stride_step_ = 0;

this->pref_ = global_preferences->make_preference<bool>(fnv1_hash(id));
bool enabled;
if (this->pref_.load(&enabled)) {
// Use the enabled state loaded from flash
this->enabled_ = enabled;
} else {
// If no state saved, then use the default
this->enabled_ = default_enabled;
}
};

void WakeWordModel::enable() {
this->enabled_ = true;
this->pref_.save(&this->enabled_);
}

void WakeWordModel::disable() {
this->enabled_ = false;
this->pref_.save(&this->enabled_);
}

DetectionEvent WakeWordModel::determine_detected() {
DetectionEvent detection_event;
detection_event.wake_word = &this->wake_word_;
Expand Down
22 changes: 15 additions & 7 deletions esphome/components/micro_wake_word/streaming_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "preprocessor_settings.h"

#include "esphome/core/preferences.h"

#include <tensorflow/lite/core/c/common.h>
#include <tensorflow/lite/micro/micro_interpreter.h>
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
Expand Down Expand Up @@ -44,12 +46,10 @@ class StreamingModel {
void unload_model();

/// @brief Enable the model. The next performing_streaming_inference call will load it.
void enable() {
this->enabled_ = true;
}
virtual void enable() { this->enabled_ = true; }

/// @brief Disable the model. The next performing_streaming_inference call will unload it.
void disable() { this->enabled_ = false; }
virtual void disable() { this->enabled_ = false; }

/// @brief Return true if the model is enabled.
bool is_enabled() { return this->enabled_; }
Expand Down Expand Up @@ -87,9 +87,9 @@ class StreamingModel {

class WakeWordModel final : public StreamingModel {
public:
WakeWordModel(const std::string &id, const uint8_t *model_start,
uint8_t probability_cutoff, size_t sliding_window_average_size,
const std::string &wake_word, size_t tensor_arena_size);
WakeWordModel(const std::string &id, const uint8_t *model_start, uint8_t probability_cutoff,
size_t sliding_window_average_size, const std::string &wake_word, size_t tensor_arena_size,
bool default_enabled);

void log_model_config() override;

Expand All @@ -104,10 +104,18 @@ class WakeWordModel final : public StreamingModel {
void add_trained_language(const std::string &language) { this->trained_languages_.push_back(language); }
const std::vector<std::string> &get_trained_languages() const { return this->trained_languages_; }

/// @brief Enable the model and save to flash. The next performing_streaming_inference call will load it.
virtual void enable() override;

/// @brief Disable the model and save to flash. The next performing_streaming_inference call will unload it.
virtual void disable() override;

protected:
std::string id_;
std::string wake_word_;
std::vector<std::string> trained_languages_;

ESPPreferenceObject pref_;
};

class VADModel final : public StreamingModel {
Expand Down