-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Make registration idempotent #649
Changes from 4 commits
c12b9d7
9979794
ff7d3dc
f5e9042
742b6c6
9671e67
3176aeb
3ee7d7d
b58d10a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
import bcrypt | ||
import pymacaroons | ||
import simplejson | ||
import time | ||
|
||
import synapse.util.stringutils as stringutils | ||
|
||
|
@@ -35,6 +36,7 @@ | |
|
||
|
||
class AuthHandler(BaseHandler): | ||
SESSION_EXPIRE_SECS = 48 * 60 * 60 | ||
|
||
def __init__(self, hs): | ||
super(AuthHandler, self).__init__(hs) | ||
|
@@ -66,15 +68,18 @@ def check_auth(self, flows, clientdict, clientip): | |
'auth' key: this method prompts for auth if none is sent. | ||
clientip (str): The IP address of the client. | ||
Returns: | ||
A tuple of (authed, dict, dict) where authed is true if the client | ||
has successfully completed an auth flow. If it is true, the first | ||
dict contains the authenticated credentials of each stage. | ||
A tuple of (authed, dict, dict, session_id) where authed is true if | ||
the client has successfully completed an auth flow. If it is true | ||
the first dict contains the authenticated credentials of each stage. | ||
|
||
If authed is false, the first dictionary is the server response to | ||
the login request and should be passed back to the client. | ||
|
||
In either case, the second dict contains the parameters for this | ||
request (which may have been given only in a previous call). | ||
|
||
session_id is the ID of this session, either passed in by the client | ||
or assigned by the call to check_auth | ||
""" | ||
|
||
authdict = None | ||
|
@@ -103,7 +108,10 @@ def check_auth(self, flows, clientdict, clientip): | |
|
||
if not authdict: | ||
defer.returnValue( | ||
(False, self._auth_dict_for_flows(flows, session), clientdict) | ||
( | ||
False, self._auth_dict_for_flows(flows, session), | ||
clientdict, session['id'] | ||
) | ||
) | ||
|
||
if 'creds' not in session: | ||
|
@@ -122,12 +130,11 @@ def check_auth(self, flows, clientdict, clientip): | |
for f in flows: | ||
if len(set(f) - set(creds.keys())) == 0: | ||
logger.info("Auth completed with creds: %r", creds) | ||
self._remove_session(session) | ||
defer.returnValue((True, creds, clientdict)) | ||
defer.returnValue((True, creds, clientdict, session['id'])) | ||
|
||
ret = self._auth_dict_for_flows(flows, session) | ||
ret['completed'] = creds.keys() | ||
defer.returnValue((False, ret, clientdict)) | ||
defer.returnValue((False, ret, clientdict, session['id'])) | ||
|
||
@defer.inlineCallbacks | ||
def add_oob_auth(self, stagetype, authdict, clientip): | ||
|
@@ -154,6 +161,29 @@ def add_oob_auth(self, stagetype, authdict, clientip): | |
defer.returnValue(True) | ||
defer.returnValue(False) | ||
|
||
def set_session_data(self, session_id, key, value): | ||
""" | ||
Store a key-value pair into the sessions data associated with this | ||
request. This data is stored server-side and cannot be modified by | ||
the client. | ||
:param session_id: (string) The ID of this session as returned from check_auth | ||
:param key: (string) The key to store the data under | ||
:param value: (any) The data to store | ||
""" | ||
sess = self._get_session_info(session_id) | ||
sess.setdefault('serverdict', {})[key] = value | ||
self._save_session(sess) | ||
|
||
def get_session_data(self, session_id, key, default=None): | ||
""" | ||
Retrieve data stored with set_session_data | ||
:param session_id: (string) The ID of this session as returned from check_auth | ||
:param key: (string) The key to store the data under | ||
:param default: (any) Value to return if the key has not been set | ||
""" | ||
sess = self._get_session_info(session_id) | ||
return sess.setdefault('serverdict', {}).get(key, default) | ||
|
||
@defer.inlineCallbacks | ||
def _check_password_auth(self, authdict, _): | ||
if "user" not in authdict or "password" not in authdict: | ||
|
@@ -263,7 +293,7 @@ def _get_session_info(self, session_id): | |
if not session_id: | ||
# create a new session | ||
while session_id is None or session_id in self.sessions: | ||
session_id = stringutils.random_string(24) | ||
session_id = stringutils.random_string_with_symbols(24) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do the session ids get sent to the client? Or they purely internal? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They get sent to the client There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its probably fine, but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, maybe - this one only goes into json so it should be fine, but possibly the extra token space isn't worth it. |
||
self.sessions[session_id] = { | ||
"id": session_id, | ||
} | ||
|
@@ -455,11 +485,17 @@ def add_threepid(self, user_id, medium, address, validated_at): | |
def _save_session(self, session): | ||
# TODO: Persistent storage | ||
logger.debug("Saving session %s", session) | ||
session["last_used"] = time.time() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use We generally try and consistently use milliseconds internally, rather than seconds. |
||
self.sessions[session["id"]] = session | ||
|
||
def _remove_session(self, session): | ||
logger.debug("Removing session %s", session) | ||
del self.sessions[session["id"]] | ||
self._prune_sessions() | ||
|
||
def _prune_sessions(self): | ||
for sid, sess in self.sessions.items(): | ||
last_used = 0 | ||
if 'last_used' in sess: | ||
last_used = sess['last_used'] | ||
if last_used < time.time() - AuthHandler.SESSION_EXPIRE_SECS: | ||
del self.sessions[sid] | ||
|
||
def hash(self, password): | ||
"""Computes a secure hash of password. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We generally try and consistently always use milliseconds.
(You get brownie points for including "secs" in the name though)