Skip to content

Commit

Permalink
Improve outlines.processors, add integration tests to test_generate…
Browse files Browse the repository at this point in the history
….py (#998)

A lot of these fixes were intended for
#966 however that's blocked
until there's a new `transformers` release.

These improvements are general to all models and will enable PRs
resolving #806 and
#965

# Structure of `OutlinesLogitsProcessor`

The goal is to create a base class which allows a logits processors to
be implemented once and used for any `outlines.models` inference
library.

To accomplish this we must normalize the input array. It must have a
consistent type (`torch.Tensor`) and consistent dimensionality (2). We
can normalize both of these simply, and without any copy operations.

`mlx.core.array`, `numpy.array`, and `torch.Tensor` all support [pythons
array standard
`__dlpack__`](https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html).
This standard allows for casting between array types without copying.

`torch.Tensor` is the only input type which cannot always be cast to any
other type because torch tensors may live in GPU memory. Therefore, we
cast all arrays to `torch.Tensor`, implement logits processors using
torch methods, and convert back to the original array type in
`OutlinesLogitsProcessor`. See docstring of
`OutlinesLogitsProcessor.__call__()` for more details.

# Detailed Changes
- Rename `BaseLogitsProcessor` to `OutlinesLogitsProcessor`
- Ensure `OutlinesLogitsProcessor.process_logits()` is always passed a
2D batch request with `torch.Tensor` logits and `List` input_ids. Also
clean up code to be more readable in `OutlinesLogitsProcessor__call__()`
- Ensure `FSMLogitsProcessor` allows unstable sequence ordering (beam
search in transformers and vLLM change the order of sequences)
- Update `tests/generate/test_generate.py` to cover more permutations of
  - regex / text 
  - batch / single
  - greedy / multinomial / beam search
  - `stream()` / `generate()`
- Ensure performance stability with difference array libraries through
`benchmark_processors.py`
  • Loading branch information
lapp0 committed Jun 30, 2024
1 parent 2807dca commit a643cb0
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 83 deletions.
52 changes: 52 additions & 0 deletions benchmarks/bench_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import mlx.core as mx
import numpy as np
import torch

from outlines.processors import OutlinesLogitsProcessor


def is_mlx_lm_allowed():
try:
import mlx.core as mx
except ImportError:
return False
return mx.metal.is_available()


class HalvingLogitsProcessor(OutlinesLogitsProcessor):
"""Simply halve the passed logits"""

def process_logits(self, input_ids, logits):
return logits / 2


class LogitsProcessorBenchmark:
params = ["torch", "numpy"]
if mx.metal.is_available():
params += ["mlx"]

def setup(self, array_library):
self.logits_processor = HalvingLogitsProcessor()

# logits: (4, 30,000 ) dtype=float
# input_ids shape: (4, 2048) dtype=int
if array_library == "torch":
self.logits = torch.rand((4, 30000), dtype=torch.float)
self.input_ids = torch.randint(
low=0, high=30000, size=(4, 2048), dtype=torch.int
)
elif array_library == "numpy":
self.logits = np.random.rand(4, 30000).astype(np.float32)
self.input_ids = np.random.randint(low=0, high=30000, size=(4, 2048))
elif array_library == "mlx":
self.logits = mx.random.uniform(
low=-1e9, high=1e9, shape=(4, 30000), dtype=mx.float32
)
self.input_ids = mx.random.randint(
low=0, high=30000, shape=(4, 2048), dtype=mx.int32
)
else:
raise ValueError

def time_logits_processor(self, array_library):
self.logits_processor(self.input_ids, self.logits)
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import outlines.generate
import outlines.grammars
import outlines.models
import outlines.processors
import outlines.types
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
Expand Down
6 changes: 3 additions & 3 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers import PreTrainedTokenizer

from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.processors import BaseLogitsProcessor
from outlines.processors import OutlinesLogitsProcessor


class MLXLM:
Expand Down Expand Up @@ -127,7 +127,7 @@ def generate_step(
temp: Optional[float],
top_p: Optional[float],
sampler: str,
logits_processor: "BaseLogitsProcessor",
logits_processor: "OutlinesLogitsProcessor",
) -> Generator[Tuple[int, float], None, None]:
"""
Adapted from
Expand All @@ -142,7 +142,7 @@ def generate_step(
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
sampler (str): The sampler string defined by SequenceGeneratorAdapter
logits_processor (BaseLogitsProcessor): Augment logits before sampling.
logits_processor (OutlinesLogitsProcessor): Augment logits before sampling.
"""
import mlx.core as mx
import mlx_lm
Expand Down
2 changes: 1 addition & 1 deletion outlines/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .structured import (
BaseLogitsProcessor,
CFGLogitsProcessor,
FSMLogitsProcessor,
JSONLogitsProcessor,
OutlinesLogitsProcessor,
RegexLogitsProcessor,
)
145 changes: 101 additions & 44 deletions outlines/processors/base_logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from abc import abstractmethod
from typing import List, Protocol, Union
from typing import TYPE_CHECKING, List, Protocol, Type, Union

import numpy as np
import torch
from numpy.typing import NDArray

if TYPE_CHECKING:
import mlx.core as mx

def is_mlx_array(logits):

Array = Union[NDArray, torch.Tensor, List, "mx.array"]


def is_mlx_array_type(array_type):
try:
import mlx.core as mx
except ImportError:
return False
return isinstance(logits, mx.array)
return issubclass(array_type, mx.array)


class BaseLogitsProcessor(Protocol):
class OutlinesLogitsProcessor(Protocol):
"""
Base class for logits processors which normalizes types of logits:
- ndarray (used by llama-cpp-python), converted to torch.Tensor
- mlx.core.array (used by mlx-lm), converted to torch.Tensor
- torch.Tensor (used by everything else)
Normalization of types and conversion to torch.Tensor
Expand All @@ -29,50 +36,100 @@ class BaseLogitsProcessor(Protocol):

@abstractmethod
def process_logits(
self, input_ids: List[int], logits: torch.Tensor
self, input_ids: List[List[int]], logits: torch.Tensor
) -> torch.Tensor:
...
"""
input_ids and logits are always 2D tensors for handling a batch of sequences.
- input_ids -> List[List[tokens]]
- logits.shape[0] -> 2D_Tensor[logits]
Important to keep in mind when designing universal logits processors
- logits processors are only used once and never re-applied for a new sequence generator
- Some models only pass output_ids, some models such as llamacpp and transformers prefix with input_ids
- Some sampling methods, such as beam search, result in unstable sequence ordering in models like vLLM
"""
pass

@torch.no_grad()
def __call__(
self,
input_ids: Union[NDArray[np.int64], List[int], torch.Tensor],
logits: Union[NDArray[np.float32], torch.Tensor],
) -> Union[NDArray[np.int64], torch.Tensor]:
input_ids: Array,
logits: Array,
) -> Array:
"""
Apply logits processor
Unify type
- convert input_ids: either ndarray, List[int], or Tensor -> List[int]
- convert logits: either ndarray, mlx array, Tensor -> Tensor
Call process_logits() to perform business logic
1) Unify type
- convert input_ids: either ndarray, mlx array, List[int], or Tensor -> List[List[int]]
- convert logits: either ndarray, mlx array, or Tensor -> 2D float Tensor
2) Unify shape, ensure logits and input_ids are 2D
3) Call self.process_logits() to perform business logic
4) Cast logits back to original array library type
"""
with torch.no_grad():
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()

if isinstance(logits, np.ndarray):
# Unify type, convert numpy array to Tensor
# from_numpy and .numpy() don't copy the data, it uses the same memory address
torch_logits = torch.from_numpy(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)
return processed_torch_logits.detach().numpy()

elif isinstance(logits, torch.Tensor):
return self.process_logits(input_ids, logits)

elif is_mlx_array(logits):
# mlx -> torch -> mlx conversion docs:
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
import mlx.core as mx

torch_logits = torch.from_dlpack(logits)
processed_torch_logits = self.process_logits(input_ids, torch_logits)

# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
logits_float32_numpy = processed_torch_logits.float().numpy()
return mx.array(logits_float32_numpy)

else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray"
", torch.Tensor, or mlx.core.array typed logits"
)

# ensure logits are torch Tensors
torch_logits = self._to_torch(logits)

assert torch_logits.shape[:-1] == self._to_torch(input_ids).shape[:-1]

# ensure input_ids are List
if not isinstance(input_ids, list):
input_ids = input_ids.tolist() # compatible with numpy, torch, and mlx

# Guarantee passed as 2D Tensors, then covert back to original (1D or 2D) shape
if len(torch_logits.shape) == 2:
processed_logits = self.process_logits(input_ids, torch_logits)
elif len(torch_logits.shape) == 1:
processed_logits = self.process_logits(
[input_ids], torch_logits.unsqueeze(0)
).squeeze(0)

# return logits as passed array type
return self._from_torch(processed_logits, type(logits))

@staticmethod
def _to_torch(tensor_like: Array) -> torch.Tensor:
"""Convert various types to torch.Tensor."""
if isinstance(tensor_like, torch.Tensor):
return tensor_like

elif isinstance(tensor_like, np.ndarray):
return torch.from_numpy(tensor_like)

elif isinstance(tensor_like, list):
return torch.tensor(tensor_like)

elif is_mlx_array_type(type(tensor_like)):
# mlx -> torch -> mlx conversion docs:
# https://ml-explore.github.io/mlx/build/html/usage/numpy.html
return torch.from_dlpack(tensor_like)

else:
raise TypeError(
"LogitsProcessor must be called with either np.NDArray, "
"torch.Tensor, list, or mlx.core.array typed logits"
)

@staticmethod
def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
"""Convert torch.Tensor to the specified target type."""
if target_type == torch.Tensor:
return tensor

elif target_type == np.ndarray:
return tensor.detach().numpy()

elif target_type == list:
return tensor.detach().tolist()

elif is_mlx_array_type(target_type):
import mlx.core as mx

# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
return mx.array(tensor.float().numpy())

else:
raise TypeError(
f"Failed to convert torch tensors to target_type `{target_type}`"
)
39 changes: 26 additions & 13 deletions outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,22 @@
limitations under the License.
"""
import math
from typing import TYPE_CHECKING, List, Optional, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union

import numpy as np
import torch
from numpy.typing import NDArray
from pydantic import BaseModel

from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.integrations.utils import convert_json_schema_to_str

from .base_logits_processor import BaseLogitsProcessor
from .base_logits_processor import OutlinesLogitsProcessor

if TYPE_CHECKING:
from outlines.models.tokenizer import Tokenizer


class FSMLogitsProcessor(BaseLogitsProcessor):
class FSMLogitsProcessor(OutlinesLogitsProcessor):
"""Bias generation using a finite state machine.
Attributes
Expand All @@ -63,13 +61,14 @@ def __init__(self, tokenizer: "Tokenizer", fsm: Guide):
The finite state machine which is used to bias the logits.
"""
self.tokenizer = tokenizer
self._fsm_state = 0
self._fsm_states: Dict[int, int] = {}
self.fsm: Guide = fsm
self._is_first_token = True
self._seq_start_idx: Optional[int] = None

def process_logits(
self, input_ids: List[int], logits: torch.Tensor
) -> NDArray[np.float32]:
self, input_ids: List[List[int]], logits: torch.Tensor
) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.
Parameters
Expand All @@ -84,17 +83,31 @@ def process_logits(
torch.Tensor
The biased logits.
"""
sequence_states: List[int] = [] # vector of states corresponding to `input_ids`

if self._is_first_token:
self._is_first_token = False
self._seq_start_idx = len(input_ids[0])

self._fsm_states = {hash(tuple([])): 0}
sequence_states = [0] * len(input_ids)

else:
last_token = input_ids[-1]
self._fsm_state = self.fsm.get_next_state(self._fsm_state, last_token)
for seq_ids in input_ids:
prev_state_key = hash(tuple(seq_ids[self._seq_start_idx : -1]))
prev_state = self._fsm_states[prev_state_key]

allowed_tokens = self.fsm.get_next_instruction(self._fsm_state).tokens
allowed_tokens = torch.tensor(allowed_tokens, device=logits.device)
curr_state_key = hash(tuple(seq_ids[self._seq_start_idx :]))
curr_state = self.fsm.get_next_state(prev_state, seq_ids[-1])

self._fsm_states[curr_state_key] = curr_state
sequence_states.append(curr_state)

mask = torch.full_like(logits, -math.inf)
mask[allowed_tokens] = logits[allowed_tokens]
for i, fsm_state in enumerate(sequence_states):
allowed_tokens = self.fsm.get_next_instruction(fsm_state).tokens
mask[i, allowed_tokens] = logits[i, allowed_tokens]

return mask

def copy(self) -> "FSMLogitsProcessor":
Expand Down
Loading

0 comments on commit a643cb0

Please sign in to comment.