Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unable to train #134

Open
riyajatar37003 opened this issue Jun 22, 2024 · 3 comments
Open

unable to train #134

riyajatar37003 opened this issue Jun 22, 2024 · 3 comments

Comments

@riyajatar37003
Copy link

riyajatar37003 commented Jun 22, 2024

i installed the tevatron using
pip install tevatron

and trying to train the repllama from the examples folder but i am getting following import error

Traceback (most recent call last):
File un_3/tevatron/examples/repllama/train.py", line 12, in
from tevatron.arguments import ModelArguments, DataArguments,
ImportError: cannot import name 'TevatronTrainingArguments' from 'tevatron.arguments' (/tmp/.local/lib/python3.10/site-packages/tevatron/arguments.py)
Traceback (most recent call last):
File "un_3/tevatron/examples/repllama/train.py", line 12, in
from tevatron.arguments import ModelArguments, DataArguments,
ImportError: cannot import name 'TevatronTrainingArguments' from 'tevatron.arguments' (/tmp/.local/lib/python3.10/site-packages/tevatron/arguments.py)
Traceback (most recent call last):
File "un_3/tevatron/examples/repllama/train.py", line 12, in
from tevatron.arguments import ModelArguments, DataArguments,
ImportError: cannot import name 'TevatronTrainingArguments' from 'tevatron.arguments' (/tmp/.local/lib/python3.10/site-packages/tevatron/arguments.py)

could u give simple steps to train repllama ?

@MXueguang
Copy link
Contributor

Hi, please clone the repo and install it via "pip install -e ." Then follow the command on main page would be able to reproduce the repllama/repmistral

deepspeed --include localhost:0,1,2,3 --master_port 60000 --module tevatron.retriever.driver.train
--deepspeed deepspeed/ds_zero3_config.json
--output_dir retriever-mistral
--model_name_or_path mistralai/Mistral-7B-v0.1
--lora
--lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj
--save_steps 50
--dataset_name Tevatron/msmarco-passage-aug
--query_prefix "Query: "
--passage_prefix "Passage: "
--bf16
--pooling eos
--append_eos_token
--normalize
--temperature 0.01
--per_device_train_batch_size 8
--gradient_checkpointing
--train_group_size 16
--learning_rate 1e-4
--query_max_len 32
--passage_max_len 156
--num_train_epochs 1
--logging_steps 10
--overwrite_output_dir
--gradient_accumulation_steps 4

@riyajatar37003
Copy link
Author

i am getting this error

UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.

i set it to true as well false in traininig_args.gradien_checkpoint kwargs
but still same error

@MXueguang
Copy link
Contributor

what is your pytorch version?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants