Skip to content

Commit

Permalink
Allow using custom training_tracker_provider
Browse files Browse the repository at this point in the history
  • Loading branch information
nanoeti committed Mar 31, 2024
1 parent 47b91bb commit a99a859
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
19 changes: 14 additions & 5 deletions rasa/engine/recipes/default_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def decorator(registered_class: Type[GraphComponent]) -> Type[GraphComponent]:
else:
unique_types = set(component_types)

cls._registered_components[
registered_class.__name__
] = cls.RegisteredComponent(
registered_class, unique_types, is_trainable, model_from
cls._registered_components[registered_class.__name__] = (
cls.RegisteredComponent(
registered_class, unique_types, is_trainable, model_from
)
)
return registered_class

Expand Down Expand Up @@ -581,12 +581,21 @@ def _add_core_train_nodes(
config={"exclusion_percentage": cli_parameters.get("exclusion_percentage")},
is_input=True,
)

training_tracker_provider_name = train_config.get("training_tracker_provider")
if training_tracker_provider_name is not None:
training_tracker_provider_cls = self._from_registry(
training_tracker_provider_name
).clazz
else:
training_tracker_provider_cls = TrainingTrackerProvider

train_nodes["training_tracker_provider"] = SchemaNode(
needs={
"story_graph": "story_graph_provider",
"domain": "domain_for_core_training_provider",
},
uses=TrainingTrackerProvider,
uses=training_tracker_provider_cls,
constructor_name="create",
fn="provide",
config={
Expand Down
2 changes: 1 addition & 1 deletion version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.6.5-e8-0.3.0
3.6.5-e8-0.4.0

0 comments on commit a99a859

Please sign in to comment.