Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Fine-tuning reproducibility with GPU #752

Open
alanwilter opened this issue May 23, 2024 · 0 comments
Open

Fine-tuning reproducibility with GPU #752

alanwilter opened this issue May 23, 2024 · 0 comments

Comments

@alanwilter
Copy link

I'm using this repo in my app, which is in Python and, for now, made to cater our own private data.

The app facilitates the fine-tuned models creation from Vanilla SAM and MedSAM. And we also do inference, of course.

We use just these methods:

from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

Our plan is to release our package for the public via GitHub, open source.

However, we have been having problems to make the fine-tuned models reproducible and the only culprit I can see here is ResizeLongestSide.

When using inference the results are deterministic.

I've updated our code to use the latest PyTorch 2.3. And I have done this in hoping to make it reproducible:

import random

import numpy as np
import torch

seed = 42

np.random.seed(seed)
random.seed(seed)
torch.cuda.empty_cache()
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(False)  # True only for CPU

However if I do torch.use_deterministic_algorithms(True) and try with GPU it does not even work, even trying the suggestions the warning message provides. Basically, there are some routines in PyTorch/Cuda that are not made deterministic yet, apparently.

By reproducibility I mean, I do the fine-tuning training and get a model. If I repeat the same procedure, I got a different model, that gives different results for inference.
If running inference with a given model, results are deterministic.

Sometimes the resulting model gives really poor results sometimes they are good or even great.

If I use torch.use_deterministic_algorithms(True) and set run with CPU only, I got reproducible training results, however it's like 100x slower, hence impractical.

I'm wondering if anyone has faced this issue.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant