99from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
1010from torch_tensorrt .dynamo .conversion .converter_utils import (
1111 broadcastable ,
12+ cast_trt_tensor ,
1213 get_positive_dim ,
1314 get_trt_tensor ,
1415 to_numpy ,
2021 set_layer_name ,
2122)
2223from torch_tensorrt .fx .types import Shape , TRTTensor
24+ from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
2325
2426_LOGGER : logging .Logger = logging .getLogger (__name__ )
2527
@@ -378,8 +380,8 @@ def scatter_value(
378380 source_ir : Optional [SourceIR ],
379381 name : str ,
380382 input : TRTTensor ,
381- dim : Shape ,
382- index : Shape ,
383+ dim : int ,
384+ index : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
383385 value : float ,
384386) -> TRTTensor :
385387 if not isinstance (input , TRTTensor ):
@@ -389,26 +391,34 @@ def scatter_value(
389391 )
390392 input_shape = input .shape
391393 index_shape = index .shape
394+ index_shape_list = list (index .shape )
395+ if not (isinstance (index , TRTTensor )):
396+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
392397 if len (input_shape ) != len (index_shape ):
393398 raise RuntimeError (f"The no of dimensions of input and index should be equal" )
394- ranks = len (input_shape )
395- dim = get_positive_dim (cast (int , dim ), ranks )
399+ dim = get_positive_dim (dim , len (input_shape ))
396400 dynamic_shape = has_dynamic_shape (input .shape )
397401 if dynamic_shape :
398402 # Check whether slice target dim is dynamic shape dim
399403 assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
400404
401- input_dims = len (input . shape )
405+ input_dims = len (input_shape )
402406 for i in range (0 , input_dims ):
403- if index [i ] >= input .shape [i ]:
407+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
404408 raise RuntimeError (
405- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
409+ f"cannot have index size greater than the input size along dimension { dim } "
406410 )
407- value_tensor = value * torch .ones (index .shape )
411+
412+ value_tensor = get_trt_tensor (
413+ ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
414+ )
415+ value_tensor = cast_trt_tensor (
416+ ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
417+ )
408418 scatter_layer = ctx .net .add_scatter (
409- input , index , value_tensor , trt .tensorrt . ScatterModekELEMENT
419+ input , index , value_tensor , trt .ScatterMode . ELEMENT
410420 )
411- scatter_layer .set_axis ( dim )
421+ scatter_layer .axis = dim
412422 set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
413423 out = scatter_layer .get_output (0 )
414424 return out
@@ -432,6 +442,8 @@ def scatter_src(
432442 input_shape = input .shape
433443 index_shape = index .shape
434444 src_shape = src .shape
445+ if not (isinstance (index , TRTTensor )):
446+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
435447 if len (input_shape ) != len (index_shape ):
436448 raise RuntimeError (f"The no of dimensions of input and index should be equal" )
437449 if len (index_shape ) != len (src_shape ):
@@ -445,14 +457,23 @@ def scatter_src(
445457 assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
446458
447459 for i in range (0 , input_dims ):
448- if index [i ] >= input .shape [i ]:
460+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
449461 raise RuntimeError (
450- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
462+ f"cannot have index size greater than the input size along dimension { dim } "
451463 )
464+ input_dtype = input .dtype
465+ # required for cases where src is a constant
466+ src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
467+ if input_dtype != src_dtype :
468+ raise RuntimeError (f"The type of input and src should be made" )
469+ src_tensor = src
470+ if not (isinstance (src , TRTTensor )):
471+ src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
472+
452473 scatter_layer = ctx .net .add_scatter (
453- input , index , src , trt .tensorrt . ScatterModekELEMENT
474+ input , index , src_tensor , trt .ScatterMode . ELEMENT
454475 )
455- scatter_layer .set_axis ( dim )
476+ scatter_layer .axis = dim
456477 set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
457478 out = scatter_layer .get_output (0 )
458479 return out
0 commit comments