Skip to content

Commit

Permalink
Add user error superclass (#1225)
Browse files Browse the repository at this point in the history
* Add user error superclass

* update class inheritance structure
  • Loading branch information
milocress authored May 22, 2024
1 parent 001e7c3 commit 8e29698
Showing 1 changed file with 55 additions and 23 deletions.
78 changes: 55 additions & 23 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
ALLOWED_MESSAGES_KEYS = {'messages'}

ErrorLocation = Union[Literal['TrainDataloader'], Literal['EvalDataloader']]
ErrorAttribution = Union[Literal['UserError'], Literal['InternalError'],
Literal['NetworkError']]
TrainDataLoaderLocation = 'TrainDataloader'
EvalDataLoaderLocation = 'EvalDataloader'

Expand All @@ -43,10 +45,29 @@ class ContextualError(Exception):
"""Error thrown when an error occurs in the context of a specific task."""

location: Optional[ErrorLocation] = None
error_attribution: Optional[ErrorAttribution] = None


class UserError(ContextualError):
"""Error thrown when an error is caused by user input."""

error_attribution = 'UserError'


class NetworkError(ContextualError):
"""Error thrown when an error is caused by a network issue."""

error_attribution = 'NetworkError'


class InternalError(ContextualError):
"""Error thrown when an error is caused by an internal issue."""

error_attribution = 'InternalError'


# Finetuning dataloader exceptions
class MissingHuggingFaceURLSplitError(ValueError, ContextualError):
class MissingHuggingFaceURLSplitError(ValueError, UserError):
"""Error thrown when there's no split used in HF dataset config."""

def __init__(self) -> None:
Expand All @@ -55,7 +76,7 @@ def __init__(self) -> None:
super().__init__(message)


class NotEnoughDatasetSamplesError(ValueError, ContextualError):
class NotEnoughDatasetSamplesError(ValueError, UserError):
"""Error thrown when there is not enough data to train a model."""

def __init__(
Expand Down Expand Up @@ -85,7 +106,7 @@ def __init__(


## Tasks exceptions
class UnknownExampleTypeError(KeyError, ContextualError):
class UnknownExampleTypeError(KeyError, UserError):
"""Error thrown when an unknown example type is used in a task."""

def __init__(self, example_keys: str) -> None:
Expand All @@ -99,15 +120,15 @@ def __init__(self, example_keys: str) -> None:
super().__init__(message)


class NotEnoughChatDataError(ValueError, ContextualError):
class NotEnoughChatDataError(ValueError, UserError):
"""Error thrown when there is not enough chat data to train a model."""

def __init__(self) -> None:
message = 'Chat example must have at least two messages'
super().__init__(message)


class ConsecutiveRepeatedChatRolesError(ValueError, ContextualError):
class ConsecutiveRepeatedChatRolesError(ValueError, UserError):
"""Error thrown when there are consecutive repeated chat roles."""

def __init__(self, repeated_role: str) -> None:
Expand All @@ -116,7 +137,7 @@ def __init__(self, repeated_role: str) -> None:
super().__init__(message)


class InvalidLastChatMessageRoleError(ValueError, ContextualError):
class InvalidLastChatMessageRoleError(ValueError, UserError):
"""Error thrown when the last message role in a chat example is invalid."""

def __init__(self, last_role: str, expected_roles: set[str]) -> None:
Expand All @@ -126,7 +147,7 @@ def __init__(self, last_role: str, expected_roles: set[str]) -> None:
super().__init__(message)


class IncorrectMessageKeyQuantityError(ValueError, ContextualError):
class IncorrectMessageKeyQuantityError(ValueError, UserError):
"""Error thrown when a message has an incorrect number of keys."""

def __init__(self, keys: List[str]) -> None:
Expand All @@ -135,7 +156,7 @@ def __init__(self, keys: List[str]) -> None:
super().__init__(message)


class InvalidRoleError(ValueError, ContextualError):
class InvalidRoleError(ValueError, UserError):
"""Error thrown when a role is invalid."""

def __init__(self, role: str, valid_roles: set[str]) -> None:
Expand All @@ -145,7 +166,7 @@ def __init__(self, role: str, valid_roles: set[str]) -> None:
super().__init__(message)


class InvalidContentTypeError(TypeError, ContextualError):
class InvalidContentTypeError(TypeError, UserError):
"""Error thrown when the content type is invalid."""

def __init__(self, content_type: type) -> None:
Expand All @@ -154,7 +175,7 @@ def __init__(self, content_type: type) -> None:
super().__init__(message)


class InvalidPromptTypeError(TypeError, ContextualError):
class InvalidPromptTypeError(TypeError, UserError):
"""Error thrown when the prompt type is invalid."""

def __init__(self, prompt_type: type) -> None:
Expand All @@ -163,7 +184,7 @@ def __init__(self, prompt_type: type) -> None:
super().__init__(message)


class InvalidResponseTypeError(TypeError, ContextualError):
class InvalidResponseTypeError(TypeError, UserError):
"""Error thrown when the response type is invalid."""

def __init__(self, response_type: type) -> None:
Expand All @@ -172,7 +193,7 @@ def __init__(self, response_type: type) -> None:
super().__init__(message)


class InvalidPromptResponseKeysError(ValueError, ContextualError):
class InvalidPromptResponseKeysError(ValueError, UserError):
"""Error thrown when missing expected prompt and response keys."""

def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
Expand All @@ -181,7 +202,7 @@ def __init__(self, mapping: Dict[str, str], example: Dict[str, Any]):
super().__init__(message)


class InvalidFileExtensionError(FileNotFoundError, ContextualError):
class InvalidFileExtensionError(FileNotFoundError, UserError):
"""Error thrown when a file extension is not a safe extension."""

def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
Expand All @@ -194,7 +215,10 @@ def __init__(self, dataset_name: str, valid_extensions: List[str]) -> None:
super().__init__(message)


class UnableToProcessPromptResponseError(ValueError, ContextualError):
class UnableToProcessPromptResponseError(
ValueError,
UserError,
):
"""Error thrown when a prompt and response cannot be processed."""

def __init__(self, input: Dict) -> None:
Expand All @@ -204,7 +228,7 @@ def __init__(self, input: Dict) -> None:


## Convert Delta to JSON exceptions
class ClusterDoesNotExistError(ValueError, ContextualError):
class ClusterDoesNotExistError(ValueError, NetworkError):
"""Error thrown when the cluster does not exist."""

def __init__(self, cluster_id: str) -> None:
Expand All @@ -213,15 +237,22 @@ def __init__(self, cluster_id: str) -> None:
super().__init__(message)


class FailedToCreateSQLConnectionError(RuntimeError, ContextualError):
class FailedToCreateSQLConnectionError(
RuntimeError,
NetworkError,
):
"""Error thrown when client can't sql connect to Databricks."""

def __init__(self) -> None:
message = 'Failed to create sql connection to db workspace. To use sql connect, you need to provide http_path and cluster_id!'
message = 'Failed to create sql connection to db workspace. ' + \
'To use sql connect, you need to provide http_path and cluster_id!'
super().__init__(message)


class FailedToConnectToDatabricksError(RuntimeError, ContextualError):
class FailedToConnectToDatabricksError(
RuntimeError,
NetworkError,
):
"""Error thrown when the client fails to connect to Databricks."""

def __init__(self) -> None:
Expand All @@ -230,7 +261,7 @@ def __init__(self) -> None:


## Convert Text to MDS exceptions
class InputFolderMissingDataError(ValueError, ContextualError):
class InputFolderMissingDataError(ValueError, UserError):
"""Error thrown when the input folder is missing data."""

def __init__(self, input_folder: str) -> None:
Expand All @@ -239,7 +270,7 @@ def __init__(self, input_folder: str) -> None:
super().__init__(message)


class OutputFolderNotEmptyError(FileExistsError, ContextualError):
class OutputFolderNotEmptyError(FileExistsError, UserError):
"""Error thrown when the output folder is not empty."""

def __init__(self, output_folder: str) -> None:
Expand All @@ -248,17 +279,18 @@ def __init__(self, output_folder: str) -> None:
super().__init__(message)


class MisconfiguredHfDatasetError(ValueError, ContextualError):
class MisconfiguredHfDatasetError(ValueError, UserError):
"""Error thrown when a HuggingFace dataset is misconfigured."""

def __init__(self, dataset_name: str, split: str) -> None:
self.dataset_name = dataset_name
self.split = split
message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. Please check your dataset format and make sure you can load your dataset locally.'
message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. ' + \
'Please check your dataset format and make sure you can load your dataset locally.'
super().__init__(message)


class RunTimeoutError(RuntimeError):
class RunTimeoutError(RuntimeError, InternalError):
"""Error thrown when a run times out."""

def __init__(self, timeout: int) -> None:
Expand Down

0 comments on commit 8e29698

Please sign in to comment.