-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[KV Cache Injection] Causal Mask for OPT #1688
[KV Cache Injection] Causal Mask for OPT #1688
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending rebase and comment regarding strong preference on using cast over where
|
||
@classmethod | ||
def add_positions_input(cls, model: ModelProto) -> ModelProto: | ||
def add_causal_mask_input(self, model: ModelProto) -> ModelProto: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like this needs rebase?
@@ -12,78 +12,55 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
from onnx import ModelProto, NodeProto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebase?
``` | ||
| causal_mask | ||
| | | ||
| Where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we using a where instead of a cast to bool? can you check with runtime for what would work better if either are fine? additionally, seems like a cast would read better in onnx vs where which involves a condition...
…1677) * initial commit * [KV Cache Injection] Causal Mask for CodeGen (#1676) * initial implementation; testing now * fix a small blunder * cleanup --------- Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com> * [KV Cache Injection] Causal Mask for OPT (#1688) * initial implementation; testing now * fix a small blunder * cleanup * initial implementation * on to testing with deepsparse --------- Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com> * replace boolean causal mask for int64 causal mask * better logging info * allow transformations to be also a list --------- Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com>
Add causal mask support for OPT models to enable multitoken prefill in the Deepsparse pipeline.
Manual Testing