Skip to content

Commit 30cf0e7

Browse files
kulinsethrazarmehr
andauthored
[MPS] Copy fixes for MPS backend (#95321)
* [MPS] Handle broadcasting by expanding src tensor in Copy.mm (#95272) Fixes #ISSUE_NUMBER Pull Request resolved: #95272 Approved by: https://github.com/DenisVieriu97 * [MPS] Fix copy_cast_mps() on tensors with storage offset (#95093) - The copy_cast path requires storage_offset to be applied before casting - This should fix some correctness issues in transformer models Fixes #94980 Pull Request resolved: #95093 Approved by: https://github.com/kulinseth --------- Co-authored-by: Ramin Azarmehr <razarmehr@apple.com>
1 parent 96f627d commit 30cf0e7

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

aten/src/ATen/native/mps/operations/Copy.mm

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,11 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
251251
bool returnGatherOutput = dst_.is_contiguous();
252252
Tensor src;
253253
auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
254+
const bool sameDataType = src_.dtype() == dst_.dtype();
254255

255-
if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
256+
if ((!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) ||
257+
// the copy_cast path requires storage_offset to be applied before casting
258+
(src_.storage_offset() && !sameDataType)) {
256259
Tensor emptyShell = Tensor();
257260
src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell);
258261

@@ -282,7 +285,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
282285
src._set_neg(src_.is_neg());
283286

284287
const size_t src_size = src.nbytes();
285-
if (src.dtype() == dst_.dtype()) {
288+
if (sameDataType) {
286289
MPSStream* stream = getCurrentMPSStream();
287290
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
288291
stream->copy(sourceBuffer, destBuffer, src_size, src_byte_offset, dst_byte_offset);
@@ -297,22 +300,27 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
297300
TORCH_CHECK(dst.defined(), "dst is undefined");
298301
TORCH_CHECK(src.defined(), "src is undefined");
299302

303+
bool needs_broadcasting = false;
304+
300305
if (src.numel() == 0 || dst.is_same(src)) {
301306
return dst;
302307
}
303308
if (dst.numel() == 0) {
304309
dst.resize_as_(src);
305310
}
311+
if (dst.dim() > src.dim()) {
312+
needs_broadcasting = true;
313+
}
306314

307315
if (src.device().type() == at::kMPS && dst.device().type() == at::kCPU) {
308-
return copy_from_mps_(dst, src, non_blocking);
316+
return copy_from_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
309317
}
310318
if (src.device().type() == at::kCPU && dst.device().type() == at::kMPS) {
311-
return copy_to_mps_(dst, src, non_blocking);
319+
return copy_to_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
312320
}
313321

314322
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
315-
return copy_kernel_mps(dst, src, non_blocking);
323+
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
316324
}
317325
TORCH_INTERNAL_ASSERT(
318326
src.device().type() == DeviceType::MPS,

test/test_mps.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,6 +1786,15 @@ def test_slice_reshape(self):
17861786
x_cpu = x_cpu + 2
17871787
self.assertEqual(x, x_cpu)
17881788

1789+
def test_slice_casting(self):
1790+
# generate random binary numbers
1791+
cpu_in = torch.bernoulli(torch.empty(1, 1, 128, 128).uniform_(0, 1)).to(torch.uint8)
1792+
mps_in = cpu_in.detach().clone().to("mps")
1793+
# check copy_cast(unit8 -> bool) on tensors with storage offset
1794+
cpu_out = cpu_in[:, :, 11 : 12, :12].to(torch.bool)
1795+
mps_out = mps_in[:, :, 11 : 12, :12].to(torch.bool)
1796+
self.assertEqual(cpu_out, mps_out)
1797+
17891798
def test_slice_reshape_contg_view(self):
17901799
import torch
17911800

@@ -9304,6 +9313,7 @@ class TestConsistency(TestCaseMPS):
93049313
'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
93059314
'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
93069315
'linalg.matrix_norm': ['f16'],
9316+
'linalg.matrix_power': ['f32'],
93079317
'linalg.svd': ['f32'],
93089318
'linalg.vector_norm': ['f16', 'f32'],
93099319
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],

0 commit comments

Comments
 (0)