Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions docs/source/distributed.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,31 @@ The machine with rank 0 will be used to set up all connections.
This is the default method, meaning that ``init_method`` does not have to be specified (or
can be ``env://``).

Distributed Key-Value Store
---------------------------

The distributed package comes with a distributed key-value store, which can be
used to share information between processes in the group as well as to
initialize the distributed pacakge in
:func:`torch.distributed.init_process_group` (by explicitly creating the store
as an alternative to specifying ``init_method``.) There are 3 choices for
Key-Value Stores: :class:`~torch.distributed.TCPStore`,
:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`.

.. autoclass:: Store
.. autoclass:: TCPStore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add Store here? Similar to torch.optim where there is documentation about the base class.

.. autoclass:: HashStore
.. autoclass:: FileStore
.. autoclass:: PrefixStore

.. autofunction:: torch.distributed.Store.set
.. autofunction:: torch.distributed.Store.get
.. autofunction:: torch.distributed.Store.add
.. autofunction:: torch.distributed.Store.wait
.. autofunction:: torch.distributed.Store.num_keys
.. autofunction:: torch.distributed.Store.delete_key
.. autofunction:: torch.distributed.Store.set_timeout

Groups
------

Expand Down
233 changes: 219 additions & 14 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,12 @@ They are used in specifying strategies for reduction collectives, e.g.,

auto store =
py::class_<::c10d::Store, std::shared_ptr<::c10d::Store>, PythonStore>(
module, "Store")
module, "Store",
R"(
Base class for all store implementations, such as the 3 provided by PyTorch
distributed: (:class:`~torch.distributed.TCPStore`, :class:`~torch.distributed.FileStore`,
and :class:`~torch.distributed.HashStore`).
)")
// Default constructor.
.def(py::init<>())
// Convert from std::string to std::vector<uint8>.
Expand All @@ -296,7 +301,23 @@ They are used in specifying strategies for reduction collectives, e.g.,
std::vector<uint8_t> value_(value.begin(), value.end());
store.set(key, value_);
},
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
Inserts the key-value pair into the store based on the supplied ``key`` and
``value``. If ``key`` already exists in the store, it will overwrite the old
value with the new supplied ``value``.

Arguments:
key (str): The key to be added to the store.
value (str): The value associated with ``key`` to be added to the store.

Example::
>>> import torch.distributed as dist
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # Should return "first_value"
>>> store.get("first_key")
)")
// Convert from std::vector<uint8_t> to py::bytes.
// The returned value is not guaranteed to be valid UTF-8.
.def(
Expand All @@ -306,46 +327,221 @@ They are used in specifying strategies for reduction collectives, e.g.,
return py::bytes(
reinterpret_cast<char*>(value.data()), value.size());
},
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
Retrieves the value associated with the given ``key`` in the store. If ``key`` is not
present in the store, the function will wait for ``timeout``, which is defined
when initializing the store, before throwing an exception.

Arguments:
key (str): The function will return the value associated with this key.

Returns:
Value associated with ``key`` if ``key`` is in the store.

Example::
>>> import torch.distributed as dist
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # Should return "first_value"
>>> store.get("first_key")
)")
.def(
"add",
&::c10d::Store::add,
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
The first call to add for a given ``key`` creates a counter associated
with ``key`` in the store, initialized to ``amount``. Subsequent calls to add
with the same ``key`` increment the counter by the specified ``amount``.
Calling :meth:`~torch.distributed.store.add` with a key that has already
been set in the store by :meth:`~torch.distributed.store.set` will result
in an exception.

Arguments:
key (str): The key in the store whose counter will be incremented.
amount (int): The quantity by which the counter will be incremented.

Example::
>>> import torch.distributed as dist
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> store.add("first_key", 1)
>>> store.add("first_key", 6)
>>> # Should return 7
>>> store.get("first_key")
)")
.def(
"delete_key",
&::c10d::Store::deleteKey,
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
Deletes the key-value pair associated with ``key`` from the store. Returns
`true` if the key was successfully deleted, and `false` if it was not.

.. warning::
The ``delete_key`` API is only supported by the :class:`~torch.distributed.TCPStore`. Using this API
with the :class:`~torch.distributed.FileStore` or :class:`~torch.distributed.HashStore` will result in an exception.

Arguments:
key (str): The key to be deleted from the store

Returns:
`true` if ``key`` was deleted, otherwise `false`.

Example::
>>> import torch.distributed as dist
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> store.set("first_key")
>>> # This should return true
>>> store.delete_key("first_key")
>>> # This should return false
>>> store.delete_key("bad_key")
)")
.def(
"num_keys",
&::c10d::Store::getNumKeys,
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
Returns the number of keys set in the store. Note that this number will typically
be one greater than the number of keys added by :meth:`~torch.distributed.store.set`
and :meth:`~torch.distributed.store.add` since one key is used to coordinate all
the workers using the store.

.. warning::
The ``num_keys`` API is only supported by the :class:`~torch.distributed.TCPStore`. Using this API
with the :class:`~torch.distributed.FileStore` or :class:`~torch.distributed.HashStore` will result in an exception.

Returns:
The number of keys present in the store.

Example::
>>> import torch.distributed as dist
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> store.set("first_key", "first_value")
>>> # This should return 2
>>> store.num_keys()
)")
.def(
"set_timeout",
&::c10d::Store::setTimeout,
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
Sets the store's default timeout. This timeout is used during initialization and in
:meth:`~torch.distributed.store.wait` and :meth:`~torch.distributed.store.get`.

Arguments:
timeout (timedelta): timeout to be set in the store.

Example::
>>> import torch.distributed as dist
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> store.set_timeout(timedelta(seconds=10))
>>> # This will throw an exception after 10 seconds
>>> store.wait(["bad_key"])
)")
.def(
"wait",
[](::c10d::Store& store, const std::vector<std::string>& keys) {
store.wait(keys);
},
py::call_guard<py::gil_scoped_release>())
py::call_guard<py::gil_scoped_release>(),
R"(
Waits for each key in ``keys`` to be added to the store. If not all keys are
set before the ``timeout`` (set during store initialization), then ``wait``
will throw an exception.

Arguments:
keys (list): List of keys on which to wait until they are set in the store.

Example::
>>> import torch.distributed as dist
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> # This will throw an exception after 30 seconds
>>> store.wait(["bad_key"])
)")
.def(
"wait",
[](::c10d::Store& store,
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) {
store.wait(keys, timeout);
},
py::call_guard<py::gil_scoped_release>());

shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store)
py::call_guard<py::gil_scoped_release>(),
R"(
Waits for each key in ``keys`` to be added to the store, and throws an exception
if the keys have not been set by the supplied ``timeout``.

Arguments:
keys (list): List of keys on which to wait until they are set in the store.
timeout (timedelta): Time to wait for the keys to be added before throwing an exception.

Example::
>>> import torch.distributed as dist
>>> # Using TCPStore as an example, other store types can also be used
>>> store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> # This will throw an exception after 10 seconds
>>> store.wait(["bad_key"], timedelta(seconds=10))
)");

shared_ptr_class_<::c10d::FileStore>(module, "FileStore", store,
R"(
A store implementation that uses a file to store the underlying key-value pairs.

Arguments:
file_name (str): path of the file in which to store the key-value pairs
world_size (int): The total number of processes using the store

Example::
>>> import torch.distributed as dist
>>> store1 = dist.FileStore("/tmp/filestore", 2)
>>> store2 = dist.FileStore("/tmp/filestore", 2)
>>> # Use any of the store methods from either the client or server after initialization
>>> store1.set("first_key", "first_value")
>>> store2.get("first_key")

)")
.def(py::init<const std::string&, int>());

#ifndef _WIN32
shared_ptr_class_<::c10d::HashStore>(module, "HashStore", store)
shared_ptr_class_<::c10d::HashStore>(module, "HashStore", store,
R"(
A thread-safe store implementation based on an underlying hashmap. This store can be used
within the same process (for example, by other threads), but cannot be used across processes.

Example::
>>> import torch.distributed as dist
>>> store = dist.HashStore()
>>> # store can be used from other threads
>>> # Use any of the store methods after initialization
>>> store.set("first_key", "first_value")
)")
.def(py::init<>());

shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store)
shared_ptr_class_<::c10d::TCPStore>(module, "TCPStore", store,
R"(
A TCP-based distributed key-value store implementation. The server store holds
the data, while the client stores can connect to the server store over TCP and
perform actions such as :meth:`~torch.distributed.store.set` to insert a key-value
pair, :meth:`~torch.distributed.store.get` to retrieve a key-value pair, etc.

Arguments:
host_name (str): The hostname or IP Address the server store should run on.
port (int): The port on which the server store should listen for incoming requests.
world_size (int): The total number of store users (number of clients + 1 for the server).
is_master (bool): True when initializing the server store, False for client stores.
timeout (timedelta): Timeout used by the store during initialization and for methods such as :meth:`~torch.distributed.store.get` and :meth:`~torch.distributed.store.wait`.

Example::
>>> import torch.distributed as dist
>>> server_store = dist.TCPStore("127.0.0.1", 0, true, timedelta(seconds=30))
>>> client_store = dist.TCPStore("127.0.0.1", 0, false)
>>> # Use any of the store methods from either the client or server after initialization
>>> server_store.set("first_key", "first_value")
>>> client_store.get("first_key")
)")
.def(
py::init<
const std::string&,
Expand All @@ -361,7 +557,16 @@ They are used in specifying strategies for reduction collectives, e.g.,
std::chrono::milliseconds(::c10d::Store::kDefaultTimeout));
#endif

shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store)
shared_ptr_class_<::c10d::PrefixStore>(module, "PrefixStore", store,
R"(
A wrapper around any of the 3 key-value stores (:class:`~torch.distributed.TCPStore`,
:class:`~torch.distributed.FileStore`, and :class:`~torch.distributed.HashStore`)
that adds a prefix to each key inserted to the store.

Arguments:
prefix (str): The prefix string that is prepended to each key before being inserted into the store.
store (torch.distributed.store): A store object that forms the underlying key-value store.
)")
.def(py::init<const std::string&, std::shared_ptr<::c10d::Store>>());

auto processGroup =
Expand Down