diff --git a/fakeredis.py b/fakeredis.py index a10d927..22e6173 100644 --- a/fakeredis.py +++ b/fakeredis.py @@ -4,7 +4,6 @@ import copy from ctypes import CDLL, POINTER, c_double, c_char_p, pointer from ctypes.util import find_library -import fnmatch from collections import MutableMapping from datetime import datetime, timedelta import operator @@ -52,11 +51,6 @@ def to_bytes(x, charset=DEFAULT_ENCODING, errors='strict'): return unicode(x).encode(charset, errors) # noqa: F821 raise TypeError('expected bytes or unicode, not ' + type(x).__name__) - def to_native(x, charset=sys.getdefaultencoding(), errors='strict'): - if x is None or isinstance(x, str): - return x - return x.encode(charset, errors) - def iteritems(d): return d.iteritems() @@ -86,11 +80,6 @@ def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'): return str(x).encode(charset, errors) raise TypeError('expected bytes or str, not ' + type(x).__name__) - def to_native(x, charset=sys.getdefaultencoding(), errors='strict'): - if x is None or isinstance(x, str): - return x - return x.decode(charset, errors) - def iteritems(d): return iter(d.items()) @@ -250,6 +239,62 @@ def wrapper(self, key, *args, **kwargs): return wrapper +def _compile_pattern(pattern): + """Compile a glob pattern (e.g. for keys) to a bytes regex. + + fnmatch.fnmatchcase doesn't work for this, because it uses different + escaping rules to redis, uses ! instead of ^ to negate a character set, + and handles invalid cases (such as a [ without a ]) differently. This + implementation was written by studying the redis implementation. + """ + # It's easier to work with text than bytes, because indexing bytes + # doesn't behave the same in Python 3. Latin-1 will round-trip safely. + pattern = pattern.decode('latin-1') + parts = ['^'] + i = 0 + L = len(pattern) + while i < L: + c = pattern[i] + if c == '?': + parts.append('.') + elif c == '*': + parts.append('.*') + elif c == '\\': + if i < L - 1: + i += 1 + parts.append(re.escape(pattern[i])) + elif c == '[': + parts.append('[') + i += 1 + if i < L and pattern[i] == '^': + i += 1 + parts.append('^') + while i < L: + if pattern[i] == '\\': + i += 1 + if i < L: + parts.append(re.escape(pattern[i])) + elif pattern[i] == ']': + break + elif i + 2 <= L and pattern[i + 1] == '-': + start = pattern[i] + end = pattern[i + 2] + if start > end: + start, end = end, start + parts.append(re.escape(start) + '-' + re.escape(end)) + i += 2 + else: + parts.append(re.escape(pattern[i])) + i += 1 + parts.append(']') + else: + parts.append(re.escape(pattern[i])) + i += 1 + parts.append('\\Z') + regex = ''.join(parts).encode('latin-1') + return re.compile(regex, re.S) + + class _Lock(object): def __init__(self, redis, name, timeout): self.redis = redis @@ -464,9 +509,10 @@ def incrbyfloat(self, name, amount=1.0): return value def keys(self, pattern=None): + if pattern is not None: + regex = _compile_pattern(to_bytes(pattern)) return [key for key in self._db - if not key or not pattern or - fnmatch.fnmatch(to_native(key), to_native(pattern))] + if pattern is None or regex.match(key)] def mget(self, keys, *args): all_keys = self._list_or_args(keys, args) @@ -1951,8 +1997,12 @@ def _scan(self, keys, cursor, match, count): result_cursor = cursor + count result_data = [] # subset = + if match is not None: + regex = _compile_pattern(to_bytes(match)) + else: + regex = None for val in data[cursor:result_cursor]: - if not match or fnmatch.fnmatch(to_native(val), to_native(match)): + if not regex or regex.match(val): result_data.append(val) if result_cursor >= len(data): result_cursor = 0 diff --git a/test_fakeredis.py b/test_fakeredis.py index e494be7..665af7a 100644 --- a/test_fakeredis.py +++ b/test_fakeredis.py @@ -390,6 +390,43 @@ def test_decr_badtype(self): with self.assertRaises(redis.ResponseError): self.redis.decr('foo2', 15) + def test_keys(self): + self.redis.set('', 'empty') + self.redis.set('abc\n', '') + self.redis.set('abc\\', '') + self.redis.set('abcde', '') + if self.decode_responses: + self.assertEqual(sorted(self.redis.keys()), + [b'', b'abc\n', b'abc\\', b'abcde']) + else: + self.redis.set(b'\xfe\xcd', '') + self.assertEqual(sorted(self.redis.keys()), + [b'', b'abc\n', b'abc\\', b'abcde', b'\xfe\xcd']) + self.assertEqual(self.redis.keys('??'), [b'\xfe\xcd']) + # empty pattern not the same as no pattern + self.assertEqual(self.redis.keys(''), [b'']) + # ? must match \n + self.assertEqual(sorted(self.redis.keys('abc?')), + [b'abc\n', b'abc\\']) + # must be anchored at both ends + self.assertEqual(self.redis.keys('abc'), []) + self.assertEqual(self.redis.keys('bcd'), []) + # wildcard test + self.assertEqual(self.redis.keys('a*de'), [b'abcde']) + # positive groups + self.assertEqual(sorted(self.redis.keys('abc[d\n]*')), + [b'abc\n', 'abcde']) + self.assertEqual(self.redis.keys('abc[c-e]?'), [b'abcde']) + self.assertEqual(self.redis.keys('abc[e-c]?'), [b'abcde']) + self.assertEqual(self.redis.keys('abc[e-e]?'), []) + self.assertEqual(self.redis.keys('abcd[ef'), [b'abcde']) + # negative groups + self.assertEqual(self.redis.keys('abc[^d\\\\]*'), [b'abc\n']) + # some escaping cases that redis handles strangely + self.assertEqual(self.redis.keys('abc\\'), [b'abc\\']) + self.assertEqual(self.redis.keys(r'abc[\c-e]e'), []) + self.assertEqual(self.redis.keys(r'abc[c-\e]e'), []) + def test_exists(self): self.assertFalse('foo' in self.redis) self.redis.set('foo', 'bar') @@ -1258,6 +1295,10 @@ def test_scan_iter_single_page(self): self.redis.set('foo2', 'bar2') self.assertEqual(set(self.redis.scan_iter(match="foo*")), set([b'foo1', b'foo2'])) + self.assertEqual(set(self.redis.scan_iter()), + set([b'foo1', b'foo2'])) + self.assertEqual(set(self.redis.scan_iter(match="")), + set([])) def test_scan_iter_multiple_pages(self): all_keys = key_val_dict(size=100)