Skip to content

Commit

Permalink
[Cherry-Pick] Fix the token_generator behavior for non-kv-cache models (
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Nov 30, 2023
1 parent d8d84ed commit 39e21d3
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/deepsparse/transformers/utils/token_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Optional

import numpy

Expand All @@ -32,7 +32,7 @@ class TokenGenerator:
def __init__(
self,
logits_shape: int,
tokens: List[int] = [],
tokens: Optional[List[int]] = None,
deterministic: bool = True,
sampling_temperature: float = 1.0,
top_k: int = 0,
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.tokens = tokens
self.tokens = [] if tokens is None else tokens

self._initialize_token_frequencies()

Expand All @@ -77,11 +77,16 @@ def generate(self, logits: numpy.ndarray) -> numpy.ndarray:
:param logits: the logits from the model with shape (vocab_size,)
:return: the sampled token
"""

if self.deterministic:
token = numpy.argmax(logits)
self.tokens.append(token)
return token

# make a copy of logits to avoid modifying the original
# logits distribution in-place
logits = logits.copy()

if self.top_k:
logits = self.apply_top_k(logits)

Expand Down Expand Up @@ -173,5 +178,5 @@ def _update_frequencies(self, token: numpy.ndarray):

def _initialize_token_frequencies(self):
unique_tokens, frequencies = numpy.unique(self.tokens, return_counts=True)
for token, frequnecies in zip(unique_tokens, frequencies):
self.token_frequencies[token] += frequnecies
for token, freq in zip(unique_tokens, frequencies):
self.token_frequencies[token] += freq

0 comments on commit 39e21d3

Please sign in to comment.