Skip to content

Commit 410c721

Browse files
davidriazatifacebook-github-bot
authored andcommitted
Add save() to torch._C.Function (#20386)
Summary: Fixes #20017 This wraps the `torch._C.Function` currently returned from `torch.jit.script` and `torch.jit.trace` in a `ScriptFunction` and `TracedFunction` respectively, both of which are just wrappers to hold the function. ](https://our.intern.facebook.com/intern/diff/15403161/) Pull Request resolved: #20386 Pulled By: driazati Differential Revision: D15403161 fbshipit-source-id: 94fb9f32929e62a00be6cf7512ea144ec9b91e0b
1 parent 987f1cc commit 410c721

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

test/test_jit.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,6 +3127,21 @@ def fn(x):
31273127
warns = [str(w.message) for w in warns]
31283128
self.assertEqual(len(warns), 0)
31293129

3130+
@unittest.skipIf(sys.platform == "win32", "temp file name on windows")
3131+
def test_trace_save(self):
3132+
def fn(x):
3133+
return x + 2
3134+
3135+
def check(func):
3136+
with tempfile.NamedTemporaryFile() as f:
3137+
func.save(f.name)
3138+
loaded = torch.jit.load(f.name)
3139+
input = torch.randn(2, 2)
3140+
self.assertEqual(func(input), loaded(input))
3141+
3142+
out = torch.jit.trace(fn, (torch.ones(2, 2),))
3143+
check(out)
3144+
31303145
@unittest.skipIf(sys.platform == "win32", "TODO: need to fix this test case for Windows")
31313146
def test_torch_load_error(self):
31323147
class J(torch.jit.ScriptModule):

torch/csrc/jit/script/init.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,17 @@ static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
233233
return retval;
234234
}
235235

236+
void addFunctionToModule(
237+
Module& module,
238+
const std::shared_ptr<Function>& func) {
239+
// Make a graph with a fake self argument
240+
auto graph = func->graph()->copy();
241+
auto v = graph->insertInput(0, "self");
242+
v->setType(module.module_object()->type());
243+
module.module_object()->type()->compilation_unit().create_function(
244+
"forward", graph);
245+
}
246+
236247
void initJitScriptBindings(PyObject* module) {
237248
auto m = py::handle(module).cast<py::module>();
238249

@@ -481,6 +492,27 @@ void initJitScriptBindings(PyObject* module) {
481492
}
482493
return result;
483494
})
495+
.def(
496+
"save",
497+
[](std::shared_ptr<Function> self,
498+
const std::string& filename,
499+
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
500+
Module module;
501+
addFunctionToModule(module, self);
502+
module.save(filename, _extra_files);
503+
},
504+
py::arg("filename"),
505+
py::arg("_extra_files") = ExtraFilesMap())
506+
.def(
507+
"save_to_buffer",
508+
[](std::shared_ptr<Function> self,
509+
const ExtraFilesMap& _extra_files = ExtraFilesMap()) {
510+
std::ostringstream buf;
511+
Module module;
512+
addFunctionToModule(module, self);
513+
return py::bytes(buf.str());
514+
},
515+
py::arg("_extra_files") = ExtraFilesMap())
484516
.def_property_readonly("graph", &Function::graph)
485517
.def_property_readonly("schema", &Function::getSchema)
486518
.def_property_readonly(

torch/csrc/jit/script/module.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,9 @@ std::pair<std::shared_ptr<Graph>, std::vector<Slot>> lower_graph(
195195
}
196196
if (e.n->kind() != prim::GetAttr) {
197197
throw ErrorReport(e.n->sourceRange())
198-
<< "temporary: the only valid use of a module is looking up an attribute";
198+
<< "temporary: the only valid use of a module is looking up an "
199+
"attribute but found "
200+
<< *e.n;
199201
}
200202
Slot slot(e.mod, e.mod->type()->getAttributeSlot(e.n->s(attr::name)));
201203
if (ClassTypePtr c = e.n->output()->type()->cast<ClassType>()) {

0 commit comments

Comments
 (0)