changeset 6823:fe0091279f50

Refactor session db logging and key generation for sessions/otks While I was working on the redis sessiondb stuff, I noticed that log_wanrning, get_logger ... was duplicated. Also there was code to generate a unique key for otks that was duplicated. Changes: creating new sessions_common.py and SessionsCommon class to provide methods: log_warning, log_info, log_debug, get_logger, getUniqueKey getUniqueKey method is closer to the method used to make session keys in client.py. sessions_common.py now report when random_.py chooses a weak random number generator. Removed same from rest.py. get_logger reconciles all logging under roundup.hyperdb.backends.<name of BasicDatabase class> some backends used to log to root logger. have BasicDatabase in other sessions_*.py modules inherit from SessionCommon. change logging to use log_* methods. In addition: remove unused imports reported by flake8 and other formatting changes modify actions.py, rest.py, templating.py to use getUniqueKey method. add tests for new methods test_redis_session.py swap out ModuleNotFoundError for ImportError to prevent crash in python2 when redis is not present. allow injection of username:password or just password into redis connection URL. set pytest_redis_pw envirnment variable to password or user:password when running test.
author John Rouillard <rouilj@ieee.org>
date Sun, 07 Aug 2022 01:51:11 -0400
parents 5053ee6c846b
children 9811073b289e
files roundup/backends/sessions_common.py roundup/backends/sessions_dbm.py roundup/backends/sessions_rdbms.py roundup/backends/sessions_redis.py roundup/backends/sessions_sqlite.py roundup/cgi/actions.py roundup/cgi/templating.py roundup/rest.py test/session_common.py test/test_redis_session.py
diffstat 10 files changed, 216 insertions(+), 103 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/roundup/backends/sessions_common.py	Sun Aug 07 01:51:11 2022 -0400
@@ -0,0 +1,50 @@
+import base64, logging
+
+import roundup.anypy.random_ as random_
+from roundup.anypy.strings import b2s
+
+logger = logging.getLogger('roundup.hyperdb.backend.sessions')
+if not random_.is_weak:
+    logger.debug("Importing good random generator")
+else:
+    logger.warning("**SystemRandom not available. Using poor random generator")
+
+
+class SessionCommon:
+
+    def log_debug(self, msg, *args, **kwargs):
+        """Log a message with level DEBUG."""
+
+        logger = self.get_logger()
+        logger.debug(msg, *args, **kwargs)
+
+    def log_info(self, msg, *args, **kwargs):
+        """Log a message with level INFO."""
+
+        logger = self.get_logger()
+        logger.info(msg, *args, **kwargs)
+
+    def log_warning(self, msg, *args, **kwargs):
+        """Log a message with level INFO."""
+        logger = self.get_logger()
+        logger.warning(msg, *args, **kwargs)
+
+    def get_logger(self):
+        """Return the logger for this database."""
+
+        # Because getting a logger requires acquiring a lock, we want
+        # to do it only once.
+        if not hasattr(self, '__logger'):
+            self.__logger = logging.getLogger('roundup.hyperdb.backends.%s' %
+                                              self.name or "basicdb" )
+
+        return self.__logger
+
+    def getUniqueKey(self, length=40):
+        otk = b2s(base64.b64encode(
+            random_.token_bytes(length))).rstrip('=')
+        while self.exists(otk):
+            otk = b2s(base64.b64encode(
+                random_.token_bytes(length))).rstrip('=')
+
+        return otk
--- a/roundup/backends/sessions_dbm.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/backends/sessions_dbm.py	Sun Aug 07 01:51:11 2022 -0400
@@ -6,16 +6,17 @@
 """
 __docformat__ = 'restructuredtext'
 
-import os, marshal, time, logging, random
+import marshal, os, random, time
 
 from roundup.anypy.html import html_escape as escape
 
 from roundup import hyperdb
 from roundup.i18n import _
 from roundup.anypy.dbm_ import anydbm, whichdb
+from roundup.backends.sessions_common import SessionCommon
 
 
-class BasicDatabase:
+class BasicDatabase(SessionCommon):
     ''' Provide a nice encapsulation of an anydbm store.
 
         Keys are id strings, values are automatically marshalled data.
@@ -88,7 +89,7 @@
 
     def set(self, infoid, **newvalues):
         db = self.opendb('c')
-        timestamp=None
+        timestamp = None
         try:
             if infoid in db:
                 values = marshal.loads(db[infoid])
@@ -147,7 +148,6 @@
         dbm = __import__(db_type)
 
         retries_left = 15
-        logger = logging.getLogger('roundup.hyperdb.backend.sessions')
         while True:
             try:
                 handle = dbm.open(path, mode)
@@ -157,14 +157,16 @@
                 #   [Errno 11] Resource temporarily unavailable retry
                 # FIXME: make this more specific
                 if retries_left < 10:
-                    logger.warning('dbm.open failed on ...%s, retry %s left: %s, %s'%(path[-15:],15-retries_left,retries_left,e))
+                    self.log_warning(
+                        'dbm.open failed on ...%s, retry %s left: %s, %s' %
+                        (path[-15:], 15-retries_left, retries_left, e))
                 if retries_left < 0:
                     # We have used up the retries. Reraise the exception
                     # that got us here.
                     raise
                 else:
                     # stagger retry to try to get around thundering herd issue.
-                    time.sleep(random.randint(0,25)*.005)
+                    time.sleep(random.randint(0, 25)*.005)
                     retries_left = retries_left - 1
                     continue  # the while loop
         return handle
--- a/roundup/backends/sessions_rdbms.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/backends/sessions_rdbms.py	Sun Aug 07 01:51:11 2022 -0400
@@ -5,67 +5,70 @@
 class. It's now also used for One Time Key handling too.
 """
 __docformat__ = 'restructuredtext'
-import os, time, logging
+import time
 
 from roundup.anypy.html import html_escape as escape
+from roundup.backends.sessions_common import SessionCommon
 
-class BasicDatabase:
+
+class BasicDatabase(SessionCommon):
     ''' Provide a nice encapsulation of an RDBMS table.
 
         Keys are id strings, values are automatically marshalled data.
     '''
     name = None
+
     def __init__(self, db):
         self.db = db
         self.conn, self.cursor = self.db.sql_open_connection()
 
     def clear(self):
-        self.cursor.execute('delete from %ss'%self.name)
+        self.cursor.execute('delete from %ss' % self.name)
 
     def exists(self, infoid):
         n = self.name
-        self.cursor.execute('select count(*) from %ss where %s_key=%s'%(n,
-            n, self.db.arg), (infoid,))
+        self.cursor.execute('select count(*) from %ss where %s_key=%s' %
+                            (n, n, self.db.arg), (infoid,))
         return int(self.cursor.fetchone()[0])
 
     _marker = []
+
     def get(self, infoid, value, default=_marker):
         n = self.name
-        self.cursor.execute('select %s_value from %ss where %s_key=%s'%(n,
-            n, n, self.db.arg), (infoid,))
+        self.cursor.execute('select %s_value from %ss where %s_key=%s' %
+                            (n, n, n, self.db.arg), (infoid,))
         res = self.cursor.fetchone()
         if not res:
             if default != self._marker:
                 return default
-            raise KeyError('No such %s "%s"'%(self.name, escape(infoid)))
+            raise KeyError('No such %s "%s"' % (self.name, escape(infoid)))
         values = eval(res[0])
         return values.get(value, None)
 
     def getall(self, infoid):
         n = self.name
-        self.cursor.execute('select %s_value from %ss where %s_key=%s'%(n,
-            n, n, self.db.arg), (infoid,))
+        self.cursor.execute('select %s_value from %ss where %s_key=%s' %
+                            (n, n, n, self.db.arg), (infoid,))
         res = self.cursor.fetchone()
         if not res:
-            raise KeyError('No such %s "%s"'%(self.name, escape (infoid)))
+            raise KeyError('No such %s "%s"' % (self.name, escape(infoid)))
         return eval(res[0])
 
     def set(self, infoid, **newvalues):
         """ Store all newvalues under key infoid with a timestamp in database.
 
-            If newvalues['__timestamp'] exists and is representable as a floating point number
-            (i.e. could be generated by time.time()), that value is used for the <name>_time
-            column in the database.
+            If newvalues['__timestamp'] exists and is representable as
+            a floating point number (i.e. could be generated by time.time()),
+            that value is used for the <name>_time column in the database.
         """
         c = self.cursor
         n = self.name
         a = self.db.arg
-        c.execute('select %s_value from %ss where %s_key=%s'% \
-                  (n, n, n, a),
-            (infoid,))
+        c.execute('select %s_value from %ss where %s_key=%s' %
+                  (n, n, n, a), (infoid,))
         res = c.fetchone()
 
-        timestamp=time.time()
+        timestamp = time.time()
         if res:
             values = eval(res[0])
         else:
@@ -85,43 +88,43 @@
         values.update(newvalues)
         if res:
             sql = ('update %ss set %s_value=%s, %s_time=%s '
-                       'where %s_key=%s'%(n, n, a, n, a, n, a))
+                   'where %s_key=%s' % (n, n, a, n, a, n, a))
             args = (repr(values), timestamp, infoid)
         else:
             sql = 'insert into %ss (%s_key, %s_time, %s_value) '\
-                'values (%s, %s, %s)'%(n, n, n, n, a, a, a)
+                'values (%s, %s, %s)' % (n, n, n, n, a, a, a)
             args = (infoid, timestamp, repr(values))
         c.execute(sql, args)
 
     def list(self):
         c = self.cursor
         n = self.name
-        c.execute('select %s_key from %ss'%(n, n))
+        c.execute('select %s_key from %ss' % (n, n))
         return [res[0] for res in c.fetchall()]
 
     def destroy(self, infoid):
-        self.cursor.execute('delete from %ss where %s_key=%s'%(self.name,
-            self.name, self.db.arg), (infoid,))
+        self.cursor.execute('delete from %ss where %s_key=%s' %
+                            (self.name, self.name, self.db.arg), (infoid,))
 
     def updateTimestamp(self, infoid):
         """ don't update every hit - once a minute should be OK """
         now = time.time()
-        self.cursor.execute('''update %ss set %s_time=%s where %s_key=%s
-            and %s_time < %s'''%(self.name, self.name, self.db.arg,
-            self.name, self.db.arg, self.name, self.db.arg),
-            (now, infoid, now-60))
+        self.cursor.execute('''update %ss set %s_time=%s where %s_key=%s '''
+            '''and %s_time < %s''' %
+                            (self.name, self.name, self.db.arg, self.name,
+                             self.db.arg, self.name, self.db.arg),
+                            (now, infoid, now-60))
 
     def clean(self):
         ''' Remove session records that haven't been used for a week. '''
         now = time.time()
         week = 60*60*24*7
         old = now - week
-        self.cursor.execute('delete from %ss where %s_time < %s'%(self.name,
-            self.name, self.db.arg), (old, ))
+        self.cursor.execute('delete from %ss where %s_time < %s' %
+                            (self.name, self.name, self.db.arg), (old, ))
 
     def commit(self):
-        logger = logging.getLogger('roundup.hyperdb.backend')
-        logger.info('commit %s' % self.name)
+        self.log_info('commit %s' % self.name)
         self.conn.commit()
         self.cursor = self.conn.cursor()
 
@@ -136,9 +139,11 @@
     def close(self):
         self.conn.close()
 
+
 class Sessions(BasicDatabase):
     name = 'session'
 
+
 class OneTimeKeys(BasicDatabase):
     name = 'otk'
 
--- a/roundup/backends/sessions_redis.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/backends/sessions_redis.py	Sun Aug 07 01:51:11 2022 -0400
@@ -18,14 +18,16 @@
 """
 __docformat__ = 'restructuredtext'
 
-import logging, marshal, redis, time
+import marshal, redis, time
 
 from roundup.anypy.html import html_escape as escape
 
 from roundup.i18n import _
 
+from roundup.backends.sessions_common import SessionCommon
 
-class BasicDatabase:
+
+class BasicDatabase(SessionCommon):
     ''' Provide a nice encapsulation of a redis store.
 
         Keys are id strings, values are automatically marshalled data.
@@ -185,9 +187,9 @@
                     transaction.execute()
                     break
                 except redis.Exceptions.WatchError:
-                    logging.getLogger('roundup.redis').info(
+                    self.log_info(
                         _('Key %(key)s changed in %(name)s db' %
-                        {"key": escape(infoid), "name": self.name})
+                          {"key": escape(infoid), "name": self.name})
                     )
             else:
                 raise Exception(_("Redis set failed afer 3 retries"))
--- a/roundup/backends/sessions_sqlite.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/backends/sessions_sqlite.py	Sun Aug 07 01:51:11 2022 -0400
@@ -11,53 +11,33 @@
 provide a performance speedup.
 """
 __docformat__ = 'restructuredtext'
-import os, time, logging
+
+from roundup.backends import sessions_rdbms
 
-from roundup.anypy.html import html_escape as escape
-import roundup.backends.sessions_rdbms as rdbms_session
 
-class BasicDatabase(rdbms_session.BasicDatabase):
+class BasicDatabase(sessions_rdbms.BasicDatabase):
     ''' Provide a nice encapsulation of an RDBMS table.
 
         Keys are id strings, values are automatically marshalled data.
     '''
     name = None
+
     def __init__(self, db):
         self.db = db
         self.conn, self.cursor = self.db.sql_open_connection(dbname=self.name)
 
-        self.sql('''SELECT name FROM sqlite_master WHERE type='table' AND name='%ss';'''%self.name)
+        self.sql('''SELECT name FROM sqlite_master WHERE type='table' AND '''
+                 '''name='%ss';''' % self.name)
         table_exists = self.cursor.fetchone()
 
         if not table_exists:
             # create table/rows etc.
             self.sql('''CREATE TABLE %(name)ss (%(name)s_key VARCHAR(255),
-            %(name)s_value TEXT, %(name)s_time REAL)'''%{"name":self.name})
-            self.sql('CREATE INDEX %(name)s_key_idx ON %(name)ss(%(name)s_key)'%{"name":self.name})
+            %(name)s_value TEXT, %(name)s_time REAL)''' % {"name": self.name})
+            self.sql('CREATE INDEX %(name)s_key_idx ON '
+                     '%(name)ss(%(name)s_key)' % {"name": self.name})
             self.commit()
 
-    def log_debug(self, msg, *args, **kwargs):
-        """Log a message with level DEBUG."""
-
-        logger = self.get_logger()
-        logger.debug(msg, *args, **kwargs)
-
-    def log_info(self, msg, *args, **kwargs):
-        """Log a message with level INFO."""
-
-        logger = self.get_logger()
-        logger.info(msg, *args, **kwargs)
-
-    def get_logger(self):
-        """Return the logger for this database."""
-
-        # Because getting a logger requires acquiring a lock, we want
-        # to do it only once.
-        if not hasattr(self, '__logger'):
-            self.__logger = logging.getLogger('roundup')
-
-        return self.__logger
-
     def sql(self, sql, args=None, cursor=None):
         """ Execute the sql with the optional args.
         """
@@ -69,9 +49,11 @@
         else:
             cursor.execute(sql)
 
+
 class Sessions(BasicDatabase):
     name = 'session'
 
+
 class OneTimeKeys(BasicDatabase):
     name = 'otk'
 
--- a/roundup/cgi/actions.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/cgi/actions.py	Sun Aug 07 01:51:11 2022 -0400
@@ -10,7 +10,7 @@
 from roundup.exceptions import Reject, RejectRaw
 from roundup.anypy import urllib_
 from roundup.anypy.strings import StringIO
-import roundup.anypy.random_ as random_
+
 
 from roundup.anypy.html import html_escape
 
@@ -23,10 +23,6 @@
            'ConfRegoAction', 'RegisterAction', 'LoginAction', 'LogoutAction',
            'NewItemAction', 'ExportCSVAction', 'ExportCSVWithIdAction']
 
-# used by a couple of routines
-chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
-
-
 class Action:
     def __init__(self, client):
         self.client = client
@@ -1005,9 +1001,8 @@
             return
 
         # generate the one-time-key and store the props for later
-        otk = ''.join([random_.choice(chars) for x in range(32)])
-        while otks.exists(otk):
-            otk = ''.join([random_.choice(chars) for x in range(32)])
+        otk = otks.getUniqueKey(length=32)
+
         otks.set(otk, uid=uid, uaddress=address)
         otks.commit()
 
@@ -1150,9 +1145,7 @@
             elif isinstance(proptype, hyperdb.Password):
                 user_props[propname] = str(value)
         otks = self.db.getOTKManager()
-        otk = ''.join([random_.choice(chars) for x in range(32)])
-        while otks.exists(otk):
-            otk = ''.join([random_.choice(chars) for x in range(32)])
+        otk = otks.getUniqueKey(length=32)
         otks.set(otk, **user_props)
 
         # send the email
--- a/roundup/cgi/templating.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/cgi/templating.py	Sun Aug 07 01:51:11 2022 -0400
@@ -208,11 +208,7 @@
         module/function.
     '''
     otks=client.db.getOTKManager()
-    key = b2s(base64.b32encode(random_.token_bytes(40)))
-
-    while otks.exists(key):
-        key = b2s(base64.b32encode(random_.token_bytes(40)))
-
+    key = otks.getUniqueKey()
     # lifetime is in minutes.
     if lifetime is None:
         lifetime = client.db.config['WEB_CSRF_TOKEN_LIFETIME']
--- a/roundup/rest.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/roundup/rest.py	Sun Aug 07 01:51:11 2022 -0400
@@ -50,18 +50,6 @@
     basestring = str
     unicode = str
 
-import roundup.anypy.random_ as random_
-
-import logging
-logger = logging.getLogger('roundup.rest')
-
-if not random_.is_weak:
-    logger.debug("Importing good random generator")
-else:
-    logger.warning("**SystemRandom not available. Using poor random generator")
-
-chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
-
 
 def _data_decorator(func):
     """Wrap the returned data into an object."""
@@ -1140,9 +1128,7 @@
         """Get the Post Once Exactly token to create a new instance of class
            See https://tools.ietf.org/html/draft-nottingham-http-poe-00"""
         otks = self.db.Otk
-        poe_key = ''.join([random_.choice(chars) for x in range(40)])
-        while otks.exists(u2s(poe_key)):
-            poe_key = ''.join([random_.choice(chars) for x in range(40)])
+        poe_key = otks.getUniqueKey()
 
         try:
             lifetime = int(input['lifetime'].value)
--- a/test/session_common.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/test/session_common.py	Sun Aug 07 01:51:11 2022 -0400
@@ -24,6 +24,18 @@
 sessions_dbm.py.
 
 """
+
+import pytest, sys
+
+_py3 = sys.version_info[0] > 2
+if _py3:
+    skip_py2 = lambda func, *args, **kwargs: func
+else:
+    from .pytest_patcher import mark_class
+    skip_py2 = mark_class(pytest.mark.skip(
+        reason="Skipping log test, test doesn't work on python2"))
+
+
 class SessionTest(object):
     def setUp(self):
         # remove previous test, ignore errors
@@ -185,3 +197,75 @@
         week_ago =  time.time() - 60*60*24*7
         self.assertGreater(week_ago + 302, ts)
         self.assertLess(week_ago + 298, ts)
+
+    def testGetUniqueKey(self):
+        # 40 bytes of randomness gets larger when encoded
+        key = self.sessions.getUniqueKey()
+        self.assertEqual(len(key), 54)
+
+        # length is bytes of randomness
+        key = self.sessions.getUniqueKey(length=23)
+        self.assertEqual(len(key), 31)
+
+        key = self.sessions.getUniqueKey(length=200)
+        self.assertEqual(len(key), 267)
+
+    def testget_logger(self):
+        logger = self.sessions.get_logger()
+        # why do rdbms session use session/otk as the table name
+        # while dbm uses sessions/otks? In any case check both.
+        self.assertIn(logger.name, ["roundup.hyperdb.backends.sessions",
+                                    "roundup.hyperdb.backends.session"])
+
+        logger = self.otks.get_logger()
+        self.assertIn(logger.name, ["roundup.hyperdb.backends.otks",
+                                    "roundup.hyperdb.backends.otk"])
+
+    def testget_logger_name_test(self):
+        self.sessions.name="otks"
+        logger = self.sessions.get_logger()
+        self.assertEqual(logger.name, "roundup.hyperdb.backends.otks")
+
+    @skip_py2
+    def test_log_warning(self):
+        """Only python3 pytest has the right context handler for this,
+           so skip this on python2.
+        """
+
+        self.sessions.name = "newdb"
+
+        with self.assertLogs(logger="roundup.hyperdb.backends.newdb") as logs:
+            self.sessions.log_warning("hello world")
+
+        self.assertEqual(len(logs.records), 1)
+        self.assertEqual(logs.records[0].levelname, "WARNING")
+
+    @skip_py2
+    def test_log_info(self):
+        """Only python3 pytest has the right context handler for this,
+           so skip this on python2.
+        """
+
+        self.sessions.name = "newdb"
+
+        with self.assertLogs(logger="roundup.hyperdb.backends.newdb") as logs:
+            self.sessions.log_info("hello world")
+
+        self.assertEqual(len(logs.records), 1)
+        self.assertEqual(logs.records[0].levelname, "INFO")
+
+    @skip_py2
+    def test_log_debug(self):
+        """Only python3 pytest has the right context handler for this,
+           so skip this on python2.
+        """
+
+        self.sessions.name = "newdb"
+
+        with self.assertLogs(logger="roundup.hyperdb.backends.newdb",
+                             level='DEBUG') as logs:
+            self.sessions.log_debug("hello world")
+
+        self.assertEqual(len(logs.records), 1)
+        self.assertEqual(logs.records[0].levelname, "DEBUG")
+        
--- a/test/test_redis_session.py	Sun Aug 07 01:26:30 2022 -0400
+++ b/test/test_redis_session.py	Sun Aug 07 01:51:11 2022 -0400
@@ -23,7 +23,7 @@
 try:
     from roundup.backends.sessions_redis import Sessions, OneTimeKeys
     skip_redis = lambda func, *args, **kwargs: func
-except ModuleNotFoundError as e:
+except ImportError as e:
     from .pytest_patcher import mark_class
     skip_redis = mark_class(pytest.mark.skip(
         reason='Skipping redis tests: redis module not available'))
@@ -37,9 +37,22 @@
     def setUp(self):
         SessionTest.setUp(self)
 
+        import os
+        if 'pytest_redis_pw' in os.environ:
+            pw = os.environ['pytest_redis_pw']
+            if ':' in pw:
+                # pw is user:password
+                pw = "%s@" % pw
+            else:
+                # pw is just password
+                pw = ":%s@" % pw
+        else:
+            pw = ""
+
         # redefine the session db's as redis.
         self.db.config.SESSIONDB_BACKEND = "redis"
-        self.db.config.SESSIONDB_REDIS_URL = 'redis://localhost:6379/15?health_check_interval=2'
+        self.db.config.SESSIONDB_REDIS_URL = \
+                    'redis://%slocalhost:6379/15?health_check_interval=2' % pw
         self.db.Session = None
         self.db.Otk = None
         self.sessions = self.db.getSessionManager()

Roundup Issue Tracker: http://roundup-tracker.org/