Skip to content

Commit

Permalink
jupyter lab fixes (#1139) [skip ci]
Browse files Browse the repository at this point in the history
* add a basic notebook for lab users in the root

* update notebook and fix cors for jupyter

* cell is code

* fix eval batch size check

* remove intro notebook
  • Loading branch information
winglian committed Jan 22, 2024
1 parent f5a828a commit eaaeefc
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docker/Dockerfile-cloud
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ EXPOSE 22

COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh

RUN pip install jupyterlab notebook && \
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
Expand Down
2 changes: 1 addition & 1 deletion scripts/cloud-entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ fi

if [ "$JUPYTER_DISABLE" != "1" ]; then
# Run Jupyter Lab in the background
jupyter lab --allow-root --ip 0.0.0.0 &
jupyter lab --port=8888 --ip=* --allow-root --ServerApp.allow_origin=* --ServerApp.preferred_dir=/workspace &
fi

# Execute the passed arguments (CMD)
Expand Down
20 changes: 13 additions & 7 deletions src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""
import logging
from pathlib import Path
from typing import Tuple

import fire
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer

from axolotl.cli import (
check_accelerate_default_config,
Expand All @@ -24,19 +26,23 @@
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)


if parsed_cfg.rl:
dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)


if __name__ == "__main__":
Expand Down
7 changes: 4 additions & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,10 @@ def build(self, total_num_steps):
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
if self.cfg.eval_batch_size:
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def wrapper(*args, **kwargs):
device = kwargs.get("device", args[0] if args else None)

if (
not torch.cuda.is_available()
device is None
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
):
Expand Down
8 changes: 6 additions & 2 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import math
import os
from typing import Any, Optional, Tuple, Union # noqa: F401
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401

import addict
import bitsandbytes as bnb
Expand Down Expand Up @@ -348,7 +348,11 @@ def load_model(
LOG.info("patching _expand_mask")
hijack_expand_mask()

model_kwargs = {}
model_kwargs: Dict[str, Any] = {}

if cfg.model_kwargs:
for key, val in model_kwargs.items():
model_kwargs[key] = val

max_memory = cfg.max_memory
device_map = cfg.device_map
Expand Down

0 comments on commit eaaeefc

Please sign in to comment.