Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] get_backend is not thread-safe #279

Closed
tran-khoa opened this issue Sep 14, 2023 · 3 comments
Closed

[BUG] get_backend is not thread-safe #279

tran-khoa opened this issue Sep 14, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@tran-khoa
Copy link

Describe the bug
Using einops with multiple threads can lead to a race condition, as the backend dictionary is updated while being iterated over in another thread.

Traceback (most recent call last):
  File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 995, in _bo
otstrap
    self._bootstrap_inner()
  File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 1038, in _b
ootstrap_inner
    self.run()
  File "/p/software/jurecadc/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/threading.py", line 975, in run
    self._target(*self._args, **self._kwargs)
  File "/p/home/jusers/tran4/jureca/llfs/llfs/data/iterator.py", line 70, in _prefetch_thread
    self._buffer.put(self._obtain_sample())
                     ^^^^^^^^^^^^^^^^^^^^^
  File "/p/home/jusers/tran4/jureca/llfs/llfs/data/iterator.py", line 64, in _obtain_sample
    sample = self.jax_pipeline(sample)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/home/jusers/tran4/jureca/llfs/projs/ll_barlow/experiment.py", line 57, in __call__
    images = filter_vmap(self.tokenizer_fn)(images)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 166, in reraise
_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/api.py", line 1258, in vmap_f
    out_flat = batching.batch(
               ^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/jax/_src/linear_util.py", line 188, in call_wrapp
ed
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/equinox/_vmap_pmap.py", line 204, in _fun_wrapper
    _out = self._fun(*_args)
           ^^^^^^^^^^^^^^^^^
  File "/p/home/jusers/tran4/jureca/llfs/llfs/nn/transformers/vit.py", line 100, in __call__
    x = eo.rearrange(x, "c p -> p c")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 483, in rearrange
    return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 412, in reduce
    return _apply_recipe(recipe, tensor, reduction_type=reduction)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/einops.py", line 233, in _apply_recipe
    backend = get_backend(tensor)
              ^^^^^^^^^^^^^^^^^^^
  File "/p/project/jinm60/users/tran4/env_llfs/lib/python3.11/site-packages/einops/_backends.py", line 27, in get_backend
    for framework_name, backend in _backends.items():
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: dictionary changed size during iteration

Reproduction steps
Rerun multiple times

import einops as eo
import jax.numpy as jnp
import threading

x = jnp.ones((2, 2))
y = jnp.zeros((2, 2))

def thread(*args, **kwargs):
    global x
    x = eo.rearrange(x, "n c -> (n c)")
    print(x)

threading.Thread(target=thread, daemon=True).start()

y = eo.rearrange(y, "n c -> (n c)")
print(y)

Expected behavior
No race condition

Your platform
einops 0.6.1, python 3.11.3, jax v0.4.14

@tran-khoa tran-khoa added the bug Something isn't working label Sep 14, 2023
@arogozhnikov
Copy link
Owner

Hmmm, good point, does wrapping _backends.items() into list solves the problem?

@tran-khoa
Copy link
Author

I don't think so. Assume two threads T1 and T2 using the same backend cannot find the backend in the dict.
If T1 then imports the respective backend, T2 may encounter
if BackendSubclass.framework_name not in _backends: therefore not finding the backend that has already been imported by T1.

A simple solution would be to introduce a lock, something like

def get_backend(tensor) -> 'AbstractBackend':
    """
    Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
    If needed, imports package and creates backend
    """
    for framework_name, backend in _backends.items():
        if backend.is_appropriate_type(tensor):
            return backend

    with lock:
        # Try to find backend again
        for framework_name, backend in _backends.items():
            if backend.is_appropriate_type(tensor):
                return backend

        # Find backend subclasses recursively
        backend_subclasses = []
        backends = AbstractBackend.__subclasses__()
        while backends:
            backend = backends.pop()
            backends += backend.__subclasses__()
            backend_subclasses.append(backend)

        for BackendSubclass in backend_subclasses:
            if _debug_importing:
                print('Testing for subclass of ', BackendSubclass)
            if BackendSubclass.framework_name not in _backends:
                # check that module was already imported. Otherwise it can't be imported
                if BackendSubclass.framework_name in sys.modules:
                    if _debug_importing:
                        print('Imported backend for ', BackendSubclass.framework_name)
                    backend = BackendSubclass()
                    _backends[backend.framework_name] = backend
                    if backend.is_appropriate_type(tensor):
                        return backend

    raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))

@arogozhnikov
Copy link
Owner

If T1 then imports the respective backend, T2 may encounter
if BackendSubclass.framework_name not in _backends: therefore not finding the backend that has already been imported by T1.

True, but not an issue: rest of function is idempontent, and no problem if backend is created twice.

arogozhnikov added a commit that referenced this issue Sep 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants