Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jupyter lab fixes #1139

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile-cloud
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ENV HF_HUB_ENABLE_HF_TRANSFER="1"

COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh

RUN pip install jupyterlab notebook && \
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
Expand Down
2 changes: 1 addition & 1 deletion scripts/cloud-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fi

if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --allow-root --ip 0.0.0.0 &
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi

# Execute the passed arguments (CMD)
Expand Down
20 changes: 13 additions & 7 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""
import logging
from pathlib import Path
from typing import Tuple

import fire
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -24,19 +26,23 @@
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)


if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,10 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.eval_batch_size:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but I'm trying to prevent it from getting set to None here which results in a downstream error I found.

training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)

if (
not torch.cuda.is_available()
device is None
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
Expand Down
8 changes: 6 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Any, Optional, Tuple, Union # noqa: F401
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
import bitsandbytes as bnb
Expand Down Expand Up @@ -339,7 +339,11 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()

model_kwargs = {}
model_kwargs: Dict[str, Any] = {}

if cfg.model_kwargs:
for key, val in model_kwargs.items():
model_kwargs[key] = val

max_memory = cfg.max_memory
device_map = cfg.device_map
Expand Down