Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Oct 11, 2024
1 parent c8ffcd5 commit 3db58b6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
shape = [self.batch_size, self.sequence_length]
elif input_name == 'timesteps':
shape = [self.batch_size, self.sequence_length]
return self.random_int_tensor(shape=shape, max_value=max_ep_len, framework=framework, dtype=int_dtype)
return self.random_int_tensor(shape=shape, max_value=self.max_ep_len, framework=framework, dtype=int_dtype)

return self.random_float_tensor(shape, min_value=-2., max_value=2., framework=framework, dtype=float_dtype)

Expand Down
3 changes: 2 additions & 1 deletion optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ class NormalizedTextConfig(NormalizedConfig):

class NormalizedDecisionTransformerConfig(NormalizedConfig):
# REFERENCE: https://huggingface.co/docs/transformers/model_doc/decision_transformer
STATE_DIM = "state_dim"
ACT_DIM = "act_dim"
STATE_DIM = "state_dim"

MAX_EP_LEN = "max_ep_len"
HIDDEN_SIZE = "hidden_size"

Expand Down

0 comments on commit 3db58b6

Please sign in to comment.