@@ -199,7 +199,7 @@ def insert (self, values ):
199199 sql += " VALUES "
200200 sql += self .as_list (values )
201201
202- self .glue .mysql_query ( sql )
202+ self .glue .mysql_update ( sql )
203203
204204
205205
@@ -234,7 +234,7 @@ def insert (self, values ):
234234
235235 def flush (self ):
236236 if len (self .buffer )> 0 :
237- self .glue .mysql_query ( self .buffer )
237+ self .glue .mysql_update ( self .buffer )
238238 self .buffer = ""
239239
240240
@@ -273,7 +273,7 @@ def __init__( self, inserter, glue, table ):
273273 def drop (self ):
274274 sql = "DROP TEMPORARY TABLE IF EXISTS %s" % self .table .get_name ()
275275
276- ok = self .glue .mysql_query ( sql )
276+ ok = self .glue .mysql_update ( sql )
277277 return ok
278278
279279
@@ -302,7 +302,9 @@ def __init__(self, transport, graphname = None ):
302302 super (MySQLGlue , self ).__init__ (transport , graphname )
303303
304304 self .connection = None
305+
305306 self .unbuffered = False
307+ self ._update_cursor = None
306308
307309 self .temp_table_prefix = "gp_temp_"
308310 self .temp_table_db = None
@@ -341,30 +343,48 @@ def mysql_connect( self, server, username, password, db, port = 3306 ):
341343 def mysql_unbuffered_query ( self , sql ):
342344 return self .mysql_query ( sql , True )
343345
344- def mysql_query ( self , sql , unbuffered = None , dict_rows = False ):
346+ def mysql_update ( self , sql ): #TODO: port to PHP; use in PHP!
347+ if not self .update_cursor :
348+ self ._update_cursor = MySQLdb .cursors .SSCursor (self .connection )
349+
350+ self .mysql_query ( sql , True , False , self ._update_cursor )
351+
352+ return self .connection .affected_rows ()
353+
354+ def mysql_query ( self , sql , unbuffered = None , dict_rows = False , cursor = None ):
345355 if unbuffered is None :
346356 unbuffered = self .unbuffered
347357
348- if unbuffered :
349- if dict_rows :
350- # no buffering, returns dicts
351- cursor = MySQLdb .cursors .SSDictCursor (self .connection ) # TESTME
352- else :
353- # no buffering, returns tuples
354- cursor = MySQLdb .cursors .SSCursor (self .connection ) # TESTME
358+ if cursor :
359+ using_new_cursor = False
355360 else :
356- if dict_rows :
357- # buffers result, returns dicts
358- cursor = MySQLdb .cursors .DictCursor (self .connection ) # TESTME
361+ using_new_cursor = True
362+
363+ if unbuffered :
364+ if dict_rows :
365+ # no buffering, returns dicts
366+ cursor = MySQLdb .cursors .SSDictCursor (self .connection ) # TESTME
367+ else :
368+ # no buffering, returns tuples
369+ cursor = MySQLdb .cursors .SSCursor (self .connection ) # TESTME
359370 else :
360- # default: buffered tuples
361- cursor = MySQLdb .cursors .Cursor (self .connection )
371+ if dict_rows :
372+ # buffers result, returns dicts
373+ cursor = MySQLdb .cursors .DictCursor (self .connection ) # TESTME
374+ else :
375+ # default: buffered tuples
376+ cursor = MySQLdb .cursors .Cursor (self .connection )
362377
363378 with warnings .catch_warnings ():
364379 #ignore MySQL warnings. use cursor.nfo() to get them.
365380 warnings .simplefilter ("ignore" )
366-
367- cursor .execute ( sql )
381+
382+ try :
383+ cursor .execute ( sql )
384+ except :
385+ if using_new_cursor :
386+ cursor .close () #NOTE: *always* close the cursor if an exception ocurred.
387+ raise
368388
369389 if not dict_rows :
370390 # HACK: glue a fetch_dict method to a cursor that natively returns sequences from fetchone()
@@ -516,7 +536,7 @@ def next_id (self):
516536
517537 def drop_temp_table (self , spec ):
518538 sql = "DROP TEMPORARY TABLE %s" % spec .get_name ()
519- self .mysql_query (sql )
539+ self .mysql_update (sql )
520540
521541
522542 def make_temp_table (self , spec ):
@@ -533,7 +553,7 @@ def make_temp_table (self, spec ):
533553 sql += spec .get_field_definitions ()
534554 sql += ")"
535555
536- self .mysql_query (sql )
556+ self .mysql_update (sql )
537557
538558 return MySQLTable (table , spec .get_fields ())
539559
@@ -546,8 +566,10 @@ def mysql_query_value (self, sql ):
546566
547567 def mysql_query_record (self , sql ):
548568 cursor = self .mysql_query ( sql )
549- a = cursor .fetchone ()
550- cursor .close ()
569+ try :
570+ a = cursor .fetchone ()
571+ finally :
572+ cursor .close ()
551573
552574 if ( not a ): return None
553575 else : return a
@@ -623,7 +645,10 @@ def query_to_file (self, query, file, remote = False ):
623645 query += " INTO %s DATA OUTFILE " % r #TESTME
624646 query += self .quote_string (file )
625647
626- return self .mysql_query (query )
648+ cursor = self .mysql_query (query )
649+ cursor .close ()
650+
651+ return self .connection .affected_rows ()
627652
628653
629654 def insert_from_file (self , table , file , remote = False ):
@@ -634,10 +659,20 @@ def insert_from_file (self, table, file, remote = False ):
634659 query += self .quote_string (file )
635660 query += " INTO TABLE %s " % table
636661
637- return self .mysql_query (query )
662+ cursor = self .mysql_query (query )
663+ cursor .close ()
664+
665+ return self .connection .affected_rows ()
638666
639667
640668 def close (self ):
669+ if self ._update_cursor :
670+ try :
671+ self ._update_cursor .close ()
672+ except Exception as e :
673+ self ._trace (__function__ (), "failed to close mysql cursor: %s" % e )
674+ #XXX: do we really not care? can we go on? could there have been a commit pending?
675+
641676 if self .connection :
642677 try :
643678 self ._trace (__function__ (), "closing mysql connection" )
@@ -656,7 +691,7 @@ def new_client_connection(graphname, host = False, port = False ):
656691
657692 @staticmethod
658693 def new_slave_connection (command , cwd = None , env = None ):
659- return MySQLGlue ( SlaveTransport (command , cwd , env ) )
694+ return MySQLGlue ( SlaveTransport (command , cwd , env ), None )
660695
661696
662697 def dump_query (self , sql ):
@@ -665,7 +700,10 @@ def dump_query (self, sql ):
665700 res = self .mysql_query ( sql )
666701 if ( not res ): return False
667702
668- return self .dump_result ( res )
703+ c = self .dump_result ( res )
704+ res .close ()
705+
706+ return c
669707
670708
671709 def dump_result (self , res ):
0 commit comments