Skip to content
2 changes: 2 additions & 0 deletions caffe2/proto/torch.proto
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ message ModuleDef {

// Used for retrieving module state from the pickled IValues table
optional int64 get_state_attribute_id = 10;

optional RecordRef torchscript_debug_arena = 11;
}

// Represents all non-module code that the model depends on.
Expand Down
26 changes: 26 additions & 0 deletions torch/csrc/jit/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/python_print.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/source_range_serializer.h>

#include <caffe2/core/types.h>
#include <caffe2/proto/caffe2_pb.h>
Expand Down Expand Up @@ -659,8 +660,11 @@ void ScriptModuleSerializer::convertClass(

std::vector<c10::NamedTypePtr> class_deps;
std::ostringstream class_stream;
// TODO: serialize for classes
SourceRangeRecords source_ranges;
PythonPrint(
class_stream,
source_ranges,
class_type,
tensor_table_,
class_deps,
Expand Down Expand Up @@ -905,9 +909,11 @@ void ScriptModuleSerializer::convertModule(

if (module.class_compilation_unit()->get_functions().size() > 0) {
std::ostringstream methods;
SourceRangeRecords source_ranges;
methods << "op_version_set = " << CURRENT_OP_VERSION_SET << "\n";
PythonPrint(
methods,
source_ranges,
*module.class_compilation_unit(),
/*is_method=*/true,
tensor_table_,
Expand All @@ -921,6 +927,26 @@ void ScriptModuleSerializer::convertModule(
writer_.writeRecord(
filename.str(), methods_str.c_str(), methods_str.size());
record->set_key(filename.str());

// Write out debug records
torch::RecordRef* debug_record =
module_def->mutable_torchscript_debug_arena();
Pickler p;
SourceRangeSerializer srs;
p.start();
p.startTuple();
for (const auto& range : source_ranges) {
std::vector<c10::IValue> row_elems{(int64_t)range.bytes,
srs.serialize(range.range)};
p.addIValue(c10::ivalue::Tuple::create(std::move(row_elems)));
}
p.endTuple();
p.finish();
std::stringstream debug_filename;
debug_filename << "debug/" << module_name.str() << ".pkl";
writer_.writeRecord(
debug_filename.str(), p.stack().data(), p.stack().size());
debug_record->set_key(debug_filename.str());
}

for (script::Slot s : module.get_module_slots()) {
Expand Down
Loading