Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Make registration idempotent #649

Merged
merged 9 commits into from
Mar 16, 2016
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import bcrypt
import pymacaroons
import simplejson
import time

import synapse.util.stringutils as stringutils

Expand All @@ -35,6 +36,7 @@


class AuthHandler(BaseHandler):
SESSION_EXPIRE_SECS = 48 * 60 * 60
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally try and consistently always use milliseconds.

(You get brownie points for including "secs" in the name though)


def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
Expand Down Expand Up @@ -66,15 +68,18 @@ def check_auth(self, flows, clientdict, clientip):
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
A tuple of (authed, dict, dict, session_id) where authed is true if
the client has successfully completed an auth flow. If it is true
the first dict contains the authenticated credentials of each stage.

If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.

In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).

session_id is the ID of this session, either passed in by the client
or assigned by the call to check_auth
"""

authdict = None
Expand Down Expand Up @@ -103,7 +108,10 @@ def check_auth(self, flows, clientdict, clientip):

if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, session), clientdict)
(
False, self._auth_dict_for_flows(flows, session),
clientdict, session['id']
)
)

if 'creds' not in session:
Expand All @@ -122,12 +130,11 @@ def check_auth(self, flows, clientdict, clientip):
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(session)
defer.returnValue((True, creds, clientdict))
defer.returnValue((True, creds, clientdict, session['id']))

ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
defer.returnValue((False, ret, clientdict, session['id']))

@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
Expand All @@ -154,6 +161,29 @@ def add_oob_auth(self, stagetype, authdict, clientip):
defer.returnValue(True)
defer.returnValue(False)

def set_session_data(self, session_id, key, value):
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param value: (any) The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
self._save_session(sess)

def get_session_data(self, session_id, key, default=None):
"""
Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param default: (any) Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)

@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
Expand Down Expand Up @@ -263,7 +293,7 @@ def _get_session_info(self, session_id):
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
session_id = stringutils.random_string_with_symbols(24)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the session ids get sent to the client? Or they purely internal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They get sent to the client

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its probably fine, but random_string_with_symbols will return a lot of silly symbols, so I've tended to avoid using them in public APIs (especially for anything that is used as query string params)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe - this one only goes into json so it should be fine, but possibly the extra token space isn't worth it.

self.sessions[session_id] = {
"id": session_id,
}
Expand Down Expand Up @@ -455,11 +485,17 @@ def add_threepid(self, user_id, medium, address, validated_at):
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = time.time()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use hs.get_clock()?

We generally try and consistently use milliseconds internally, rather than seconds.

self.sessions[session["id"]] = session

def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]
self._prune_sessions()

def _prune_sessions(self):
for sid, sess in self.sessions.items():
last_used = 0
if 'last_used' in sess:
last_used = sess['last_used']
if last_used < time.time() - AuthHandler.SESSION_EXPIRE_SECS:
del self.sessions[sid]

def hash(self, password):
"""Computes a secure hash of password.
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/v2_alpha/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def on_POST(self, request):

body = parse_json_object_from_request(request)

authed, result, params = yield self.auth_handler.check_auth([
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body, self.hs.get_ip_from_request(request))
Expand Down
28 changes: 27 additions & 1 deletion synapse/rest/client/v2_alpha/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,34 @@ def on_POST(self, request):
[LoginType.EMAIL_IDENTITY]
]

authed, result, params = yield self.auth_handler.check_auth(
authed, result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)

if not authed:
defer.returnValue((401, result))
return

# have we already registered a user for this session
registered_user_id = self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
registered_user_id
)
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
refresh_token = yield self.auth_handler.issue_refresh_token(
registered_user_id
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))

# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
Expand All @@ -161,6 +181,12 @@ def on_POST(self, request):
guest_access_token=guest_access_token,
)

# remember that we've now registered that user account, and with what
# user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id
)

if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]

Expand Down
9 changes: 5 additions & 4 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def setUp(self):
side_effect=lambda x: defer.succeed(self.appservice))
)

self.auth_result = (False, None, None)
self.auth_result = (False, None, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result)
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None)
)
self.registration_handler = Mock()
self.identity_handler = Mock()
Expand Down Expand Up @@ -112,7 +113,7 @@ def test_POST_user_valid(self):
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
}, None)
self.registration_handler.register = Mock(return_value=(user_id, token))

(code, result) = yield self.servlet.on_POST(self.request)
Expand All @@ -135,7 +136,7 @@ def test_POST_disabled_registration(self):
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)