Skip to content

Commit beb4f5b

Browse files
razarmehrpytorchmergebot
authored andcommitted
[MPS] Add Python Module Bindings for the MPS backend (#94417)
- This PR is a prerequisite for the upcoming Memory Leak Detection PR. - Enable global manual seeding via `torch.manual_seed()` + test case - Add `torch.mps.synchronize()` to wait for MPS stream to finish + test case - Enable the following python interfaces for MPS: `torch.mps.[get_rng_state(), set_rng_state(), synchronize(), manual_seed(), seed()]` - Added some test cases in test_mps.py - Added `mps.rst` to document the `torch.mps` module. - Fixed the failure with `test_public_bindings.py` Description of new files added: - `torch/csrc/mps/Module.cpp`: implements `torch._C` module functions for `torch.mps` and `torch.backends.mps`. - `torch/mps/__init__.py`: implements Python bindings for `torch.mps` module. Pull Request resolved: #94417 Approved by: https://github.com/albanD
1 parent d0cff06 commit beb4f5b

File tree

15 files changed

+225
-16
lines changed

15 files changed

+225
-16
lines changed

aten/src/ATen/detail/MPSHooksInterface.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,21 @@ struct TORCH_API MPSHooksInterface {
2828
return false;
2929
}
3030

31+
virtual bool isOnMacOS13orNewer() const {
32+
AT_ERROR("MPS backend is not available.");
33+
}
34+
3135
virtual const Generator& getDefaultMPSGenerator() const {
3236
AT_ERROR("Cannot get default MPS generator without MPS backend.");
3337
}
3438

3539
virtual Allocator* getMPSDeviceAllocator() const {
3640
AT_ERROR("MPSDeviceAllocator requires MPS.");
3741
}
42+
43+
virtual void deviceSynchronize() const {
44+
AT_ERROR("Cannot synchronize MPS device without MPS backend.");
45+
}
3846
};
3947

4048
struct TORCH_API MPSHooksArgs {};

aten/src/ATen/mps/MPSDevice.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class TORCH_API MPSDevice {
7878

7979
TORCH_API bool is_available();
8080
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
81-
81+
TORCH_API void device_synchronize();
8282
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
8383

8484
} // namespace mps

aten/src/ATen/mps/MPSDevice.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <c10/util/CallOnce.h>
44

55
#include <ATen/mps/MPSDevice.h>
6+
#include <ATen/mps/MPSStream.h>
67
#include <ATen/mps/MPSAllocatorInterface.h>
78
#include <ATen/mps/IndexKernels.h>
89

@@ -118,5 +119,9 @@ bool is_macos_13_or_newer(MacOSVersion version) {
118119
return MPSDevice::getInstance()->isMacOS13Plus(version);
119120
}
120121

122+
void device_synchronize() {
123+
getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
124+
}
125+
121126
} // namespace mps
122127
} // namespace at

aten/src/ATen/mps/MPSHooks.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ bool MPSHooks::hasMPS() const {
1616
return at::mps::is_available();
1717
}
1818

19+
bool MPSHooks::isOnMacOS13orNewer() const {
20+
return at::mps::is_macos_13_or_newer();
21+
}
22+
1923
Allocator* MPSHooks::getMPSDeviceAllocator() const {
2024
return at::mps::GetMPSAllocator();
2125
}
@@ -24,6 +28,10 @@ const Generator& MPSHooks::getDefaultMPSGenerator() const {
2428
return at::mps::detail::getDefaultMPSGenerator();
2529
}
2630

31+
void MPSHooks::deviceSynchronize() const {
32+
at::mps::device_synchronize();
33+
}
34+
2735
using at::MPSHooksRegistry;
2836
using at::RegistererMPSHooksRegistry;
2937

aten/src/ATen/mps/MPSHooks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ struct MPSHooks : public at::MPSHooksInterface {
1313
MPSHooks(at::MPSHooksArgs) {}
1414
void initMPS() const override;
1515
bool hasMPS() const override;
16+
bool isOnMacOS13orNewer() const override;
1617
Allocator* getMPSDeviceAllocator() const override;
1718
const Generator& getDefaultMPSGenerator() const override;
19+
void deviceSynchronize() const override;
1820
};
1921

2022
}} // at::mps

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ libtorch_python_core_sources = [
822822
"torch/csrc/dynamo/guards.cpp",
823823
"torch/csrc/dynamo/init.cpp",
824824
"torch/csrc/functorch/init.cpp",
825+
"torch/csrc/mps/Module.cpp",
825826
"torch/csrc/jit/backends/backend_init.cpp",
826827
"torch/csrc/jit/python/init.cpp",
827828
"torch/csrc/jit/passes/onnx.cpp",

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ Features described in this documentation are classified by release status:
8181
torch.autograd <autograd>
8282
torch.library <library>
8383
cuda
84+
mps
8485
torch.backends <backends>
8586
torch.distributed <distributed>
8687
torch.distributed.algorithms.join <distributed.algorithms.join>

docs/source/mps.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
torch.mps
2+
===================================
3+
.. automodule:: torch.mps
4+
.. currentmodule:: torch.mps
5+
6+
.. autosummary::
7+
:toctree: generated
8+
:nosignatures:
9+
10+
synchronize
11+
get_rng_state
12+
set_rng_state
13+
manual_seed
14+
seed

test/test_mps.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5836,6 +5836,45 @@ def test_mps_generator(self):
58365836
mps_x = torch.randn(5, device='mps', generator=g_mps)
58375837
self.assertEqual(mps_x, mps_y)
58385838

5839+
def test_default_mps_generator(self):
5840+
# manual seeding on the "default" MPS generator using
5841+
# the global torch.manual_seed()
5842+
torch.manual_seed(230)
5843+
mps_x = torch.randn(5, device='mps')
5844+
# manual seeding using torch.mps.manual_seed()
5845+
# which should set the "default" MPS generator
5846+
# like the global torch.manual_seed()
5847+
torch.mps.manual_seed(230)
5848+
mps_y = torch.randn(5, device='mps')
5849+
# seed values were the same, so the random tensor contents should match
5850+
self.assertEqual(mps_x, mps_y)
5851+
5852+
# save the default generator's state to restore it later
5853+
g_state = torch.mps.get_rng_state()
5854+
5855+
# generate random numbers without seeding
5856+
mps_x = torch.randn(5, device='mps')
5857+
# in this case, the random results must differ from the last generated random results
5858+
self.assertNotEqual(mps_x, mps_y)
5859+
5860+
# restore the previously saved state, and the results should match again
5861+
torch.mps.set_rng_state(g_state)
5862+
mps_x = torch.randn(5, device='mps')
5863+
self.assertEqual(mps_x, mps_y)
5864+
5865+
def test_device_synchronize(self):
5866+
# just running some ops each followed by a synchronize to wait for
5867+
# MPS stream to finish running each of them
5868+
net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
5869+
.to(device='mps', dtype=torch.float)
5870+
5871+
x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
5872+
torch.mps.synchronize()
5873+
x = net1(x)
5874+
torch.mps.synchronize()
5875+
x.backward(torch.randn_like(x))
5876+
torch.mps.synchronize()
5877+
58395878
# Test random_.to and random_.from
58405879
def test_random(self):
58415880
def helper(shape, low, high, dtype=torch.int32):

torch/_C/__init__.pyi.in

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,6 @@ def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: T
903903
def _disabled_torch_dispatch_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_dispatch_function
904904
def _get_linalg_preferred_backend() -> torch._C._LinalgBackend: ...
905905
def _set_linalg_preferred_backend(arg: torch._C._LinalgBackend): ...
906-
def _is_mps_available() -> _bool: ...
907-
def _is_mps_on_macos_13_or_newer() -> _bool: ...
908906
class _LinalgBackend:
909907
Default: _LinalgBackend
910908
Cusolver: _LinalgBackend
@@ -1200,6 +1198,12 @@ class _TensorBase(metaclass=_TensorMeta):
12001198
# Defined in torch/csrc/multiprocessing/init.cpp
12011199
def _multiprocessing_init() -> None: ...
12021200

1201+
# Defined in torch/csrc/mps/Module.cpp
1202+
def _mps_synchronize() -> None: ...
1203+
def _mps_get_default_generator() -> Generator: ...
1204+
def _is_mps_available() -> _bool: ...
1205+
def _is_mps_on_macos_13_or_newer() -> _bool: ...
1206+
12031207
# Defined in torch/csrc/cuda/Module.cpp
12041208
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
12051209
def _cuda_getCurrentRawStream(device: _int) -> _int: ...

0 commit comments

Comments
 (0)