diff --git a/kombu/transport/filesystem.py b/kombu/transport/filesystem.py index 92d8c4e43..9d2b35819 100644 --- a/kombu/transport/filesystem.py +++ b/kombu/transport/filesystem.py @@ -96,18 +96,16 @@ def callback(body, message): import tempfile import uuid from collections import namedtuple -from contextlib import contextmanager from pathlib import Path from queue import Empty from time import monotonic from kombu.exceptions import ChannelError +from kombu.transport import virtual from kombu.utils.encoding import bytes_to_str, str_to_bytes from kombu.utils.json import dumps, loads from kombu.utils.objects import cached_property -from . import virtual - VERSION = (1, 0, 0) __version__ = '.'.join(map(str, VERSION)) @@ -138,7 +136,7 @@ def unlock(file): elif os.name == 'posix': import fcntl - from fcntl import LOCK_EX, LOCK_NB, LOCK_SH # noqa + from fcntl import LOCK_EX, LOCK_SH def lock(file, flags): """Create file lock.""" @@ -163,40 +161,45 @@ class Channel(virtual.Channel): supports_fanout = True - @contextmanager - def _get_exchange_file_obj(self, exchange, mode="rb"): - file = self.control_folder / f"{exchange}.exchange" - if "w" in mode: - self.control_folder.mkdir(exist_ok=True) - f_obj = file.open(mode) - - try: - if "w" in mode: - lock(f_obj, LOCK_EX) - yield f_obj - except OSError: - raise ChannelError(f"Cannot open {file}") - finally: - if "w" in mode: - unlock(f_obj) - f_obj.close() - def get_table(self, exchange): + file = self.control_folder / f"{exchange}.exchange" try: - with self._get_exchange_file_obj(exchange) as f_obj: + f_obj = file.open("r") + try: + lock(f_obj, LOCK_SH) exchange_table = loads(bytes_to_str(f_obj.read())) return [exchange_queue_t(*q) for q in exchange_table] + finally: + unlock(f_obj) + f_obj.close() except FileNotFoundError: return [] + except OSError: + raise ChannelError(f"Cannot open {file}") def _queue_bind(self, exchange, routing_key, pattern, queue): - queues = self.get_table(exchange) + file = self.control_folder / f"{exchange}.exchange" + self.control_folder.mkdir(exist_ok=True) queue_val = exchange_queue_t(routing_key or "", pattern or "", queue or "") - if queue_val not in queues: - queues.insert(0, queue_val) - with self._get_exchange_file_obj(exchange, "wb") as f_obj: - f_obj.write(str_to_bytes(dumps(queues))) + try: + if file.exists(): + f_obj = file.open("rb+", buffering=0) + lock(f_obj, LOCK_EX) + exchange_table = loads(bytes_to_str(f_obj.read())) + queues = [exchange_queue_t(*q) for q in exchange_table] + if queue_val not in queues: + queues.insert(0, queue_val) + f_obj.seek(0) + f_obj.write(str_to_bytes(dumps(queues))) + else: + f_obj = file.open("wb", buffering=0) + lock(f_obj, LOCK_EX) + queues = [queue_val] + f_obj.write(str_to_bytes(dumps(queues))) + finally: + unlock(f_obj) + f_obj.close() def _put_fanout(self, exchange, payload, routing_key, **kwargs): for q in self.get_table(exchange): @@ -209,7 +212,7 @@ def _put(self, queue, payload, **kwargs): filename = os.path.join(self.data_folder_out, filename) try: - f = open(filename, 'wb') + f = open(filename, 'wb', buffering=0) lock(f, LOCK_EX) f.write(str_to_bytes(dumps(payload))) except OSError: @@ -241,7 +244,8 @@ def _get(self, queue): shutil.move(os.path.join(self.data_folder_in, filename), processed_folder) except OSError: - pass # file could be locked, or removed in meantime so ignore + # file could be locked, or removed in meantime so ignore + continue filename = os.path.join(processed_folder, filename) try: diff --git a/t/unit/transport/test_filesystem.py b/t/unit/transport/test_filesystem.py index b22e3b8de..20c7f47a6 100644 --- a/t/unit/transport/test_filesystem.py +++ b/t/unit/transport/test_filesystem.py @@ -1,7 +1,9 @@ from __future__ import annotations import tempfile +from fcntl import LOCK_EX, LOCK_SH from queue import Empty +from unittest.mock import call, patch import pytest @@ -234,3 +236,69 @@ def callback2(message_data, message): assert self.q2(self.consume_channel).get() self.q2(self.consume_channel).purge() assert self.q2(self.consume_channel).get() is None + + +@t.skip.if_win32 +class test_FilesystemLock: + def setup(self): + try: + data_folder_in = tempfile.mkdtemp() + data_folder_out = tempfile.mkdtemp() + control_folder = tempfile.mkdtemp() + except Exception: + pytest.skip("filesystem transport: cannot create tempfiles") + + self.consumer_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_in, + "data_folder_out": data_folder_out, + "control_folder": control_folder, + }, + ) + self.consume_channel = self.consumer_connection.channel() + self.produce_connection = Connection( + transport="filesystem", + transport_options={ + "data_folder_in": data_folder_out, + "data_folder_out": data_folder_in, + "control_folder": control_folder, + }, + ) + self.producer_channel = self.produce_connection.channel() + self.exchange = Exchange("filesystem_exchange_lock", type="fanout") + self.q = Queue("queue1", exchange=self.exchange) + + def teardown(self): + # make sure we don't attempt to restore messages at shutdown. + for channel in [self.producer_channel, self.consumer_connection]: + try: + channel._qos._dirty.clear() + except AttributeError: + pass + try: + channel._qos._delivered.clear() + except AttributeError: + pass + + def test_lock_during_process(self): + producer = Producer(self.producer_channel, self.exchange) + + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + Consumer(self.consume_channel, self.q) + assert unlock_m.call_count == 1 + lock_m.assert_called_once_with(unlock_m.call_args[0][0], LOCK_EX) + + self.q(self.consume_channel).declare() + with patch("kombu.transport.filesystem.lock") as lock_m, patch( + "kombu.transport.filesystem.unlock" + ) as unlock_m: + producer.publish({"foo": 1}) + assert unlock_m.call_count == 2 + assert lock_m.call_count == 2 + exchange_file_obj = unlock_m.call_args_list[0][0][0] + msg_file_obj = unlock_m.call_args_list[1][0][0] + assert lock_m.call_args_list == [call(exchange_file_obj, LOCK_SH), + call(msg_file_obj, LOCK_EX)]