Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ragged inputs to model #666

Merged

Conversation

oliverholworthy
Copy link
Member

@oliverholworthy oliverholworthy commented Apr 5, 2023

Part of NVIDIA-Merlin/Merlin#255

Goals ⚽

  • Enable Transformers4Rec model to be called with ragged input representation.

Implementation Details 🚧

  • Adds pre-processing step to the first part of the forward method of the model that pads any tensors in the ragged representation.
    • Where there are two tensors with names {feature}__values {feature}__offests.
    • Pads all to minimum of the maximum sequence in batch or the model max_sequence_length (if defined)

Testing Details 🔍

  • Adds a test for model with sequence inputs and passing ragged representation inputs

@oliverholworthy oliverholworthy added the enhancement New feature or request label Apr 5, 2023
@oliverholworthy oliverholworthy added this to the Merlin 23.04 milestone Apr 5, 2023
@oliverholworthy oliverholworthy self-assigned this Apr 5, 2023
Co-authored-by: Marc Romeyn <marcromeyn@gmail.com>
@github-actions
Copy link

github-actions bot commented Apr 5, 2023

@oliverholworthy oliverholworthy marked this pull request as ready for review April 5, 2023 17:18
)
model_output = model(inference_inputs)

# if the model is traced with ragged inputs it can only be called with ragged inputs
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that when tracing the model, the representation used as the input determines what the inputs to the traced model expects. (padded vs ragged)

batch_padded = {}
for col_name, col in TensorTable(batch).items():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorTable is not currently compatible with torch.jit.script compilation.

example of one of the errors that shows up (I don't think it prints out all errors, only the first it encounters -> there may be more unsupported things apart from the below example)

E   torch.jit.frontend.UnsupportedNodeError: SetComp aren't supported:
E     File "/workspace/merlin/core/merlin/table/tensor_table.py", line 61
E       def _validate_columns(self, cols_dict):
E           col_types = {type(col_obj) for col_obj in cols_dict.values()}
E                       ~ <--- HERE
E           if len(col_types) >= 2:
E               raise TypeError(
E   '__torch__.merlin.table.tensor_table.TensorTable' is being compiled since it was called from 'pad_batch'

]
),
)
assert torch.equal(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dense sequence inputs are not padded as part of this pad_inputs currently. Assuming we'll either have ragged or padded sequence inputs, not a mix of both

@@ -481,6 +482,7 @@ def __init__(
head_reduction: str = "mean",
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
name: str = None,
max_sequence_length: Optional[int] = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a max_sequence_length to limit the size of the padding when receiving ragged inputs.



@torch.jit.script
def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def pad_inputs(inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None):
def pad_inputs(
inputs: Dict[str, torch.Tensor], max_sequence_length: Optional[int] = None
) -> Dict[str, torch.Tensor]:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants