55from torch .fx .node import Target
66from torch_tensorrt .dynamo ._SourceIR import SourceIR
77from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8- from torch_tensorrt .dynamo .conversion .converter_utils import extend_attr_to_tuple
8+ from torch_tensorrt .dynamo .conversion .converter_utils import (
9+ extend_attr_to_tuple ,
10+ get_positive_dim ,
11+ )
912from torch_tensorrt .fx .converters .converter_utils import (
1013 has_dynamic_shape ,
1114 set_layer_name ,
@@ -116,37 +119,69 @@ def adaptive_avg_poolNd(
116119 output_size : Sequence [int ],
117120) -> TRTTensor :
118121 input_rank = len (input .shape )
119- if input_rank == 3 :
120- input = impl .shuffle .reshape (ctx , target , source_ir , f"{ name } _reshape" , input , (1 , * input .shape ))
122+
123+ if input_rank == 3 : # TRT doesn't support 3D pooling
124+ input = impl .shuffle .reshape (
125+ ctx , target , source_ir , f"{ name } _reshape" , input , (1 , * input .shape )
126+ )
121127
122128 extend_len = len (output_size )
129+ output_size = list (output_size )
130+ original_input = input
123131
124- # pad the input based on output_size if the dim of output is larger than input
125- pad = []
132+ # repeat_interleave the input if the dim of output is larger than input
126133 input_shape = input .shape
127- for i in range (1 , extend_len + 1 ):
128- input_dim = input_shape [- i ]
129- output_dim = output_size [- i ]
134+ insert_axises = []
135+ for axis in range (1 , extend_len + 1 ):
136+ axis = - axis
137+ positive_axis = get_positive_dim (
138+ axis , input_rank
139+ ) # this is for calculating new shapes below
140+ input_dim = input_shape [axis ]
141+ output_dim = output_size [axis ]
130142 diff = output_dim - input_dim
131- if diff > 0 :
132- if diff % 2 == 0 :
133- pad .append (diff // 2 )
134- pad .append (diff // 2 )
135- else :
136- pad .append (diff // 2 + 1 )
137- pad .append (diff // 2 + 1 )
138- else :
139- pad .append (0 )
140- pad .append (0 )
141-
142- input = impl .pad .replication_padNd (
143- ctx ,
144- target ,
145- source_ir ,
146- f"{ name } _replication_padNd" ,
147- input ,
148- pad ,
149- )
143+ if diff > 0 : # the dim of output is larger than input
144+ times = output_dim // input_dim
145+ remainder = output_dim % input_dim
146+ if (
147+ diff == 2 and remainder == 2
148+ ): # case 1: output_dim - input_dim == 2 and is not an integral multiple
149+ insert_axises .append (axis )
150+ remainder -= 1
151+ output_size [axis ] -= 1
152+
153+ if (
154+ remainder + 1 == input_dim
155+ ): # case 2: remainder + 1 == input_dim, we will repeat_interleave the whole input
156+ remainder = 0
157+ times += 1
158+
159+ flags = []
160+ concat_list = []
161+ for j in range (input_dim ):
162+ single_elem = impl .select .select (
163+ ctx , target , source_ir , f"{ name } _select_{ axis } _{ j } " , input , axis , j
164+ )
165+ new_shape = list (single_elem .shape )
166+ new_shape .insert (positive_axis , 1 )
167+ single_elem = impl .shuffle .reshape (
168+ ctx ,
169+ target ,
170+ source_ir ,
171+ f"{ name } _reshape_{ axis } _{ j } " ,
172+ single_elem ,
173+ new_shape ,
174+ )
175+ if remainder > 0 or j in flags :
176+ concat_list .extend ([single_elem ] * (times + 1 ))
177+ remainder -= 2
178+ flags .append (input_dim - j - 1 )
179+ else :
180+ concat_list .extend ([single_elem ] * times )
181+ out = impl .cat .cat (
182+ ctx , target , source_ir , f"{ name } _cat_{ axis } " , concat_list , axis
183+ )
184+ input = out
150185
151186 stride = tuple (
152187 input .shape [- extend_len + i ] // output_size [i ] for i in range (extend_len )
@@ -155,6 +190,20 @@ def adaptive_avg_poolNd(
155190 input .shape [- extend_len + i ] - (output_size [i ] - 1 ) * stride [i ]
156191 for i in range (extend_len )
157192 )
193+
194+ # Don't have to pool, directly return
195+ if all (s == 1 for s in stride ) and all (k == 1 for k in kernel_size ):
196+ if input_rank == 3 : # reshape back to 3D
197+ input = impl .shuffle .reshape (
198+ ctx ,
199+ target ,
200+ source_ir ,
201+ f"{ name } _reshape_back" ,
202+ input ,
203+ (* input .shape [1 :],),
204+ )
205+ return input
206+
158207 layer = ctx .net .add_pooling_nd (
159208 input = input , type = trt .PoolingType .AVERAGE , window_size = kernel_size
160209 )
@@ -163,7 +212,78 @@ def adaptive_avg_poolNd(
163212
164213 output = layer .get_output (0 )
165214
166- if input_rank == 3 :
167- output = impl .shuffle .reshape (ctx , target , source_ir , f"{ name } _reshape_back" , output , (* output .shape [1 :],))
215+ # For case 1, we need to split the output and insert the mid of input
216+ for axis in insert_axises :
217+ positive_axis = get_positive_dim (axis , input_rank )
218+ input_dim = input_shape [axis ]
219+ output_dim = output_size [axis ]
220+ if input_dim % 2 == 1 :
221+ mid = impl .select .select (
222+ ctx ,
223+ target ,
224+ source_ir ,
225+ f"{ name } _select_{ axis } " ,
226+ original_input ,
227+ axis ,
228+ input_dim // 2 ,
229+ )
230+ new_shape = list (mid .shape )
231+ new_shape .insert (positive_axis , 1 )
232+ mid = impl .shuffle .reshape (
233+ ctx , target , source_ir , f"{ name } _reshape_{ axis } " , mid , new_shape
234+ )
235+ split_output = impl .split .split (
236+ ctx , target , source_ir , f"{ name } _split_{ axis } " , output , 2 , axis
237+ )
238+ split_output .insert (1 , mid )
239+ output = impl .cat .cat (
240+ ctx , target , source_ir , f"{ name } _cat_{ axis } " , split_output , axis
241+ )
242+ else :
243+ mid1 = impl .select .select (
244+ ctx ,
245+ target ,
246+ source_ir ,
247+ f"{ name } _select_{ axis } " ,
248+ original_input ,
249+ axis ,
250+ input_dim // 2 - 1 ,
251+ )
252+ new_shape = list (mid1 .shape )
253+ new_shape .insert (positive_axis , 1 )
254+ mid1 = impl .shuffle .reshape (
255+ ctx , target , source_ir , f"{ name } _reshape_{ axis } " , mid1 , new_shape
256+ )
257+ mid2 = impl .select .select (
258+ ctx ,
259+ target ,
260+ source_ir ,
261+ f"{ name } _select_{ axis } " ,
262+ original_input ,
263+ axis ,
264+ input_dim // 2 ,
265+ )
266+ mid2 = impl .shuffle .reshape (
267+ ctx , target , source_ir , f"{ name } _reshape_{ axis } " , mid2 , new_shape
268+ )
269+ split_output = impl .split .split (
270+ ctx ,
271+ target ,
272+ source_ir ,
273+ f"{ name } _split_{ axis } " ,
274+ output ,
275+ [output_dim // 2 , 1 , output_dim // 2 ],
276+ axis ,
277+ )
278+ split_output [1 ] = mid1
279+ split_output .insert (2 , mid2 )
280+ output = impl .cat .cat (
281+ ctx , target , source_ir , f"{ name } _cat_{ axis } " , split_output , axis
282+ )
283+
284+ if input_rank == 3 : # reshape back to 3D
285+ output = impl .shuffle .reshape (
286+ ctx , target , source_ir , f"{ name } _reshape_back" , output , (* output .shape [1 :],)
287+ )
168288
169289 return output
0 commit comments