From aa218e33f16e1ea10b7ffcc22855e57b5b1e9bfe Mon Sep 17 00:00:00 2001 From: Inseok Lee Date: Fri, 27 Oct 2023 04:55:16 +0200 Subject: [PATCH] Add lock to Pubsub.execute_command to ensure only one connection is created (#19) * Add lock to Pubsub.execute_command to ensure only one connection is created * Add tests --- redis/client.py | 15 +++++++++------ tests/test_pubsub.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/redis/client.py b/redis/client.py index ab626ccdf4..bbe46d0d42 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1402,6 +1402,7 @@ def __init__( self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] else: self.health_check_response = [b"pong", self.health_check_response_b] + self._connection_lock = threading.Lock() self.reset() def __enter__(self): @@ -1466,12 +1467,14 @@ def execute_command(self, *args): # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection( - "pubsub", self.shard_hint - ) - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) + with self._connection_lock: + if self.connection is None: + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 5d86934de6..f1cf6ae771 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -777,3 +777,22 @@ def get_msg(): # the timeout on the read should not cause disconnect assert is_connected() + + +@pytest.mark.onlynoncluster +class TestConnectionLeak: + def test_connection_leak(self, r: redis.Redis): + pubsub = r.pubsub() + + def test(): + tid = threading.get_ident() + pubsub.subscribe(f"foo{tid}") + + threads = [threading.Thread(target=test) for _ in range(10)] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert r.connection_pool._created_connections == 2