Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions caffe2/distributed/file_store_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ int64_t FileStoreHandler::add(
return 0;
}

int64_t FileStoreHandler::getNumKeys() {
CHECK(false) << "getNumKeys not implemented for FileStoreHandler";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be TORCH_CHECK?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in the Caffe2 codebase so I don't think we should use TORCH_CHECK. We could use CAFFE_ENFORCE, but the other unimplemented function in this class used CHECK(false) so I used it as well.

return 0;
}

bool FileStoreHandler::check(const std::vector<std::string>& names) {
std::vector<std::string> paths;
for (const auto& name : names) {
Expand Down
2 changes: 2 additions & 0 deletions caffe2/distributed/file_store_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class CAFFE2_API FileStoreHandler : public StoreHandler {

virtual int64_t add(const std::string& name, int64_t value) override;

virtual int64_t getNumKeys() override;

virtual bool check(const std::vector<std::string>& names) override;

virtual void wait(
Expand Down
5 changes: 5 additions & 0 deletions caffe2/distributed/redis_store_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ int64_t RedisStoreHandler::add(const std::string& name, int64_t value) {
return reply->integer;
}

int64_t RedisStoreHandler::getNumKeys() {
CHECK(false) << "getNumKeys not implemented for RedisStoreHandler";
return 0;
}

bool RedisStoreHandler::check(const std::vector<std::string>& names) {
std::vector<std::string> args;
args.push_back("EXISTS");
Expand Down
2 changes: 2 additions & 0 deletions caffe2/distributed/redis_store_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class CAFFE2_API RedisStoreHandler : public StoreHandler {

virtual int64_t add(const std::string& name, int64_t value) override;

virtual int64_t getNumKeys() override;

virtual bool check(const std::vector<std::string>& names) override;

virtual void wait(
Expand Down
5 changes: 5 additions & 0 deletions caffe2/distributed/store_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class CAFFE2_API StoreHandler {
*/
virtual int64_t add(const std::string& name, int64_t value) = 0;

/*
* Returns the number of keys in this store.
*/
virtual int64_t getNumKeys() = 0;

/*
* Check if a keys exist in the store.
*/
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class PythonStore : public ::c10d::Store {
PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value);
}

int64_t getNumKeys() override {
PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys);
}

bool check(const std::vector<std::string>& keys) override {
PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys);
}
Expand Down Expand Up @@ -301,6 +305,10 @@ They are used in specifying strategies for reduction collectives, e.g.,
"add",
&::c10d::Store::add,
py::call_guard<py::gil_scoped_release>())
.def(
"num_keys",
&::c10d::Store::getNumKeys,
py::call_guard<py::gil_scoped_release>())
.def(
"set_timeout",
&::c10d::Store::setTimeout,
Expand Down
6 changes: 6 additions & 0 deletions torch/lib/c10d/FileStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <system_error>
#include <thread>

#include <c10/util/Exception.h>

#define SYSASSERT(rv, ...) \
if ((rv) < 0) { \
throw std::system_error(errno, std::system_category(), ##__VA_ARGS__); \
Expand Down Expand Up @@ -303,6 +305,10 @@ int64_t FileStore::add(const std::string& key, int64_t value) {
return addHelper(regKey, value);
}

int64_t FileStore::getNumKeys() {
TORCH_CHECK(false, "getNumKeys not implemented for FileStore");
}

bool FileStore::check(const std::vector<std::string>& keys) {
std::unique_lock<std::mutex> l(activeFileOpLock_);
File file(path_, O_RDONLY, timeout_);
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/FileStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class FileStore : public Store {

int64_t add(const std::string& key, int64_t value) override;

int64_t getNumKeys() override;

bool check(const std::vector<std::string>& keys) override;

void wait(const std::vector<std::string>& keys) override;
Expand Down
6 changes: 6 additions & 0 deletions torch/lib/c10d/HashStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <cstdio>
#include <system_error>

#include <c10/util/Exception.h>

namespace c10d {

void HashStore::set(const std::string& key, const std::vector<uint8_t>& data) {
Expand Down Expand Up @@ -77,6 +79,10 @@ int64_t HashStore::add(const std::string& key, int64_t i) {
return ti;
}

int64_t HashStore::getNumKeys() {
TORCH_CHECK(false, "getNumKeys not implemented for HashStore");
}

bool HashStore::check(const std::vector<std::string>& keys) {
std::unique_lock<std::mutex> lock(m_);
for (const auto& key : keys) {
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/HashStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class HashStore : public Store {

int64_t add(const std::string& key, int64_t value) override;

int64_t getNumKeys() override;

bool check(const std::vector<std::string>& keys) override;

protected:
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/PrefixStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ int64_t PrefixStore::add(const std::string& key, int64_t value) {
return store_->add(joinKey(key), value);
}

int64_t PrefixStore::getNumKeys() {
return store_->getNumKeys();
}

bool PrefixStore::check(const std::vector<std::string>& keys) {
auto joinedKeys = joinKeys(keys);
return store_->check(joinedKeys);
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/PrefixStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class PrefixStore : public Store {

int64_t add(const std::string& key, int64_t value) override;

int64_t getNumKeys() override;

bool check(const std::vector<std::string>& keys) override;

void wait(const std::vector<std::string>& keys) override;
Expand Down
2 changes: 2 additions & 0 deletions torch/lib/c10d/Store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Store {

virtual bool check(const std::vector<std::string>& keys) = 0;

virtual int64_t getNumKeys() = 0;

virtual void wait(const std::vector<std::string>& keys) = 0;

virtual void wait(
Expand Down
14 changes: 13 additions & 1 deletion torch/lib/c10d/TCPStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace c10d {

namespace {

enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT };
enum class QueryType : uint8_t { SET, GET, ADD, CHECK, WAIT, GETNUMKEYS };

enum class CheckResponseType : uint8_t { READY, NOT_READY };

Expand Down Expand Up @@ -180,6 +180,9 @@ void TCPStoreDaemon::query(int socket) {
} else if (qt == QueryType::WAIT) {
waitHandler(socket);

} else if (qt == QueryType::GETNUMKEYS) {
getNumKeysHandler(socket);

} else {
throw std::runtime_error("Unexpected query type");
}
Expand Down Expand Up @@ -228,6 +231,10 @@ void TCPStoreDaemon::getHandler(int socket) const {
tcputil::sendVector<uint8_t>(socket, data);
}

void TCPStoreDaemon::getNumKeysHandler(int socket) const {
tcputil::sendValue<int64_t>(socket, tcpStore_.size());
}

void TCPStoreDaemon::checkHandler(int socket) const {
SizeType nargs;
tcputil::recvBytes<SizeType>(socket, &nargs, 1);
Expand Down Expand Up @@ -364,6 +371,11 @@ int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
return tcputil::recvValue<int64_t>(storeSocket_);
}

int64_t TCPStore::getNumKeys() {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::GETNUMKEYS);
return tcputil::recvValue<int64_t>(storeSocket_);
}

bool TCPStore::check(const std::vector<std::string>& keys) {
tcputil::sendValue<QueryType>(storeSocket_, QueryType::CHECK);
SizeType nkeys = keys.size();
Expand Down
3 changes: 3 additions & 0 deletions torch/lib/c10d/TCPStore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TCPStoreDaemon {
void addHandler(int socket);
void getHandler(int socket) const;
void checkHandler(int socket) const;
void getNumKeysHandler(int socket) const;
void waitHandler(int socket);

bool checkKeys(const std::vector<std::string>& keys) const;
Expand Down Expand Up @@ -63,6 +64,8 @@ class TCPStore : public Store {

bool check(const std::vector<std::string>& keys) override;

int64_t getNumKeys() override;

void wait(const std::vector<std::string>& keys) override;

void wait(
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/c10d/test/TCPStoreTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ void testHelper(const std::string& prefix = "") {
c10d::test::check(*serverStore, "key0", "value0");
c10d::test::check(*serverStore, "key1", "value1");
c10d::test::check(*serverStore, "key2", "value2");
auto numKeys = serverStore->getNumKeys();
// We expect 5 keys since 3 are added above, 'counter' is added by the
// helper thread, and the init key to coordinate workers.
EXPECT_EQ(numKeys, 5);
});

// Hammer on TCPStore
Expand Down