diff roundup/backends/rdbms_common.py @ 2098:18addf2a8596

Implemented proper datatypes in mysql and postgresql backends... ...(well, sqlite too, but that doesn't care). Probably should use BOOLEAN instead of INTEGER for the Boolean props. Need to fix a bizzaro MySQL error (gee, how unusual) Need to finish MySQL migration from "version 1" database schemas.
author Richard Jones <richard@users.sourceforge.net>
date Mon, 22 Mar 2004 07:45:40 +0000
parents 3f6024ab2c7a
children 62ed6505cbec
line wrap: on
line diff
--- a/roundup/backends/rdbms_common.py	Mon Mar 22 00:28:04 2004 +0000
+++ b/roundup/backends/rdbms_common.py	Mon Mar 22 07:45:40 2004 +0000
@@ -1,4 +1,4 @@
-# $Id: rdbms_common.py,v 1.83 2004-03-21 23:39:08 richard Exp $
+# $Id: rdbms_common.py,v 1.84 2004-03-22 07:45:39 richard Exp $
 ''' Relational database (SQL) backend common code.
 
 Basics:
@@ -46,6 +46,13 @@
 # number of rows to keep in memory
 ROW_CACHE_SIZE = 100
 
+def _num_cvt(num):
+    num = str(num)
+    try:
+        return int(num)
+    except:
+        return float(num)
+
 class Database(FileStorage, hyperdb.Database, roundupdb.Database):
     ''' Wrapper around an SQL database that presents a hyperdb interface.
 
@@ -212,22 +219,43 @@
                 klass.index(nodeid)
         self.indexer.save_index()
 
+
+    hyperdb_to_sql_datatypes = {
+        hyperdb.String : 'VARCHAR(255)',
+        hyperdb.Date   : 'TIMESTAMP',
+        hyperdb.Link   : 'INTEGER',
+        hyperdb.Interval  : 'VARCHAR(255)',
+        hyperdb.Password  : 'VARCHAR(255)',
+        hyperdb.Boolean   : 'INTEGER',
+        hyperdb.Number    : 'REAL',
+    }
     def determine_columns(self, properties):
         ''' Figure the column names and multilink properties from the spec
 
             "properties" is a list of (name, prop) where prop may be an
             instance of a hyperdb "type" _or_ a string repr of that type.
         '''
-        cols = ['_actor', '_activity', '_creator', '_creation']
+        cols = [
+            ('_actor', 'INTEGER'),
+            ('_activity', 'DATE'),
+            ('_creator', 'INTEGER'),
+            ('_creation', 'DATE')
+        ]
         mls = []
         # add the multilinks separately
         for col, prop in properties:
             if isinstance(prop, Multilink):
                 mls.append(col)
-            elif isinstance(prop, type('')) and prop.find('Multilink') != -1:
-                mls.append(col)
-            else:
-                cols.append('_'+col)
+                continue
+
+            if isinstance(prop, type('')):
+                raise ValueError, "string property spec!"
+                #and prop.find('Multilink') != -1:
+                #mls.append(col)
+
+            datatype = self.hyperdb_to_sql_datatypes[prop.__class__]
+            cols.append(('_'+col, datatype))
+
         cols.sort()
         return cols, mls
 
@@ -315,11 +343,11 @@
         cols, mls = self.determine_columns(spec.properties.items())
 
         # add on our special columns
-        cols.append('id')
-        cols.append('__retired__')
+        cols.append(('id', 'INTEGER PRIMARY KEY'))
+        cols.append(('__retired__', 'INTEGER DEFAULT 0'))
 
         # create the base table
-        scols = ','.join(['%s varchar'%x for x in cols])
+        scols = ','.join(['%s %s'%x for x in cols])
         sql = 'create table _%s (%s)'%(spec.classname, scols)
         if __debug__:
             print >>hyperdb.DEBUG, 'create_class', (self, sql)
@@ -332,13 +360,6 @@
     def create_class_table_indexes(self, spec):
         ''' create the class table for the given spec
         '''
-        # create id index
-        index_sql1 = 'create index _%s_id_idx on _%s(id)'%(
-                        spec.classname, spec.classname)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'create_index', (self, index_sql1)
-        self.cursor.execute(index_sql1)
-
         # create __retired__ index
         index_sql2 = 'create index _%s_retired_idx on _%s(__retired__)'%(
                         spec.classname, spec.classname)
@@ -376,14 +397,10 @@
     def create_class_table_key_index(self, cn, key):
         ''' create the class table for the given spec
         '''
+        sql = 'create index _%s_%s_idx on _%s(_%s)'%(cn, key, cn, key)
         if __debug__:
-            print >>hyperdb.DEBUG, 'update_class setting keyprop %r'% \
-                key
-        index_sql3 = 'create index _%s_%s_idx on _%s(_%s)'%(cn, key,
-            cn, key)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'create_index', (self, index_sql3)
-        self.cursor.execute(index_sql3)
+            print >>hyperdb.DEBUG, 'create_class_tab_key_index', (self, sql)
+        self.cursor.execute(sql)
 
     def drop_class_table_key_index(self, cn, key):
         table_name = '_%s'%cn
@@ -392,7 +409,7 @@
             return
         sql = 'drop index '+index_name
         if __debug__:
-            print >>hyperdb.DEBUG, 'drop_index', (self, sql)
+            print >>hyperdb.DEBUG, 'drop_class_tab_key_index', (self, sql)
         self.cursor.execute(sql)
 
     def create_journal_table(self, spec):
@@ -402,9 +419,11 @@
         # journal table
         cols = ','.join(['%s varchar'%x
             for x in 'nodeid date tag action params'.split()])
-        sql = 'create table %s__journal (%s)'%(spec.classname, cols)
+        sql = '''create table %s__journal (
+            nodeid integer, date timestamp, tag varchar(255),
+            action varchar(255), params varchar(25))'''%spec.classname
         if __debug__:
-            print >>hyperdb.DEBUG, 'create_class', (self, sql)
+            print >>hyperdb.DEBUG, 'create_journal_table', (self, sql)
         self.cursor.execute(sql)
         self.create_journal_table_indexes(spec)
 
@@ -476,13 +495,6 @@
         for ml in mls:
             self.create_multilink_table(spec, ml)
 
-        # ID counter
-        sql = 'insert into ids (name, num) values (%s,%s)'%(self.arg, self.arg)
-        vals = (spec.classname, 1)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'create_class', (self, sql, vals)
-        self.cursor.execute(sql, vals)
-
     def drop_class(self, cn, spec):
         ''' Drop the given table from the database.
 
@@ -497,10 +509,8 @@
 
         # drop class table and indexes
         self.drop_class_table_indexes(cn, spec[0])
-        sql = 'drop table _%s'%cn
-        if __debug__:
-            print >>hyperdb.DEBUG, 'drop_class', (self, sql)
-        self.cursor.execute(sql)
+
+        self.drop_class_table(cn)
 
         # drop journal table and indexes
         self.drop_journal_table_indexes(cn)
@@ -517,6 +527,12 @@
                 print >>hyperdb.DEBUG, 'drop_class', (self, sql)
             self.cursor.execute(sql)
 
+    def drop_class_table(self, cn):
+        sql = 'drop table _%s'%cn
+        if __debug__:
+            print >>hyperdb.DEBUG, 'drop_class', (self, sql)
+        self.cursor.execute(sql)
+
     #
     # Classes
     #
@@ -581,40 +597,18 @@
             self.cursor.execute(sql)
 
     #
-    # Node IDs
-    #
-    def newid(self, classname):
-        ''' Generate a new id for the given class
-        '''
-        # get the next ID
-        sql = 'select num from ids where name=%s'%self.arg
-        if __debug__:
-            print >>hyperdb.DEBUG, 'newid', (self, sql, classname)
-        self.cursor.execute(sql, (classname, ))
-        newid = int(self.cursor.fetchone()[0])
-
-        # update the counter
-        sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
-        vals = (int(newid)+1, classname)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'newid', (self, sql, vals)
-        self.cursor.execute(sql, vals)
-
-        # return as string
-        return str(newid)
-
-    def setid(self, classname, setid):
-        ''' Set the id counter: used during import of database
-        '''
-        sql = 'update ids set num=%s where name=%s'%(self.arg, self.arg)
-        vals = (setid, classname)
-        if __debug__:
-            print >>hyperdb.DEBUG, 'setid', (self, sql, vals)
-        self.cursor.execute(sql, vals)
-
-    #
     # Nodes
     #
+
+    hyperdb_to_sql_value = {
+        hyperdb.String : str,
+        hyperdb.Date   : lambda x: x.formal(sep=' ', sec='%f'),
+        hyperdb.Link   : int,
+        hyperdb.Interval  : lambda x: x.serialise(),
+        hyperdb.Password  : str,
+        hyperdb.Boolean   : int,
+        hyperdb.Number    : lambda x: x,
+    }
     def addnode(self, classname, nodeid, node):
         ''' Add the specified node to its class's db.
         '''
@@ -626,20 +620,24 @@
         cols, mls = self.determine_columns(cl.properties.items())
 
         # we'll be supplied these props if we're doing an import
-        if not node.has_key('creator'):
+        values = node.copy()
+        if not values.has_key('creator'):
             # add in the "calculated" properties (dupe so we don't affect
             # calling code's node assumptions)
-            node = node.copy()
-            node['creation'] = node['activity'] = date.Date()
-            node['actor'] = node['creator'] = self.getuid()
+            values['creation'] = values['activity'] = date.Date()
+            values['actor'] = values['creator'] = self.getuid()
+
+        cl = self.classes[classname]
+        props = cl.getprops(protected=1)
+        del props['id']
 
         # default the non-multilink columns
-        for col, prop in cl.properties.items():
-            if not node.has_key(col):
+        for col, prop in props.items():
+            if not values.has_key(col):
                 if isinstance(prop, Multilink):
-                    node[col] = []
+                    values[col] = []
                 else:
-                    node[col] = None
+                    values[col] = None
 
         # clear this node out of the cache if it's in there
         key = (classname, nodeid)
@@ -647,13 +645,20 @@
             del self.cache[key]
             self.cache_lru.remove(key)
 
-        # make the node data safe for the DB
-        node = self.serialise(classname, node)
+        # figure the values to insert
+        vals = []
+        for col,dt in cols:
+            prop = props[col[1:]]
+            value = values[col[1:]]
+            if value:
+                value = self.hyperdb_to_sql_value[prop.__class__](value)
+            vals.append(value)
+        vals.append(nodeid)
+        vals = tuple(vals)
 
         # make sure the ordering is correct for column name -> column value
-        vals = tuple([node[col[1:]] for col in cols]) + (nodeid, 0)
-        s = ','.join([self.arg for x in cols]) + ',%s,%s'%(self.arg, self.arg)
-        cols = ','.join(cols) + ',id,__retired__'
+        s = ','.join([self.arg for x in cols]) + ',%s'%self.arg
+        cols = ','.join([col for col,dt in cols]) + ',id'
 
         # perform the inserts
         sql = 'insert into _%s (%s) values (%s)'%(classname, cols, s)
@@ -689,34 +694,42 @@
         values['activity'] = date.Date()
         values['actor'] = self.getuid()
 
-        # make db-friendly
-        values = self.serialise(classname, values)
+        cl = self.classes[classname]
+        props = cl.getprops()
 
-        cl = self.classes[classname]
         cols = []
         mls = []
         # add the multilinks separately
-        props = cl.getprops()
         for col in values.keys():
             prop = props[col]
             if isinstance(prop, Multilink):
                 mls.append(col)
             else:
-                cols.append('_'+col)
+                cols.append(col)
         cols.sort()
 
+        # figure the values to insert
+        vals = []
+        for col in cols:
+            prop = props[col]
+            value = values[col]
+            if value is not None:
+                value = self.hyperdb_to_sql_value[prop.__class__](value)
+            vals.append(value)
+        vals.append(int(nodeid))
+        vals = tuple(vals)
+
         # if there's any updates to regular columns, do them
         if cols:
             # make sure the ordering is correct for column name -> column value
-            sqlvals = tuple([values[col[1:]] for col in cols]) + (nodeid,)
-            s = ','.join(['%s=%s'%(x, self.arg) for x in cols])
+            s = ','.join(['_%s=%s'%(x, self.arg) for x in cols])
             cols = ','.join(cols)
 
             # perform the update
             sql = 'update _%s set %s where id=%s'%(classname, s, self.arg)
             if __debug__:
-                print >>hyperdb.DEBUG, 'setnode', (self, sql, sqlvals)
-            self.cursor.execute(sql, sqlvals)
+                print >>hyperdb.DEBUG, 'setnode', (self, sql, vals)
+            self.cursor.execute(sql, vals)
 
         # now the fun bit, updating the multilinks ;)
         for col, (add, remove) in multilink_changes.items():
@@ -725,16 +738,28 @@
                 sql = 'insert into %s (nodeid, linkid) values (%s,%s)'%(tn,
                     self.arg, self.arg)
                 for addid in add:
-                    self.sql(sql, (nodeid, addid))
+                    # XXX numeric ids
+                    self.sql(sql, (int(nodeid), int(addid)))
             if remove:
                 sql = 'delete from %s where nodeid=%s and linkid=%s'%(tn,
                     self.arg, self.arg)
                 for removeid in remove:
-                    self.sql(sql, (nodeid, removeid))
+                    # XXX numeric ids
+                    self.sql(sql, (int(nodeid), int(removeid)))
 
         # make sure we do the commit-time extra stuff for this node
         self.transactions.append((self.doSaveNode, (classname, nodeid, values)))
 
+    sql_to_hyperdb_value = {
+        hyperdb.String : str,
+        hyperdb.Date   : date.Date,
+#        hyperdb.Link   : int,      # XXX numeric ids
+        hyperdb.Link   : str,
+        hyperdb.Interval  : date.Interval,
+        hyperdb.Password  : lambda x: password.Password(encrypted=x),
+        hyperdb.Boolean   : int,
+        hyperdb.Number    : _num_cvt,
+    }
     def getnode(self, classname, nodeid):
         ''' Get a node from the database.
         '''
@@ -753,7 +778,7 @@
         # figure the columns we're fetching
         cl = self.classes[classname]
         cols, mls = self.determine_columns(cl.properties.items())
-        scols = ','.join(cols)
+        scols = ','.join([col for col,dt in cols])
 
         # perform the basic property fetch
         sql = 'select %s from _%s where id=%s'%(scols, classname, self.arg)
@@ -765,8 +790,14 @@
 
         # make up the node
         node = {}
+        props = cl.getprops(protected=1)
         for col in range(len(cols)):
-            node[cols[col][1:]] = values[col]
+            name = cols[col][0][1:]
+            value = values[col]
+            if value is not None:
+                value = self.sql_to_hyperdb_value[props[name].__class__](value)
+            node[name] = value
+
 
         # now the multilinks
         for col in mls:
@@ -775,10 +806,8 @@
                 self.arg)
             self.cursor.execute(sql, (nodeid,))
             # extract the first column from the result
-            node[col] = [x[0] for x in self.cursor.fetchall()]
-
-        # un-dbificate the node data
-        node = self.unserialise(classname, node)
+            # XXX numeric ids
+            node[col] = [str(x[0]) for x in self.cursor.fetchall()]
 
         # save off in the cache
         key = (classname, nodeid)
@@ -826,71 +855,6 @@
         sql = 'delete from %s__journal where nodeid=%s'%(classname, self.arg)
         self.sql(sql, (nodeid,))
 
-    def serialise(self, classname, node):
-        '''Copy the node contents, converting non-marshallable data into
-           marshallable data.
-        '''
-        if __debug__:
-            print >>hyperdb.DEBUG, 'serialise', classname, node
-        properties = self.getclass(classname).getprops()
-        d = {}
-        for k, v in node.items():
-            # if the property doesn't exist, or is the "retired" flag then
-            # it won't be in the properties dict
-            if not properties.has_key(k):
-                d[k] = v
-                continue
-
-            # get the property spec
-            prop = properties[k]
-
-            if isinstance(prop, Password) and v is not None:
-                d[k] = str(v)
-            elif isinstance(prop, Date) and v is not None:
-                d[k] = v.serialise()
-            elif isinstance(prop, Interval) and v is not None:
-                d[k] = v.serialise()
-            else:
-                d[k] = v
-        return d
-
-    def unserialise(self, classname, node):
-        '''Decode the marshalled node data
-        '''
-        if __debug__:
-            print >>hyperdb.DEBUG, 'unserialise', classname, node
-        properties = self.getclass(classname).getprops()
-        d = {}
-        for k, v in node.items():
-            # if the property doesn't exist, or is the "retired" flag then
-            # it won't be in the properties dict
-            if not properties.has_key(k):
-                d[k] = v
-                continue
-
-            # get the property spec
-            prop = properties[k]
-
-            if isinstance(prop, Date) and v is not None:
-                d[k] = date.Date(v)
-            elif isinstance(prop, Interval) and v is not None:
-                d[k] = date.Interval(v)
-            elif isinstance(prop, Password) and v is not None:
-                p = password.Password()
-                p.unpack(v)
-                d[k] = p
-            elif isinstance(prop, Boolean) and v is not None:
-                d[k] = int(v)
-            elif isinstance(prop, Number) and v is not None:
-                # try int first, then assume it's a float
-                try:
-                    d[k] = int(v)
-                except ValueError:
-                    d[k] = float(v)
-            else:
-                d[k] = v
-        return d
-
     def hasnode(self, classname, nodeid):
         ''' Determine if the database has a given node.
         '''
@@ -930,9 +894,9 @@
         else:
             journaltag = self.getuid()
         if creation:
-            journaldate = creation.serialise()
+            journaldate = creation
         else:
-            journaldate = date.Date().serialise()
+            journaldate = date.Date()
 
         # create the journal entry
         cols = ','.join('nodeid date tag action params'.split())
@@ -960,12 +924,13 @@
         '''
         # make the params db-friendly
         params = repr(params)
-        entry = (nodeid, journaldate, journaltag, action, params)
+        dc = self.hyperdb_to_sql_value[hyperdb.Date]
+        entry = (nodeid, dc(journaldate), journaltag, action, params)
 
         # do the insert
         a = self.arg
-        sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(classname,
-            cols, a, a, a, a, a)
+        sql = 'insert into %s__journal (%s) values (%s,%s,%s,%s,%s)'%(
+            classname, cols, a, a, a, a, a)
         if __debug__:
             print >>hyperdb.DEBUG, 'addjournal', (self, sql, entry)
         self.cursor.execute(sql, entry)
@@ -980,9 +945,11 @@
             print >>hyperdb.DEBUG, 'load_journal', (self, sql, nodeid)
         self.cursor.execute(sql, (nodeid,))
         res = []
+        dc = self.sql_to_hyperdb_value[hyperdb.Date]
         for nodeid, date_stamp, user, action, params in self.cursor.fetchall():
             params = eval(params)
-            res.append((nodeid, date.Date(date_stamp), user, action, params))
+            # XXX numeric ids
+            res.append((str(nodeid), dc(date_stamp), user, action, params))
         return res
 
     def pack(self, pack_before):
@@ -1278,7 +1245,8 @@
         if self.do_journal:
             self.db.addjournal(self.classname, newid, 'create', {})
 
-        return newid
+        # XXX numeric ids
+        return str(newid)
 
     def export_list(self, propnames, nodeid):
         ''' Export a node - generate a list of CSV-able data in the order
@@ -1844,7 +1812,8 @@
                 keyvalue, self.classname)
 
         # return the id
-        return row[0]
+        # XXX numeric ids
+        return str(row[0])
 
     def find(self, **propspec):
         '''Get the ids of nodes in this class which link to the given nodes.
@@ -1923,9 +1892,11 @@
         else:
             o = o[0]
         t = ', '.join(tables)
-        sql = 'select distinct(id) from %s where __retired__ <> %s and %s'%(t, a, o)
+        sql = 'select distinct(id) from %s where __retired__ <> %s and %s'%(
+            t, a, o)
         self.db.sql(sql, allvalues)
-        l = [x[0] for x in self.db.sql_fetchall()]
+        # XXX numeric ids
+        l = [str(x[0]) for x in self.db.sql_fetchall()]
         if __debug__:
             print >>hyperdb.DEBUG, 'find ... ', l
         return l
@@ -1953,7 +1924,8 @@
             s, self.db.arg)
         args.append(0)
         self.db.sql(sql, tuple(args))
-        l = [x[0] for x in self.db.sql_fetchall()]
+        # XXX numeric ids
+        l = [str(x[0]) for x in self.db.sql_fetchall()]
         if __debug__:
             print >>hyperdb.DEBUG, 'find ... ', l
         return l
@@ -1983,7 +1955,8 @@
         if __debug__:
             print >>hyperdb.DEBUG, 'getnodeids', (self, sql, retired)
         self.db.cursor.execute(sql, args)
-        ids = [x[0] for x in self.db.cursor.fetchall()]
+        # XXX numeric ids
+        ids = [str(x[0]) for x in self.db.cursor.fetchall()]
         return ids
 
     def filter(self, search_matches, filterspec, sort=(None,None),
@@ -2077,20 +2050,21 @@
                         where.append('_%s=%s'%(k, a))
                         args.append(v)
             elif isinstance(propclass, Date):
+                dc = self.db.hyperdb_to_sql_value[hyperdb.Date]
                 if isinstance(v, type([])):
                     s = ','.join([a for x in v])
                     where.append('_%s in (%s)'%(k, s))
-                    args = args + [date.Date(x).serialise() for x in v]
+                    args = args + [dc(date.Date(v)) for x in v]
                 else:
                     try:
                         # Try to filter on range of dates
                         date_rng = Range(v, date.Date, offset=timezone)
-                        if (date_rng.from_value):
+                        if date_rng.from_value:
                             where.append('_%s >= %s'%(k, a))                            
-                            args.append(date_rng.from_value.serialise())
-                        if (date_rng.to_value):
+                            args.append(dc(date_rng.from_value))
+                        if date_rng.to_value:
                             where.append('_%s <= %s'%(k, a))
-                            args.append(date_rng.to_value.serialise())
+                            args.append(dc(date_rng.to_value))
                     except ValueError:
                         # If range creation fails - ignore that search parameter
                         pass                        
@@ -2103,10 +2077,10 @@
                     try:
                         # Try to filter on range of intervals
                         date_rng = Range(v, date.Interval)
-                        if (date_rng.from_value):
+                        if date_rng.from_value:
                             where.append('_%s >= %s'%(k, a))
                             args.append(date_rng.from_value.serialise())
-                        if (date_rng.to_value):
+                        if date_rng.to_value:
                             where.append('_%s <= %s'%(k, a))
                             args.append(date_rng.to_value.serialise())
                     except ValueError:
@@ -2188,7 +2162,8 @@
         l = self.db.sql_fetchall()
 
         # return the IDs (the first column)
-        return [row[0] for row in l]
+        # XXX numeric ids
+        return [str(row[0]) for row in l]
 
     def count(self):
         '''Get the number of nodes in this class.

Roundup Issue Tracker: http://roundup-tracker.org/