@@ -8039,6 +8039,8 @@ def triplot(self, *args, **kwargs):
80398039 triplot .__doc__ = mtri .triplot .__doc__
80408040
80418041
8042+ from gridspec import GridSpec , SubplotSpec
8043+
80428044class SubplotBase :
80438045 """
80448046 Base class for subplots, which are :class:`Axes` instances with
@@ -8062,90 +8064,61 @@ def __init__(self, fig, *args, **kwargs):
80628064
80638065 self .figure = fig
80648066
8065- if len (args )== 1 :
8066- s = str (args [0 ])
8067- if len (s ) != 3 :
8068- raise ValueError ('Argument to subplot must be a 3 digits long' )
8069- rows , cols , num = map (int , s )
8067+ if len (args ) == 1 :
8068+ if isinstance (args [0 ], SubplotSpec ):
8069+ self ._subplotspec = args [0 ]
8070+
8071+ else :
8072+ s = str (args [0 ])
8073+ if len (s ) != 3 :
8074+ raise ValueError ('Argument to subplot must be a 3 digits long' )
8075+ rows , cols , num = map (int , s )
8076+ self ._subplotspec = GridSpec (rows , cols )[num - 1 ]
8077+ # num - 1 for converting from matlab to python indexing
80708078 elif len (args )== 3 :
80718079 rows , cols , num = args
8080+ if isinstance (num , tuple ) and len (num ) == 2 :
8081+ self ._subplotspec = GridSpec (rows , cols )[num [0 ]- 1 :num [1 ]]
8082+ else :
8083+ self ._subplotspec = GridSpec (rows , cols )[num - 1 ]
8084+ # num - 1 for converting from matlab to python indexing
80728085 else :
80738086 raise ValueError ( 'Illegal argument to subplot' )
80748087
80758088
8076- total = rows * cols
8077- num -= 1 # convert from matlab to python indexing
8078- # ie num in range(0,total)
8079- if num >= total :
8080- raise ValueError ( 'Subplot number exceeds total subplots' )
8081- self ._rows = rows
8082- self ._cols = cols
8083- self ._num = num
8084-
80858089 self .update_params ()
80868090
80878091 # _axes_class is set in the subplot_class_factory
80888092 self ._axes_class .__init__ (self , fig , self .figbox , ** kwargs )
80898093
8094+
8095+
80908096 def get_geometry (self ):
80918097 'get the subplot geometry, eg 2,2,3'
8092- return self ._rows , self ._cols , self ._num + 1
8098+ rows , cols , num1 , num2 = self .get_subplotspec ().get_geometry ()
8099+ return rows , cols , num1 + 1 # for compatibility
80938100
80948101 # COVERAGE NOTE: Never used internally or from examples
80958102 def change_geometry (self , numrows , numcols , num ):
80968103 'change subplot geometry, eg. from 1,1,1 to 2,2,3'
8097- self ._rows = numrows
8098- self ._cols = numcols
8099- self ._num = num - 1
8104+ self ._subplotspec = GridSpec (numrows , numcols )[num - 1 ]
81008105 self .update_params ()
81018106 self .set_position (self .figbox )
81028107
8108+ def get_subplotspec (self ):
8109+ 'get the SubplotSpec instance associated with the subplot'
8110+ return self ._subplotspec
8111+
8112+ def set_subplotspec (self , subplotspec ):
8113+ 'set the SubplotSpec instance associated with the subplot'
8114+ self ._subplotspec = subplotspec
8115+
81038116 def update_params (self ):
81048117 'update the subplot position from fig.subplotpars'
81058118
8106- rows = self ._rows
8107- cols = self ._cols
8108- num = self ._num
8109-
8110- pars = self .figure .subplotpars
8111- left = pars .left
8112- right = pars .right
8113- bottom = pars .bottom
8114- top = pars .top
8115- wspace = pars .wspace
8116- hspace = pars .hspace
8117- totWidth = right - left
8118- totHeight = top - bottom
8119-
8120- figH = totHeight / (rows + hspace * (rows - 1 ))
8121- sepH = hspace * figH
8122-
8123- figW = totWidth / (cols + wspace * (cols - 1 ))
8124- sepW = wspace * figW
8125-
8126- rowNum , colNum = divmod (num , cols )
8127-
8128- figBottom = top - (rowNum + 1 )* figH - rowNum * sepH
8129- figLeft = left + colNum * (figW + sepW )
8130-
8131- self .figbox = mtransforms .Bbox .from_bounds (figLeft , figBottom ,
8132- figW , figH )
8133- self .rowNum = rowNum
8134- self .colNum = colNum
8135- self .numRows = rows
8136- self .numCols = cols
8137-
8138- if 0 :
8139- print 'rcn' , rows , cols , num
8140- print 'lbrt' , left , bottom , right , top
8141- print 'self.figBottom' , self .figBottom
8142- print 'self.figLeft' , self .figLeft
8143- print 'self.figW' , self .figW
8144- print 'self.figH' , self .figH
8145- print 'self.rowNum' , self .rowNum
8146- print 'self.colNum' , self .colNum
8147- print 'self.numRows' , self .numRows
8148- print 'self.numCols' , self .numCols
8119+ self .figbox , self .rowNum , self .colNum , self .numRows , self .numCols = \
8120+ self .get_subplotspec ().get_position (self .figure ,
8121+ return_all = True )
81498122
81508123
81518124 def is_first_col (self ):
0 commit comments