@@ -146,36 +146,55 @@ def summary(self, top: int = 20) -> None:
146146 print ("------------------------------------------------" )
147147
148148 @no_type_check
149- def show_traces (self ) -> None :
150- """
151- Show the traces of ``memory_allocated``, ``memory_active`` and ``memory_reserved`` at
152- operator level and the marker 'fw_bw_boundary' at the boundary of forward pass
153- and backward pass.
154- """
149+ def show_traces (self , path : str = "" ) -> None :
150+ from itertools import chain
151+
155152 import matplotlib .pyplot as plt
156153
157- y_1 = [mb for (name , mb ) in self .memories_allocated .values ()]
158- y_2 = [mb for (name , mb ) in self .memories_active .values ()]
159- y_3 = [mb for (name , mb ) in self .memories_reserved .values ()]
160- min_val = min (y_1 + y_2 + y_3 )
161- max_val = max (y_1 + y_2 + y_3 )
162- x = list (i for i in range (len (y_1 )))
163- fig = plt .figure (figsize = (16 , 8 ))
164- plt .plot (x , list (y_1 ), label = "memory_allocated" )
165- plt .plot (x , list (y_2 ), label = "memory_active" )
166- plt .plot (x , list (y_3 ), label = "memory_reserved" )
167- plt .xlabel ("# Operator Calls" )
168- plt .ylabel ("Memory (MB)" )
169- for marker_name , marker in self ._markers .items ():
170- if marker_name == "fw_bw_boundary" :
171- plt .plot (
172- [marker , marker ], [min_val , max_val ], "r" , lw = 2 , label = marker_name
173- )
174- else :
175- plt .plot (
176- [marker , marker ], [min_val , max_val ], "k-" , lw = 2 , label = marker_name
177- )
178- plt .legend ()
154+ def _plot_figure (x , y_values , labels ):
155+ min_val = min (list (chain (* y_values ))) * 0.999
156+ max_val = max (list (chain (* y_values ))) * 1.001
157+ plt .figure ()
158+ for y , label in zip (y_values , labels ):
159+ plt .plot (x , y , label = label )
160+ plt .xlabel ("# Operator Calls" )
161+ plt .ylabel ("Memory (MB)" )
162+ plt .legend ()
163+ for marker_name , marker in self ._markers .items ():
164+ if marker_name == "fw_bw_boundary" :
165+ plt .plot (
166+ [marker , marker ],
167+ [min_val , max_val ],
168+ "r" ,
169+ lw = 2 ,
170+ label = marker_name ,
171+ )
172+ else :
173+ plt .plot (
174+ [marker , marker ],
175+ [min_val , max_val ],
176+ "k-" ,
177+ lw = 2 ,
178+ label = marker_name ,
179+ )
180+
181+ if path != "" :
182+ self .load (path )
183+
184+ y_1 = [gb for (name , gb ) in self .memories_allocated .values ()]
185+ y_2 = [gb for (name , gb ) in self .memories_active .values ()]
186+ y_3 = [gb for (name , gb ) in self .memories_reserved .values ()]
187+ x = list (range (len (y_1 )))
188+ # Split figures when there is big difference between
189+ # "reserved_memory" and "allocated_memory" or "active_memory".
190+ _plot_figure (
191+ x ,
192+ [list (y_1 ), list (y_2 ), list (y_3 )],
193+ ["allocated_memory" , "active_memory" , "reserved_memory" ],
194+ )
195+ _plot_figure (x , [list (y_1 )], ["allocated_memory" ])
196+ _plot_figure (x , [list (y_2 )], ["active_memory" ])
197+ _plot_figure (x , [list (y_3 )], ["reserved_memory" ])
179198
180199 def save_stats (self , path : str ) -> None :
181200 """
@@ -190,7 +209,7 @@ def save_stats(self, path: str) -> None:
190209 }
191210
192211 with open (path , "wb" ) as f :
193- pickle .dump (stats , f )
212+ pickle .dump (stats , f , pickle . HIGHEST_PROTOCOL )
194213
195214 def load (self , path : str ) -> None :
196215 """
0 commit comments