Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to postpone filter init till it's running #242

Open
stas00 opened this issue Jul 9, 2024 · 5 comments
Open

how to postpone filter init till it's running #242

stas00 opened this issue Jul 9, 2024 · 5 comments

Comments

@stas00
Copy link

stas00 commented Jul 9, 2024

So it appears that currently I can't instantiate a model on a gpu because the filter object is created by the launcher, which either doesn't have a gpu, or it is most likely the wrong gpu even if it has one, since we would need a dedicated gpu(s) for each task.

Is it possible to add a 2nd init which would be the user init that will run on the actual job?

The filter task is simple - instantiate a model on a gpu and then run filter using it - of course we don't want model to be re-instantiated on every filter call.

Needing to import torch inside the filter is super-weird as well, but I get that it's due to pickle - but perhaps we can have two inits - one of the framework - and then another of the user.

So when a job is launched the first thing the framework runs is user defined init if any, and then proceeds normally.

I guess I will try to overcome this meanwhile using @functools.cache or something similar.

Thank you!

tag: @guipenedo

@stas00
Copy link
Author

stas00 commented Jul 9, 2024

I'm trying:

    @functools.cached_property
    def device(self):
        return torch.device('cuda:0')

    @functools.cached_property
    def model(self):
        return ClassifierHead.from_pretrained(mname).to(self.device)

    @functools.cached_property
    def tokenizer(self):
        return AutoTokenizer.from_pretrained(mname)

    @functools.cached_property
    def config(self):
        return AutoConfig.from_pretrained(mname)

and then inside the filter only the first items gets hit by this init.

edit: hmm, this approach seems to hang. strace showing it's stuck trying to read from fd=4

so I got to:

+ export PYTHONUNBUFFERED=TRUE
+ PYTHONUNBUFFERED=TRUE
+ srun -l launch_pickled_pipeline /data/stas/classify2/data/logs/slurm_processing/executor.pik
0: 2024-07-09 01:18:35.831 | INFO     | datatrove.utils.logging:add_task_logger:58 - Launching pipeline for rank=0
0: 2024-07-09 01:18:35.831 | INFO     | datatrove.utils.logging:log_pipeline:90 -
0: --- 🛠️ PIPELINE 🛠
0: 📖 - READER: 🤗 HuggingFace
0: 🔻 - FILTER: Classifier Filter
0: 💽 - WRITER: 🐿 Jsonl

and then nothing happens.

I set the reader to limit=10, so it should be real fast.

Must be something pickle-related

Do you by chance have an example of a working filter that uses a gpu given by the srun task?

@stas00
Copy link
Author

stas00 commented Jul 9, 2024

update: If I run the same job as a local executor it works fine, it hangs on the first sample w/ slurm, so it must be some pickle related issue.

When I scancel the job it shows the buffered up part

srun: Job step aborted: Waiting up to 32 seconds for job step to finish.
 10%|█         | 1/10 [00:06<00:54,  6.00s/it]slurmstepd: error: *** STEP 101728.0 ON -5 CANCELLED AT 2024-07-09T03:51:10 ***
slurmstepd: error: *** JOB 101728 ON 5 CANCELLED AT 2024-07-09T03:51:10 ***

@guipenedo
Copy link
Collaborator

can you share the full class so I can try to reproduce the issue?

@stas00
Copy link
Author

stas00 commented Jul 9, 2024

Yes, of course

import functools
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer

class ClassifierFilter(BaseFilter):
    name = "Classifier Filter"

    def __init__(
        self,
        exclusion_writer: DiskWriter = None,
    ):
        super().__init__(exclusion_writer)

    @functools.cached_property
    def device(self):
        return torch.device('cuda')

    @functools.cached_property
    def model(self):
        return ClassifierHead.from_pretrained("nvidia/domain-classifier").to(self.device)

    @functools.cached_property
    def tokenizer(self):
        return AutoTokenizer.from_pretrained("nvidia/domain-classifier")

    @functools.cached_property
    def config(self):
        return AutoConfig.from_pretrained("nvidia/domain-classifier")

    def filter(self, doc) -> bool | tuple[bool, str]:
        import torch  # noqa - pickle quirk

        inputs = self.tokenizer(
            [doc.text], return_tensors="pt", padding="longest", truncation=True
        ).to(self.device)
        outputs = self.model(inputs["input_ids"], inputs["attention_mask"])

        predicted_classes = torch.argmax(outputs, dim=1)
        predicted_domains = [
            self.config.id2label[class_idx.item()]
            for class_idx in predicted_classes.cpu().numpy()
        ]
        if predicted_domains[0] == "Health":
            return True
        else:
            return False, predicted_domains[0]

or perhaps if you have an example of a filter that works on running something on cuda that could help too.

@stas00
Copy link
Author

stas00 commented Jul 10, 2024

Is there a plan for another way of passing the jobs instead of pickle?

The hanging happens because of functool.cached_property - so can't use it it seems.

I came up with the following workaround, creating my own post-un-pickle-init via to_device:

class ClassifierFilter(BaseFilter):
    name = "Classifier Filter"

    def __init__(
        self,
        exclusion_writer: DiskWriter = None,
    ):
        super().__init__(exclusion_writer)

        self.device = None
        self.config = AutoConfig.from_pretrained("nvidia/domain-classifier")
        self.tokenizer = AutoTokenizer.from_pretrained("nvidia/domain-classifier")
        self.model = ClassifierHead.from_pretrained("nvidia/domain-classifier")

    def to_device(self):
        import torch
        if self.device is not None:
            return
        self.device = torch.device('cuda')
        self.model = self.model.to(self.device)

    def filter(self, doc) -> bool | tuple[bool, str]:
        import torch  # noqa - pickle quirk
        
        self.to_device()
        
        inputs = self.tokenizer(
            [doc.text], return_tensors="pt", padding="longest", truncation=True
        ).to(self.device)
        outputs = self.model(inputs["input_ids"], inputs["attention_mask"])

        predicted_classes = torch.argmax(outputs, dim=1)
        predicted_domains = [
            self.config.id2label[class_idx.item()]
            for class_idx in predicted_classes.cpu().numpy()
        ]
        if predicted_domains[0] == "Health":
            return True
        else:
            return False, predicted_domains[0]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants