Skip to content

Commit

Permalink
Add lock to Pubsub.execute_command to ensure only one connection is c…
Browse files Browse the repository at this point in the history
…reated (#19)

* Add lock to Pubsub.execute_command to ensure only one connection is created

* Add tests
  • Loading branch information
dlunch authored Oct 27, 2023
1 parent a0ad397 commit aa218e3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
15 changes: 9 additions & 6 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,7 @@ def __init__(
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
else:
self.health_check_response = [b"pong", self.health_check_response_b]
self._connection_lock = threading.Lock()
self.reset()

def __enter__(self):
Expand Down Expand Up @@ -1466,12 +1467,14 @@ def execute_command(self, *args):
# subscribed to one or more channels

if self.connection is None:
self.connection = self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
with self._connection_lock:
if self.connection is None:
self.connection = self.connection_pool.get_connection(
"pubsub", self.shard_hint
)
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,22 @@ def get_msg():

# the timeout on the read should not cause disconnect
assert is_connected()


@pytest.mark.onlynoncluster
class TestConnectionLeak:
def test_connection_leak(self, r: redis.Redis):
pubsub = r.pubsub()

def test():
tid = threading.get_ident()
pubsub.subscribe(f"foo{tid}")

threads = [threading.Thread(target=test) for _ in range(10)]
for thread in threads:
thread.start()

for thread in threads:
thread.join()

assert r.connection_pool._created_connections == 2

0 comments on commit aa218e3

Please sign in to comment.