@@ -90,7 +90,7 @@ def index(
9090 # is_numpy is a flag to specify if all the indices are numpy or torchTensor.
9191 # If any is not this flag will be set to False
9292 _LOGGER .debug (
93- f "Determining whether aten.index constant-index optimization can be invoked"
93+ "Determining whether aten.index constant-index optimization can be invoked"
9494 )
9595 is_numpy = all (
9696 isinstance (ind , (torch .Tensor , np .ndarray )) for ind in index if ind is not None
@@ -123,7 +123,7 @@ def index(
123123 return identity_layer .get_output (0 )
124124 elif len (tensor_indices ) == 1 :
125125 indices_tensor = get_trt_tensor (
126- ctx , tensor_indices [0 ], name + f "_parameter_to_fp32_tensor"
126+ ctx , tensor_indices [0 ], name + "_parameter_to_fp32_tensor"
127127 )
128128 index = adv_indx_indices [0 ]
129129 _LOGGER .debug (f"The advanced index indices is { adv_indx_indices } " )
@@ -204,7 +204,7 @@ def index(
204204 cum_adv_index = cum_adv_index + adv_index
205205 multiplier = multiplier * input_shape [adv_indx_indices [i ]]
206206 cum_adv_index = get_trt_tensor (
207- ctx , cum_adv_index , name + f "_index_sum_intermediate"
207+ ctx , cum_adv_index , name + "_index_sum_intermediate"
208208 )
209209 else :
210210 multiplier = get_trt_tensor (
@@ -263,7 +263,7 @@ def index(
263263 adv_indx_count
264264 == adv_indx_indices [adv_indx_count - 1 ] - adv_indx_indices [0 ] + 1
265265 ):
266- _LOGGER .debug (f "The indices are continuous in this case" )
266+ _LOGGER .debug ("The indices are continuous in this case" )
267267 concat_tensor_reshape .append (
268268 get_trt_tensor (ctx , - 1 , name + "_dynamic_concat" )
269269 )
@@ -287,7 +287,7 @@ def index(
287287 source_ir ,
288288 )
289289 unfold_tensor = regular_index_shuffle_layer .get_output (0 )
290- _LOGGER .debug (f "The tensor is unfolded now" )
290+ _LOGGER .debug ("The tensor is unfolded now" )
291291 _LOGGER .debug (f"The unfolded tensor shape is { unfold_tensor .shape } " )
292292
293293 # Transpose folded advanced indexed axis to its original location.
@@ -342,7 +342,7 @@ def index(
342342 reshape_output = unfold_advanced_shuffle_layer .get_output (0 )
343343
344344 else :
345- _LOGGER .debug (f "The indices are not continuous in this case" )
345+ _LOGGER .debug ("The indices are not continuous in this case" )
346346 concat_final_tensor = []
347347 concat_final_tensor .append (cum_adv_index_shape_tensor )
348348 for i in range (0 , rank ):
@@ -370,3 +370,21 @@ def index(
370370 reshape_output = reshape_layer .get_output (0 )
371371
372372 return reshape_output
373+
374+
375+ def index_select (
376+ ctx : ConversionContext ,
377+ target : Target ,
378+ source_ir : Optional [SourceIR ],
379+ name : str ,
380+ input : TRTTensor ,
381+ dim : int ,
382+ index : TRTTensor ,
383+ ) -> TRTTensor :
384+ # The axis parameter specifies the dimension along which to index.
385+ dim = get_positive_dim (dim , len (input .shape ))
386+ gather_layer = ctx .net .add_gather (input , index , axis = dim )
387+
388+ set_layer_name (gather_layer , target , f"{ name } _gather" , source_ir )
389+
390+ return gather_layer .get_output (0 )
0 commit comments