Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Improved Tensor Dimension Handling in predict_masks Method #581

Open
sushmanthreddy opened this issue Sep 29, 2023 · 0 comments · May be fixed by #580
Open

Improved Tensor Dimension Handling in predict_masks Method #581

sushmanthreddy opened this issue Sep 29, 2023 · 0 comments · May be fixed by #580

Comments

@sushmanthreddy
Copy link

Issue:
In the predict_masks method of the MaskDecoder class, there's an enhancement regarding tensor dimension handling. Here's a detailed breakdown:

  1. Conditional Check:

    • A new check if image_embeddings.shape[0] != tokens.shape[0]: has been added to ascertain tensor dimension consistency before applying torch.repeat_interleave.
  2. Usage of torch.repeat_interleave:

    • Ensures image_embeddings tensor's batch size aligns with tokens by expanding it along the batch dimension.
  3. Ensuring Consistency:

    • This check ensures that torch.repeat_interleave is applied only when necessary, ensuring consistent tensor handling within the predict_masks method, as opposed to the original implementation where torch.repeat_interleave is applied directly.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
1 participant