diff roundup/backends/rdbms_common.py @ 3963:3230f9c88086

Fix race condition for key properties in rdbms backends [SF#1876683]
author Richard Jones <richard@users.sourceforge.net>
date Thu, 07 Feb 2008 03:28:34 +0000
parents 9095a4da67f9
children b1e81ad3fa6a
line wrap: on
line diff
--- a/roundup/backends/rdbms_common.py	Thu Feb 07 01:03:39 2008 +0000
+++ b/roundup/backends/rdbms_common.py	Thu Feb 07 03:28:34 2008 +0000
@@ -15,7 +15,7 @@
 # BASIS, AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
 # SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
 #
-#$Id: rdbms_common.py,v 1.193 2007-10-25 07:26:11 richard Exp $
+#$Id: rdbms_common.py,v 1.194 2008-02-07 03:28:34 richard Exp $
 """ Relational database (SQL) backend common code.
 
 Basics:
@@ -29,7 +29,7 @@
 - journals are stored adjunct to the per-class tables
 - table names and columns have "_" prepended so the names can't clash with
   restricted names (like "order")
-- retirement is determined by the __retired__ column being true
+- retirement is determined by the __retired__ column being > 0
 
 Database-specific changes may generally be pushed out to the overridable
 sql_* methods, since everything else should be fairly generic. There's
@@ -42,6 +42,13 @@
 that maps to a table. If that information differs from the hyperdb schema,
 then we update it. We also store in the schema dict a version which
 allows us to upgrade the database schema when necessary. See upgrade_db().
+
+To force a unqiueness constraint on the key properties we put the item
+id into the __retired__ column duing retirement (so it's 0 for "active"
+items) and place a unqiueness constraint on key + __retired__. This is
+particularly important for the users class where multiple users may
+try to have the same username, with potentially many retired users with
+the same name.
 """
 __docformat__ = 'restructuredtext'
 
@@ -239,7 +246,8 @@
 
     # update this number when we need to make changes to the SQL structure
     # of the backen database
-    current_db_version = 4
+    current_db_version = 5
+    db_version_updated = False
     def upgrade_db(self):
         """ Update the SQL database to reflect changes in the backend code.
 
@@ -272,7 +280,11 @@
         if version < 4:
             self.fix_version_3_tables()
 
+        if version < 5:
+            self.fix_version_4_tables()
+
         self.database_schema['version'] = self.current_db_version
+        self.db_version_updated = True
         return 1
 
     def fix_version_3_tables(self):
@@ -283,9 +295,21 @@
             self.sql('ALTER TABLE %ss ADD %s_value TEXT'%(name, name))
 
     def fix_version_2_tables(self):
-        """Default (used by sqlite): NOOP"""
+        # Default (used by sqlite): NOOP
         pass
 
+    def fix_version_4_tables(self):
+        # note this is an explicit call now
+        c = self.cursor
+        for cn, klass in self.classes.items():
+            c.execute('select id from _%s where __retired__<>0'%(cn,))
+            for (id,) in c.fetchall():
+                c.execute('update _%s set __retired__=%s where id=%s'%(cn,
+                    self.arg, self.arg), (id, id))
+
+            if klass.key:
+                self.add_class_key_required_unique_constraint(cn, klass.key)
+
     def _convert_journal_tables(self):
         """Get current journal table contents, drop the table and re-create"""
         c = self.cursor
@@ -530,9 +554,18 @@
                         spec.classname, spec.key)
             self.sql(index_sql3)
 
+            # and the unique index for key / retired(id)
+            self.add_class_key_required_unique_constraint(spec.classname,
+                spec.key)
+
         # TODO: create indexes on (selected?) Link property columns, as
         # they're more likely to be used for lookup
 
+    def add_class_key_required_unique_constraint(self, cn, key):
+        sql = '''create unique index _%s_key_retired_idx 
+            on _%s(__retired__, _%s)'''%(cn, cn, key)
+        self.sql(sql)
+
     def drop_class_table_indexes(self, cn, key):
         # drop the old table indexes first
         l = ['_%s_id_idx'%cn, '_%s_retired_idx'%cn]
@@ -555,10 +588,15 @@
     def drop_class_table_key_index(self, cn, key):
         table_name = '_%s'%cn
         index_name = '_%s_%s_idx'%(cn, key)
-        if not self.sql_index_exists(table_name, index_name):
-            return
-        sql = 'drop index '+index_name
-        self.sql(sql)
+        if self.sql_index_exists(table_name, index_name):
+            sql = 'drop index '+index_name
+            self.sql(sql)
+
+        # and now the retired unique index too
+        index_name = '_%s_key_retired_idx'%cn
+        if self.sql_index_exists(table_name, index_name):
+            sql = 'drop index _%s_key_retired_idx'%cn
+            self.sql(sql)
 
     def create_journal_table(self, spec):
         """ create the journal table for a class given the spec and
@@ -1760,7 +1798,7 @@
         # conversion (hello, sqlite)
         sql = 'update _%s set __retired__=%s where id=%s'%(self.classname,
             self.db.arg, self.db.arg)
-        self.db.sql(sql, (1, nodeid))
+        self.db.sql(sql, (nodeid, nodeid))
         if self.do_journal:
             self.db.addjournal(self.classname, nodeid, ''"retired", None)
 
@@ -1802,7 +1840,7 @@
         sql = 'select __retired__ from _%s where id=%s'%(self.classname,
             self.db.arg)
         self.db.sql(sql, (nodeid,))
-        return int(self.db.sql_fetchone()[0])
+        return int(self.db.sql_fetchone()[0]) > 0
 
     def destroy(self, nodeid):
         """Destroy a node.
@@ -1880,9 +1918,9 @@
 
         # use the arg to handle any odd database type conversion (hello,
         # sqlite)
-        sql = "select id from _%s where _%s=%s and __retired__ <> %s"%(
+        sql = "select id from _%s where _%s=%s and __retired__=%s"%(
             self.classname, self.key, self.db.arg, self.db.arg)
-        self.db.sql(sql, (keyvalue, 1))
+        self.db.sql(sql, (keyvalue, 0))
 
         # see if there was a result that's not retired
         row = self.db.sql_fetchone()
@@ -1947,8 +1985,8 @@
                 s += '_%s in (%s)'%(prop, ','.join([a]*len(values)))
                 where.append('(' + s +')')
         if where:
-            allvalues = (1, ) + allvalues
-            sql.append("""select id from _%s where  __retired__ <> %s
+            allvalues = (0, ) + allvalues
+            sql.append("""select id from _%s where  __retired__=%s
                 and %s"""%(self.classname, a, ' and '.join(where)))
 
         # now multilinks
@@ -1957,7 +1995,7 @@
                 continue
             if not values:
                 continue
-            allvalues += (1, )
+            allvalues += (0, )
             if type(values) is type(''):
                 allvalues += (values,)
                 s = a
@@ -1965,7 +2003,7 @@
                 allvalues += tuple(values.keys())
                 s = ','.join([a]*len(values))
             tn = '%s_%s'%(self.classname, prop)
-            sql.append("""select id from _%s, %s where  __retired__ <> %s
+            sql.append("""select id from _%s, %s where  __retired__=%s
                   and id = %s.nodeid and %s.linkid in (%s)"""%(self.classname,
                   tn, a, tn, tn, s))
 
@@ -1996,9 +2034,9 @@
 
         # generate the where clause
         s = ' and '.join(['lower(_%s)=%s'%(col, self.db.arg) for col in where])
-        sql = 'select id from _%s where %s and __retired__<>%s'%(
+        sql = 'select id from _%s where %s and __retired__=%s'%(
             self.classname, s, self.db.arg)
-        args.append(1)
+        args.append(0)
         self.db.sql(sql, tuple(args))
         # XXX numeric ids
         l = [str(x[0]) for x in self.db.sql_fetchall()]
@@ -2017,12 +2055,13 @@
         """
         # flip the sense of the 'retired' flag if we don't want all of them
         if retired is not None:
+            args = (0, )
             if retired:
-                args = (0, )
+                compare = '>'
             else:
-                args = (1, )
-            sql = 'select id from _%s where __retired__ <> %s'%(self.classname,
-                self.db.arg)
+                compare = '='
+            sql = 'select id from _%s where __retired__%s%s'%(self.classname,
+                compare, self.db.arg)
         else:
             args = ()
             sql = 'select id from _%s'%self.classname
@@ -2276,7 +2315,7 @@
         props = self.getprops()
 
         # don't match retired nodes
-        where.append('_%s.__retired__ <> 1'%icn)
+        where.append('_%s.__retired__=0'%icn)
 
         # add results of full text search
         if search_matches is not None:
@@ -2341,7 +2380,7 @@
         The SQL select must include the item id as the first column.
 
         This function DOES NOT filter out retired items, add on a where
-        clause "__retired__ <> 1" if you don't want retired nodes.
+        clause "__retired__=0" if you don't want retired nodes.
         """
         if __debug__:
             start_t = time.time()
@@ -2502,7 +2541,7 @@
             # conversion (hello, sqlite)
             sql = 'update _%s set __retired__=%s where id=%s'%(self.classname,
                 self.db.arg, self.db.arg)
-            self.db.sql(sql, (1, newid))
+            self.db.sql(sql, (newid, newid))
         return newid
 
     def export_journals(self):

Roundup Issue Tracker: http://roundup-tracker.org/