Skip to content

Commit 5f49d14

Browse files
Akshit Khuranafacebook-github-bot
authored andcommitted
Add mobile_optimized tag to optimized model. (#45479)
Summary: Pull Request resolved: #45479 Add a top level boolean attribute to the model called mobile_optimized that is set to true if it is optimized. Test Plan: buck test //caffe2/test:mobile passes Reviewed By: kimishpatel Differential Revision: D23956728 fbshipit-source-id: 79c5931702208b871454319ca2ab8633596b1eb8
1 parent 17be7c6 commit 5f49d14

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

test/test_mobile_optimizer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,23 @@ def forward(self, x):
131131
bn_input = torch.rand(1, 1, 6, 6)
132132
torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
133133

134+
class MyMobileOptimizedTagTest(torch.nn.Module):
135+
def __init__(self):
136+
super(MyMobileOptimizedTagTest, self).__init__()
137+
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
138+
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
139+
140+
def forward(self, x):
141+
o = F.linear(x, self.linear_weight, self.linear_bias)
142+
return F.relu(o)
143+
144+
mobile_optimized_tag_module = MyMobileOptimizedTagTest()
145+
m = torch.jit.script(mobile_optimized_tag_module)
146+
m.eval()
147+
opt_m = optimize_for_mobile(m)
148+
tag = getattr(opt_m, "mobile_optimized", None)
149+
self.assertTrue(tag)
150+
134151
class MyPreserveMethodsTest(torch.nn.Module):
135152
def __init__(self):
136153
super(MyPreserveMethodsTest, self).__init__()

torch/csrc/jit/passes/xnnpack_rewrite.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ script::Module optimizeForMobile(
390390
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
391391
FuseAddRelu(cloned_module);
392392
}
393-
393+
cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
394394
return cloned_module;
395395
}
396396

0 commit comments

Comments
 (0)