@@ -151,7 +151,7 @@ def __enter__(self):
151151 return self
152152
153153 def __exit__ (self , exc_type , exc_val , exc_tb ):
154- if self .__session .in_transaction :
154+ if self .__session ._in_transaction :
155155 if exc_val is None :
156156 self .__session .commit_transaction ()
157157 else :
@@ -175,8 +175,6 @@ def __init__(self, client, server_session, options, authset):
175175 self ._cluster_time = None
176176 self ._operation_time = None
177177 self ._transaction = None
178- if self .options .auto_start_transaction :
179- self .start_transaction ()
180178
181179 def end_session (self ):
182180 """Finish this session. If a transaction has started, abort it.
@@ -191,7 +189,7 @@ def end_session(self):
191189 def _end_session (self , lock ):
192190 if self ._server_session is not None :
193191 try :
194- if self .in_transaction :
192+ if self ._in_transaction :
195193 self .abort_transaction ()
196194 finally :
197195 self ._client ._return_server_session (self ._server_session , lock )
@@ -259,7 +257,7 @@ def start_transaction(self, read_concern=None, write_concern=None):
259257 """
260258 self ._check_ended ()
261259
262- if self .in_transaction :
260+ if self ._in_transaction :
263261 raise InvalidOperation ("Transaction already in progress" )
264262
265263 read_concern = self ._inherit_option ("read_concern" , read_concern )
@@ -284,7 +282,7 @@ def abort_transaction(self):
284282 def _finish_transaction (self , command_name ):
285283 self ._check_ended ()
286284
287- if not self .in_transaction :
285+ if not self ._in_transaction_or_auto_start () :
288286 raise InvalidOperation ("No transaction started" )
289287
290288 try :
@@ -293,19 +291,12 @@ def _finish_transaction(self, command_name):
293291 self ._server_session ._transaction_id += 1
294292 return
295293
296- write_concern = self ._transaction .opts .write_concern
297- if write_concern is None :
298- write_concern = self .client .write_concern
299-
300294 # TODO: retryable. And it's weird to pass parse_write_concern_error
301295 # from outside database.py.
302296 self ._client .admin .command (
303297 command_name ,
304- txnNumber = self ._server_session .transaction_id ,
305- stmtId = self ._server_session .statement_id ,
306298 session = self ,
307- write_concern = write_concern ,
308- read_preference = ReadPreference .PRIMARY ,
299+ write_concern = self ._transaction .opts .write_concern ,
309300 parse_write_concern_error = True )
310301 finally :
311302 self ._server_session .reset_transaction ()
@@ -361,15 +352,22 @@ def has_ended(self):
361352 return self ._server_session is None
362353
363354 @property
364- def in_transaction (self ):
355+ def _in_transaction (self ):
365356 """True if this session has an active multi-statement transaction."""
366357 return self ._transaction is not None
367358
359+ def _in_transaction_or_auto_start (self ):
360+ """True if this session has an active transaction or will have one."""
361+ if self ._in_transaction :
362+ return True
363+ if self .options .auto_start_transaction :
364+ self .start_transaction ()
365+ return True
366+ return False
367+
368368 def _apply_to (self , command , is_retryable , read_preference ):
369369 self ._check_ended ()
370-
371- if self .options .auto_start_transaction and not self .in_transaction :
372- self .start_transaction ()
370+ self ._in_transaction_or_auto_start ()
373371
374372 self ._server_session .last_use = monotonic .time ()
375373 command ['lsid' ] = self ._server_session .session_id
@@ -379,7 +377,7 @@ def _apply_to(self, command, is_retryable, read_preference):
379377 command ['txnNumber' ] = self ._server_session .transaction_id
380378 return
381379
382- if self .in_transaction :
380+ if self ._in_transaction :
383381 if read_preference != ReadPreference .PRIMARY :
384382 raise InvalidOperation (
385383 'read preference in a transaction must be primary, not: '
0 commit comments