Skip to content

Commit

Permalink
Replace fnmatch with hand-coded pattern compiler
Browse files Browse the repository at this point in the history
This matches the quirks in the redis pattern matching, which has some
differences from fnmatch. It also applies all the logic in bytes, rather
than the native character encoding.

Fixes #182.
  • Loading branch information
bmerry committed Apr 2, 2018
1 parent 5bfa56a commit 3ebaa63
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
78 changes: 64 additions & 14 deletions fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import copy
from ctypes import CDLL, POINTER, c_double, c_char_p, pointer
from ctypes.util import find_library
import fnmatch
from collections import MutableMapping
from datetime import datetime, timedelta
import operator
Expand Down Expand Up @@ -52,11 +51,6 @@ def to_bytes(x, charset=DEFAULT_ENCODING, errors='strict'):
return unicode(x).encode(charset, errors) # noqa: F821
raise TypeError('expected bytes or unicode, not ' + type(x).__name__)

def to_native(x, charset=sys.getdefaultencoding(), errors='strict'):
if x is None or isinstance(x, str):
return x
return x.encode(charset, errors)

def iteritems(d):
return d.iteritems()

Expand Down Expand Up @@ -86,11 +80,6 @@ def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'):
return str(x).encode(charset, errors)
raise TypeError('expected bytes or str, not ' + type(x).__name__)

def to_native(x, charset=sys.getdefaultencoding(), errors='strict'):
if x is None or isinstance(x, str):
return x
return x.decode(charset, errors)

def iteritems(d):
return iter(d.items())

Expand Down Expand Up @@ -250,6 +239,62 @@ def wrapper(self, key, *args, **kwargs):
return wrapper


def _compile_pattern(pattern):
"""Compile a glob pattern (e.g. for keys) to a bytes regex.
fnmatch.fnmatchcase doesn't work for this, because it uses different
escaping rules to redis, uses ! instead of ^ to negate a character set,
and handles invalid cases (such as a [ without a ]) differently. This
implementation was written by studying the redis implementation.
"""
# It's easier to work with text than bytes, because indexing bytes
# doesn't behave the same in Python 3. Latin-1 will round-trip safely.
pattern = pattern.decode('latin-1')
parts = ['^']
i = 0
L = len(pattern)
while i < L:
c = pattern[i]
if c == '?':
parts.append('.')
elif c == '*':
parts.append('.*')
elif c == '\\':
if i < L - 1:
i += 1
parts.append(re.escape(pattern[i]))
elif c == '[':
parts.append('[')
i += 1
if i < L and pattern[i] == '^':
i += 1
parts.append('^')
while i < L:
if pattern[i] == '\\':
i += 1
if i < L:
parts.append(re.escape(pattern[i]))
elif pattern[i] == ']':
break
elif i + 2 <= L and pattern[i + 1] == '-':
start = pattern[i]
end = pattern[i + 2]
if start > end:
start, end = end, start
parts.append(re.escape(start) + '-' + re.escape(end))
i += 2
else:
parts.append(re.escape(pattern[i]))
i += 1
parts.append(']')
else:
parts.append(re.escape(pattern[i]))
i += 1
parts.append('\\Z')
regex = ''.join(parts).encode('latin-1')
return re.compile(regex, re.S)


class _Lock(object):
def __init__(self, redis, name, timeout):
self.redis = redis
Expand Down Expand Up @@ -464,9 +509,10 @@ def incrbyfloat(self, name, amount=1.0):
return value

def keys(self, pattern=None):
if pattern is not None:
regex = _compile_pattern(to_bytes(pattern))
return [key for key in self._db
if not key or not pattern or
fnmatch.fnmatch(to_native(key), to_native(pattern))]
if pattern is None or regex.match(key)]

def mget(self, keys, *args):
all_keys = self._list_or_args(keys, args)
Expand Down Expand Up @@ -1951,8 +1997,12 @@ def _scan(self, keys, cursor, match, count):
result_cursor = cursor + count
result_data = []
# subset =
if match is not None:
regex = _compile_pattern(to_bytes(match))
else:
regex = None
for val in data[cursor:result_cursor]:
if not match or fnmatch.fnmatch(to_native(val), to_native(match)):
if not regex or regex.match(val):
result_data.append(val)
if result_cursor >= len(data):
result_cursor = 0
Expand Down
41 changes: 41 additions & 0 deletions test_fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,43 @@ def test_decr_badtype(self):
with self.assertRaises(redis.ResponseError):
self.redis.decr('foo2', 15)

def test_keys(self):
self.redis.set('', 'empty')
self.redis.set('abc\n', '')
self.redis.set('abc\\', '')
self.redis.set('abcde', '')
if self.decode_responses:
self.assertEqual(sorted(self.redis.keys()),
[b'', b'abc\n', b'abc\\', b'abcde'])
else:
self.redis.set(b'\xfe\xcd', '')
self.assertEqual(sorted(self.redis.keys()),
[b'', b'abc\n', b'abc\\', b'abcde', b'\xfe\xcd'])
self.assertEqual(self.redis.keys('??'), [b'\xfe\xcd'])
# empty pattern not the same as no pattern
self.assertEqual(self.redis.keys(''), [b''])
# ? must match \n
self.assertEqual(sorted(self.redis.keys('abc?')),
[b'abc\n', b'abc\\'])
# must be anchored at both ends
self.assertEqual(self.redis.keys('abc'), [])
self.assertEqual(self.redis.keys('bcd'), [])
# wildcard test
self.assertEqual(self.redis.keys('a*de'), [b'abcde'])
# positive groups
self.assertEqual(sorted(self.redis.keys('abc[d\n]*')),
[b'abc\n', 'abcde'])
self.assertEqual(self.redis.keys('abc[c-e]?'), [b'abcde'])
self.assertEqual(self.redis.keys('abc[e-c]?'), [b'abcde'])
self.assertEqual(self.redis.keys('abc[e-e]?'), [])
self.assertEqual(self.redis.keys('abcd[ef'), [b'abcde'])
# negative groups
self.assertEqual(self.redis.keys('abc[^d\\\\]*'), [b'abc\n'])
# some escaping cases that redis handles strangely
self.assertEqual(self.redis.keys('abc\\'), [b'abc\\'])
self.assertEqual(self.redis.keys(r'abc[\c-e]e'), [])
self.assertEqual(self.redis.keys(r'abc[c-\e]e'), [])

def test_exists(self):
self.assertFalse('foo' in self.redis)
self.redis.set('foo', 'bar')
Expand Down Expand Up @@ -1258,6 +1295,10 @@ def test_scan_iter_single_page(self):
self.redis.set('foo2', 'bar2')
self.assertEqual(set(self.redis.scan_iter(match="foo*")),
set([b'foo1', b'foo2']))
self.assertEqual(set(self.redis.scan_iter()),
set([b'foo1', b'foo2']))
self.assertEqual(set(self.redis.scan_iter(match="")),
set([]))

def test_scan_iter_multiple_pages(self):
all_keys = key_val_dict(size=100)
Expand Down

0 comments on commit 3ebaa63

Please sign in to comment.