@@ -2279,6 +2279,51 @@ def test_binops_dtype_precedence(self):
22792279 getattr (torch .tensor (val1 , dtype = dtype1 , device = 'cpu' ), binop )
22802280 (torch .full (full_shape , val2 , dtype = dtype2 , device = 'cpu' )))
22812281
2282+ def test_nansum (self ):
2283+ def helper (dtype , noncontiguous , dim ):
2284+ zero_cpu = torch .zeros ((), dtype = dtype )
2285+
2286+ # Randomly scale the values
2287+ scale = random .randint (10 , 100 )
2288+ x_cpu : torch .Tensor = make_tensor (
2289+ (5 , 5 ), dtype = dtype , device = 'cpu' ,
2290+ low = - scale , high = scale , noncontiguous = noncontiguous )
2291+
2292+ if dtype .is_floating_point :
2293+ nan_mask_cpu = x_cpu < (0.2 * scale )
2294+ x_no_nan_cpu = torch .where (nan_mask_cpu , zero_cpu , x_cpu )
2295+ x_cpu [nan_mask_cpu ] = np .nan
2296+ else :
2297+ x_no_nan_cpu = x_cpu
2298+
2299+ x_mps = x_cpu .to ('mps' )
2300+ actual_out_mps = torch .empty (0 , dtype = dtype , device = 'mps' )
2301+ expect_out_cpu = torch .empty (0 , dtype = dtype )
2302+ dim_kwargs = {"dim" : dim } if dim is not None else {}
2303+ expect = torch .sum (x_no_nan_cpu , ** dim_kwargs )
2304+
2305+ actual_cpu = torch .nansum (x_cpu , ** dim_kwargs )
2306+ # Sanity check on CPU
2307+ self .assertEqual (expect , actual_cpu )
2308+
2309+ # Test MPS
2310+ actual_mps = torch .nansum (x_mps , ** dim_kwargs )
2311+ # Test out= variant
2312+ torch .nansum (x_mps , out = actual_out_mps , ** dim_kwargs )
2313+ torch .nansum (x_cpu , out = expect_out_cpu , ** dim_kwargs )
2314+ self .assertEqual (expect , actual_mps )
2315+ self .assertEqual (expect_out_cpu , actual_out_mps )
2316+
2317+ args = itertools .product (
2318+ (torch .float16 , torch .float32 , torch .int32 , torch .int64 ), # dtype
2319+ (True , False ), # noncontiguous
2320+ (0 , 1 , None ), # dim
2321+ )
2322+
2323+ for dtype , noncontiguous , dim in args :
2324+ with self .subTest (dtype = dtype , noncontiguous = noncontiguous , dim = dim ):
2325+ helper (dtype , noncontiguous , dim )
2326+
22822327
22832328class TestLogical (TestCase ):
22842329 def _wrap_tensor (self , x , device = "cpu" , dtype = None , requires_grad = False ):
@@ -8252,53 +8297,6 @@ def test_serialization_map_location(self):
82528297 self .assertEqual (x2 .device .type , "cpu" )
82538298
82548299
8255- class TestNanSum (TestCase ):
8256-
8257- def helper (self , dtype , noncontiguous , dim ):
8258- zero_cpu = torch .zeros ((), dtype = dtype )
8259-
8260- # Randomly scale the values
8261- scale = random .randint (10 , 100 )
8262- x_cpu : torch .Tensor = make_tensor (
8263- (5 , 5 ), dtype = dtype , device = 'cpu' ,
8264- low = - scale , high = scale , noncontiguous = noncontiguous )
8265-
8266- if dtype .is_floating_point :
8267- nan_mask_cpu = x_cpu < (0.2 * scale )
8268- x_no_nan_cpu = torch .where (nan_mask_cpu , zero_cpu , x_cpu )
8269- x_cpu [nan_mask_cpu ] = np .nan
8270- else :
8271- x_no_nan_cpu = x_cpu
8272-
8273- x_mps = x_cpu .to ('mps' )
8274- actual_out_mps = torch .empty (0 , dtype = dtype , device = 'mps' )
8275- expect_out_cpu = torch .empty (0 , dtype = dtype )
8276- dim_kwargs = {"dim" : dim } if dim is not None else {}
8277- expect = torch .sum (x_no_nan_cpu , ** dim_kwargs )
8278-
8279- actual_cpu = torch .nansum (x_cpu , ** dim_kwargs )
8280- # Sanity check on CPU
8281- self .assertEqual (expect , actual_cpu )
8282-
8283- # Test MPS
8284- actual_mps = torch .nansum (x_mps , ** dim_kwargs )
8285- # Test out= variant
8286- torch .nansum (x_mps , out = actual_out_mps , ** dim_kwargs )
8287- torch .nansum (x_cpu , out = expect_out_cpu , ** dim_kwargs )
8288- self .assertEqual (expect , actual_mps )
8289- self .assertEqual (expect_out_cpu , actual_out_mps )
8290-
8291- def test_nansum (self ):
8292- args = itertools .product (
8293- (torch .float16 , torch .float32 , torch .int32 , torch .int64 ), # dtype
8294- (True , False ), # noncontiguous
8295- (0 , 1 , None ), # dim
8296- )
8297-
8298- for dtype , noncontiguous , dim in args :
8299- with self .subTest (dtype = dtype , noncontiguous = noncontiguous , dim = dim ):
8300- self .helper (dtype , noncontiguous , dim )
8301-
83028300MPS_DTYPES = get_all_dtypes ()
83038301for t in [torch .double , torch .cdouble , torch .cfloat , torch .int8 , torch .bfloat16 ]:
83048302 del MPS_DTYPES [MPS_DTYPES .index (t )]
0 commit comments