Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions caffe2/serialize/inline_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

namespace caffe2 {
namespace serialize {
constexpr c10::string_view kDebugPklSuffix(".debug_pkl");

size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
Expand Down Expand Up @@ -222,6 +223,10 @@ size_t getPadding(

bool PyTorchStreamReader::hasRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);

if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
return false;
}
std::string ss = archive_name_plus_slash_ + name;
mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
const mz_zip_error err = mz_zip_get_last_error(ar_.get());
Expand Down Expand Up @@ -255,8 +260,11 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
": ",
buf);
}
// NOLINTNEXTLINE(modernize-use-emplace)
out.push_back(buf + archive_name_plus_slash_.size());
if ((load_debug_symbol_) ||
(!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
// NOLINTNEXTLINE(modernize-use-emplace)
out.push_back(buf + archive_name_plus_slash_.size());
}
}
return out;
}
Expand All @@ -276,6 +284,10 @@ size_t PyTorchStreamReader::getRecordID(const std::string& name) {
// return dataptr, size
std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
at::DataPtr retval;
return std::make_tuple(std::move(retval), 0);
}
size_t key = getRecordID(name);
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), key, &stat);
Expand Down
5 changes: 5 additions & 0 deletions caffe2/serialize/inline_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ class TORCH_API PyTorchStreamReader final {
return version_;
}

void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
load_debug_symbol_ = should_load_debug_symbol;
}

private:
void init();
size_t read(uint64_t pos, char* buf, size_t n);
Expand All @@ -124,6 +128,7 @@ class TORCH_API PyTorchStreamReader final {
std::shared_ptr<ReadAdapterInterface> in_;
int64_t version_;
std::mutex reader_lock_;
bool load_debug_symbol_ = true;
};

class TORCH_API PyTorchStreamWriter final {
Expand Down
47 changes: 47 additions & 0 deletions caffe2/serialize/inline_container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,53 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
EXPECT_TRUE(reader.hasRecord("key1"));
}

TEST(PytorchStreamWriterAndReader, SkipDebugRecords) {
std::ostringstream oss;
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 127> data1;

for (auto i: c10::irange(data1.size())) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1.debug_pkl", data1.data(), data1.size());

// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init,cppcoreguidelines-avoid-magic-numbers)
std::array<char, 64> data2;
for (auto i: c10::irange(data2.size())) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2.debug_pkl", data2.data(), data2.size());

const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1.debug_pkl"), 1);
ASSERT_EQ(written_records.count("key2.debug_pkl"), 1);
writer.writeEndOfFile();

std::string the_file = oss.str();
std::ofstream foo("output2.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();

std::istringstream iss(the_file);

// read records through readers
PyTorchStreamReader reader(&iss);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)

reader.setShouldLoadDebugSymbol(false);
EXPECT_FALSE(reader.hasRecord("key1.debug_pkl"));
at::DataPtr ptr;
size_t size;
std::tie(ptr, size) = reader.getRecord("key1.debug_pkl");
EXPECT_EQ(size, 0);
}

} // namespace
} // namespace serialize
} // namespace caffe2
46 changes: 46 additions & 0 deletions test/cpp/jit/test_save_load.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <gtest/gtest.h>

#include <test/cpp/jit/test_utils.h>
#include <iostream>
#include <sstream>

#include <caffe2/serialize/inline_container.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/calculate_necessary_args.h>
#include <torch/csrc/jit/serialization/export.h>
Expand Down Expand Up @@ -272,5 +274,49 @@ TEST(SerializationTest, CalculateNecessaryArgsTest) {
EXPECT_EQ(0, necessary.second);
}

TEST(TestSaveLoad, LoadWithoutDebugInfo) { // NOLINT (use =delete in gtest)
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(
R"(
def test_func(self, x):
b = 4
return self.foo + x + b
)");
m.define(
R"(
def exception(self):
assert False, "message"
)");
std::stringstream ss;
m.save(ss);
ss.seekg(0);
caffe2::serialize::PyTorchStreamReader reader(&ss);
reader.setShouldLoadDebugSymbol(true);
EXPECT_TRUE(reader.hasRecord("code/__torch__.py.debug_pkl"));
reader.setShouldLoadDebugSymbol(false);
EXPECT_FALSE(reader.hasRecord("code/__torch__.py.debug_pkl"));
ss.seekg(0);
Module m2 = torch::jit::load(ss);
std::string error_msg = R"(
def exception(self):
assert False, "message"
~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE)";
ASSERT_THROWS_WITH_MESSAGE(m2.run_method("exception"), error_msg);

ss.seekg(0);
// NO DEBUG trace so error message points to torchscript generated
// source instead of original python source.
std::string error2 = R"(
def exception(self: __torch__.m) -> NoneType:
_0 = uninitialized(NoneType)
ops.prim.RaiseException("AssertionError: message")
~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
return _0
)";
Module m3 = torch::jit::load(ss, c10::nullopt, false);
ASSERT_THROWS_WITH_MESSAGE(m3.run_method("exception"), error2);
}

} // namespace jit
} // namespace torch
74 changes: 51 additions & 23 deletions torch/csrc/jit/serialization/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,11 @@ Module ScriptModuleDeserializer::deserialize(
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
c10::optional<at::Device> device) {
c10::optional<at::Device> device,
bool load_debug_files) {
ExtraFilesMap extra_files;
return import_ir_module(std::move(cu), in, device, extra_files);
return import_ir_module(
std::move(cu), in, device, extra_files, load_debug_files);
}

static Module _load_jit_module_from_bytes(
Expand Down Expand Up @@ -344,12 +346,14 @@ Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
in.seekg(0, in.beg);
// NOTE: Zipformat can be large files. So using stream version directly
// instead of reading the file all at once.
if (getFileFormat(in) != FileFormat::FlatbufferFileFormat) {
auto reader = torch::make_unique<PyTorchStreamReader>(&in);
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
Expand Down Expand Up @@ -379,20 +383,24 @@ Module import_ir_module(
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
c10::optional<at::Device> device) {
c10::optional<at::Device> device,
bool load_debug_files) {
ExtraFilesMap extra_files;
return import_ir_module(std::move(cu), filename, device, extra_files);
return import_ir_module(
std::move(cu), filename, device, extra_files, load_debug_files);
}

Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
// NOTE: Zipformat can be large files. So using stream version directly
// instead of reading the file all at once.
if (getFileFormat(filename) != FileFormat::FlatbufferFileFormat) {
auto reader = torch::make_unique<PyTorchStreamReader>(filename);
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}
Expand All @@ -405,70 +413,90 @@ Module import_ir_module(
Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device) {
c10::optional<at::Device> device,
bool load_debug_files) {
ExtraFilesMap extra_files;
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
return import_ir_module(
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
}

Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::unique_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
std::shared_ptr<ReadAdapterInterface> rai_shared = std::move(rai);
return import_ir_module(cu, rai_shared, device, extra_files);
return import_ir_module(
cu, rai_shared, device, extra_files, load_debug_files);
}

Module import_ir_module(
std::shared_ptr<CompilationUnit> cu,
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto reader = std::make_shared<PyTorchStreamReader>(std::move(rai));
reader->setShouldLoadDebugSymbol(load_debug_files);
ScriptModuleDeserializer deserializer(std::move(cu), std::move(reader));
return deserializer.deserialize(device, extra_files);
}

Module load(std::istream& in, c10::optional<at::Device> device) {
Module load(
std::istream& in,
c10::optional<at::Device> device,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), in, device);
return import_ir_module(std::move(cu), in, device, load_debug_files);
}

Module load(
std::istream& in,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), in, device, extra_files);
return import_ir_module(
std::move(cu), in, device, extra_files, load_debug_files);
}

Module load(const std::string& filename, c10::optional<at::Device> device) {
Module load(
const std::string& filename,
c10::optional<at::Device> device,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), filename, device);
return import_ir_module(std::move(cu), filename, device, load_debug_files);
}

Module load(
const std::string& filename,
c10::optional<at::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), filename, device, extra_files);
return import_ir_module(
std::move(cu), filename, device, extra_files, load_debug_files);
}

Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device) {
c10::optional<c10::Device> device,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
ExtraFilesMap extra_files;
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
return import_ir_module(
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
}

Module load(
std::shared_ptr<ReadAdapterInterface> rai,
c10::optional<c10::Device> device,
ExtraFilesMap& extra_files) {
ExtraFilesMap& extra_files,
bool load_debug_files) {
auto cu = std::make_shared<CompilationUnit>();
return import_ir_module(std::move(cu), std::move(rai), device, extra_files);
return import_ir_module(
std::move(cu), std::move(rai), device, extra_files, load_debug_files);
}

Module _load_jit_module_from_bytes(
Expand Down
Loading