Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Remove some boilerplate in tests (#4156)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkowl committed Nov 6, 2018
1 parent 0f5e51f commit e62f7f1
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 216 deletions.
1 change: 1 addition & 0 deletions changelog.d/4156.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
HTTP tests have been refactored to contain less boilerplate.
116 changes: 53 additions & 63 deletions tests/rest/client/v1/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,17 @@

from mock import Mock

from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets
from synapse.util import Clock

from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)


class UserRegisterTestCase(unittest.TestCase):
def setUp(self):
class UserRegisterTestCase(unittest.HomeserverTestCase):

servlets = [register_servlets]

def make_homeserver(self, reactor, clock):

self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register"

self.registration_handler = Mock()
Expand All @@ -50,17 +43,14 @@ def setUp(self):

self.secrets = Mock()

self.hs = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs = self.setup_test_homeserver()

self.hs.config.registration_shared_secret = u"shared"

self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock()

self.resource = JsonResource(self.hs)
register_servlets(self.hs, self.resource)
return self.hs

def test_disabled(self):
"""
Expand All @@ -69,8 +59,8 @@ def test_disabled(self):
"""
self.hs.config.registration_shared_secret = None

request, channel = make_request("POST", self.url, b'{}')
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, b'{}')
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
Expand All @@ -87,8 +77,8 @@ def test_get_nonce(self):

self.hs.get_secrets = Mock(return_value=secrets)

request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)

self.assertEqual(channel.json_body, {"nonce": "abcd"})

Expand All @@ -97,25 +87,25 @@ def test_expired_nonce(self):
Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s).
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

# 59 seconds
self.clock.advance(59)
self.reactor.advance(59)

body = json.dumps({"nonce": nonce})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])

# 61 seconds
self.clock.advance(2)
self.reactor.advance(2)

request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
Expand All @@ -124,8 +114,8 @@ def test_register_incorrect_nonce(self):
"""
Only the provided nonce can be used, as it's checked in the MAC.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand All @@ -141,8 +131,8 @@ def test_register_incorrect_nonce(self):
"mac": want_mac,
}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"])
Expand All @@ -152,8 +142,8 @@ def test_register_correct_nonce(self):
When the correct nonce is provided, and the right key is provided, the
user is registered.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand All @@ -169,8 +159,8 @@ def test_register_correct_nonce(self):
"mac": want_mac,
}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])
Expand All @@ -179,8 +169,8 @@ def test_nonce_reuse(self):
"""
A valid unrecognised nonce.
"""
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
nonce = channel.json_body["nonce"]

want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
Expand All @@ -196,15 +186,15 @@ def test_nonce_reuse(self):
"mac": want_mac,
}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"])

# Now, try and reuse it
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"])
Expand All @@ -217,8 +207,8 @@ def test_missing_parts(self):
"""

def nonce():
request, channel = make_request("GET", self.url)
render(request, self.resource, self.clock)
request, channel = self.make_request("GET", self.url)
self.render(request)
return channel.json_body["nonce"]

#
Expand All @@ -227,8 +217,8 @@ def nonce():

# Must be present
body = json.dumps({})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"])
Expand All @@ -239,32 +229,32 @@ def nonce():

# Must be present
body = json.dumps({"nonce": nonce()})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"])

# Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])

# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])

# Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"])
Expand All @@ -275,16 +265,16 @@ def nonce():

# Must be present
body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"])

# Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
Expand All @@ -293,16 +283,16 @@ def nonce():
body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
)
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])

# Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock)
request, channel = self.make_request("POST", self.url, body.encode('utf8'))
self.render(request)

self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"])
10 changes: 5 additions & 5 deletions tests/rest/client/v1/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def setUp(self):
)

handlers = Mock(registration_handler=self.registration_handler)
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.reactor = MemoryReactorClock()
self.hs_clock = Clock(self.reactor)

self.hs = self.hs = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers)
Expand All @@ -76,8 +76,8 @@ def test_POST_createuser_with_valid_user(self):
return_value=(user_id, token)
)

request, channel = make_request(b"POST", url, request_data)
render(request, res, self.clock)
request, channel = make_request(self.reactor, b"POST", url, request_data)
render(request, res, self.reactor)

self.assertEquals(channel.result["code"], b"200")

Expand Down
22 changes: 7 additions & 15 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def create_room_as(self, room_creator, is_public=True, tok=None):
path = path + "?access_token=%s" % tok

request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8')
self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())

Expand Down Expand Up @@ -217,7 +217,9 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2

data = {"membership": membership}

request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
)

render(request, self.resource, self.hs.get_reactor())

Expand All @@ -228,18 +230,6 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2

self.auth_user_id = temp_id

@defer.inlineCallbacks
def register(self, user_id):
(code, response) = yield self.mock_resource.trigger(
"POST",
"/_matrix/client/r0/register",
json.dumps(
{"user": user_id, "password": "test", "type": "m.login.password"}
),
)
self.assertEquals(200, code)
defer.returnValue(response)

def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
Expand All @@ -251,7 +241,9 @@ def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if tok:
path = path + "?access_token=%s" % tok

request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor())

assert int(channel.result["code"]) == expect_code, (
Expand Down
Loading

0 comments on commit e62f7f1

Please sign in to comment.