Skip to content

Commit

Permalink
Solve Kombu filesystem transport not thread safe (celery#1593)
Browse files Browse the repository at this point in the history
* Solve Kombu filesystem transport not thread safe

fix: celery#398
Currently only write lock used in msg/exchange file written. Cause
reading in other thread got some incomplete result.

1. Add timeout for the lock acquire.
2. Add Share locks when reading message from filesystem.
3. Add a unit test for the `lock` and `unlock`
4. Add a unit test to test the lock during message processing.

* Replace deprecated function.
  • Loading branch information
karajan1001 authored Sep 7, 2022
1 parent ec533af commit 8699920
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 29 deletions.
87 changes: 59 additions & 28 deletions kombu/transport/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def callback(body, message):

import os
import shutil
import signal
import tempfile
import uuid
from collections import namedtuple
Expand All @@ -111,6 +112,26 @@ def callback(body, message):
VERSION = (1, 0, 0)
__version__ = '.'.join(map(str, VERSION))


@contextmanager
def timeout_manager(seconds: int):
def timeout_handler(signum, frame):
# Now that flock retries automatically when interrupted, we need
# an exception to stop it
# This exception will propagate on the main thread,
# make sure you're calling flock there
raise InterruptedError

original_handler = signal.signal(signal.SIGALRM, timeout_handler)

try:
signal.alarm(seconds)
yield
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, original_handler)


# needs win32all to work on Windows
if os.name == 'nt':

Expand Down Expand Up @@ -138,7 +159,7 @@ def unlock(file):
elif os.name == 'posix':

import fcntl
from fcntl import LOCK_EX, LOCK_NB, LOCK_SH # noqa
from fcntl import LOCK_EX, LOCK_SH

def lock(file, flags):
"""Create file lock."""
Expand All @@ -154,6 +175,21 @@ def unlock(file):
'Filesystem plugin only defined for NT and POSIX platforms')


@contextmanager
def lock_with_timeout(file, flags, timeout: int = 1):
with timeout_manager(timeout):
try:
lock(file, flags)
yield
except InterruptedError:
# Catch the exception raised by the handler
# If we weren't raising an exception,
# flock would automatically retry on signals
raise BlockingIOError("Lock timed out")
finally:
unlock(file)


exchange_queue_t = namedtuple("exchange_queue_t",
["routing_key", "pattern", "queue"])

Expand All @@ -168,18 +204,14 @@ def _get_exchange_file_obj(self, exchange, mode="rb"):
file = self.control_folder / f"{exchange}.exchange"
if "w" in mode:
self.control_folder.mkdir(exist_ok=True)
f_obj = file.open(mode)
lock_mode = LOCK_EX if "w" in mode else LOCK_SH

try:
if "w" in mode:
lock(f_obj, LOCK_EX)
yield f_obj
except OSError:
raise ChannelError(f"Cannot open {file}")
finally:
if "w" in mode:
unlock(f_obj)
f_obj.close()
with file.open(mode) as f_obj:
try:
with lock_with_timeout(f_obj, lock_mode):
yield f_obj
except OSError as err:
raise ChannelError(f"Cannot open {file}") from err

def get_table(self, exchange):
try:
Expand Down Expand Up @@ -209,15 +241,12 @@ def _put(self, queue, payload, **kwargs):
filename = os.path.join(self.data_folder_out, filename)

try:
f = open(filename, 'wb')
lock(f, LOCK_EX)
f.write(str_to_bytes(dumps(payload)))
except OSError:
with open(filename, 'wb') as f:
with lock_with_timeout(f, LOCK_EX):
f.write(str_to_bytes(dumps(payload)))
except OSError as err:
raise ChannelError(
f'Cannot add file {filename!r} to directory')
finally:
unlock(f)
f.close()
f'Cannot add file {filename!r} to directory') from err

def _get(self, queue):
"""Get next message from `queue`."""
Expand Down Expand Up @@ -245,14 +274,14 @@ def _get(self, queue):

filename = os.path.join(processed_folder, filename)
try:
f = open(filename, 'rb')
payload = f.read()
f.close()
if not self.store_processed:
os.remove(filename)
except OSError:
with open(filename, 'rb') as f:
with lock_with_timeout(f, LOCK_SH):
payload = f.read()
if not self.store_processed:
os.remove(filename)
except OSError as err:
raise ChannelError(
f'Cannot read file {filename!r} from queue.')
f'Cannot read file {filename!r} from queue.') from err

return loads(bytes_to_str(payload))

Expand All @@ -272,7 +301,9 @@ def _purge(self, queue):
continue

filename = os.path.join(self.data_folder_in, filename)
os.remove(filename)
with open(filename, 'wb') as f:
with lock_with_timeout(f, LOCK_EX):
os.remove(filename)

count += 1

Expand Down
112 changes: 111 additions & 1 deletion t/unit/transport/test_filesystem.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

import tempfile
from fcntl import LOCK_EX, LOCK_NB, LOCK_SH
from queue import Empty
from unittest.mock import call, patch

import pytest

import t.skip
from kombu import Connection, Consumer, Exchange, Producer, Queue
from kombu.transport.filesystem import lock, unlock


@t.skip.if_win32
class test_FilesystemTransport:

def setup(self):
self.channels = set()
try:
Expand Down Expand Up @@ -145,6 +147,7 @@ def callback2(message_data, message):

@t.skip.if_win32
class test_FilesystemFanout:

def setup(self):
try:
data_folder_in = tempfile.mkdtemp()
Expand Down Expand Up @@ -234,3 +237,110 @@ def callback2(message_data, message):
assert self.q2(self.consume_channel).get()
self.q2(self.consume_channel).purge()
assert self.q2(self.consume_channel).get() is None


@t.skip.if_win32
class test_FilesystemLock:
def test_lock(self):
file_obj1 = tempfile.NamedTemporaryFile()
with open(file_obj1.name) as file_obj2:
lock(file_obj1, LOCK_SH)
with pytest.raises(BlockingIOError):
lock(file_obj2, LOCK_EX | LOCK_NB)

lock(file_obj2, LOCK_SH)
unlock(file_obj2)

unlock(file_obj1)
lock(file_obj2, LOCK_EX)
unlock(file_obj2)
file_obj1.close()


@t.skip.if_win32
class test_FilesystemLockDuringProcess:
def setup(self):
try:
data_folder_in = tempfile.mkdtemp()
data_folder_out = tempfile.mkdtemp()
control_folder = tempfile.mkdtemp()
except Exception:
pytest.skip("filesystem transport: cannot create tempfiles")

self.consumer_connection = Connection(
transport="filesystem",
transport_options={
"data_folder_in": data_folder_in,
"data_folder_out": data_folder_out,
"control_folder": control_folder,
},
)
self.consume_channel = self.consumer_connection.channel()
self.produce_connection = Connection(
transport="filesystem",
transport_options={
"data_folder_in": data_folder_out,
"data_folder_out": data_folder_in,
"control_folder": control_folder,
},
)
self.producer_channel = self.produce_connection.channel()
self.exchange = Exchange("filesystem_exchange_lock", type="fanout")
self.q = Queue("queue1", exchange=self.exchange)

def teardown(self):
# make sure we don't attempt to restore messages at shutdown.
for channel in [self.producer_channel, self.consumer_connection]:
try:
channel._qos._dirty.clear()
except AttributeError:
pass
try:
channel._qos._delivered.clear()
except AttributeError:
pass

def test_lock_during_process(self):
producer = Producer(self.producer_channel, self.exchange)

with patch("kombu.transport.filesystem.lock") as lock_m, patch(
"kombu.transport.filesystem.unlock"
) as unlock_m:
consumer = Consumer(self.consume_channel, self.q)
assert unlock_m.call_count == 1
lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX)

self.q(self.consume_channel).declare()
with patch("kombu.transport.filesystem.lock") as lock_m, patch(
"kombu.transport.filesystem.unlock"
) as unlock_m:
producer.publish({"foo": 1})
assert unlock_m.call_count == 2
assert lock_m.call_count == 2
exchange_file_obj = unlock_m.call_args_list[0][0][0]
msg_file_obj = unlock_m.call_args_list[1][0][0]
assert lock_m.call_args_list == [call(exchange_file_obj, LOCK_SH),
call(msg_file_obj, LOCK_EX)]

def callback(_, message):
message.ack()

consumer.register_callback(callback)
consumer.consume()

with patch("kombu.transport.filesystem.lock") as lock_m, patch(
"kombu.transport.filesystem.unlock"
) as unlock_m:
self.consume_channel.drain_events()
assert lock_m.call_count == 1
assert unlock_m.call_count == 1
lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_SH)

producer.publish({"foo": 0})
with patch("kombu.transport.filesystem.lock") as lock_m, patch(
"kombu.transport.filesystem.unlock"
) as unlock_m:
self.q(self.consume_channel).purge()
assert lock_m.call_count == 1
assert unlock_m.call_count == 1
lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX)

0 comments on commit 8699920

Please sign in to comment.