Skip to content

Commit f1c1d1a

Browse files
houseroadfacebook-github-bot
authored andcommitted
Export the cosine_similarity op as an ATenOp correctly (#21884)
Summary: cosine_similarity has two non-tensor parameters, needs some special handling. Add the support for its export in this diff. Pull Request resolved: #21884 Reviewed By: zrphercule Differential Revision: D15866807 Pulled By: houseroad fbshipit-source-id: a165fbc00c65c44b276df89ae705ca8960349d48
1 parent 3ed8acd commit f1c1d1a

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

test/onnx/test_pytorch_onnx_caffe2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,13 @@ def forward(self, x):
891891
x = torch.randn(*shape)
892892
self.run_model_test(MyModel(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False)
893893

894+
def test_cosine_similarity(self):
895+
shape = (100, 128)
896+
x = torch.randn(*shape)
897+
y = torch.randn(*shape)
898+
self.run_model_test(torch.nn.CosineSimilarity(dim=1, eps=1e-6), train=False,
899+
input=(x, y), batch_size=BATCH_SIZE, use_gpu=False)
900+
894901
def test_lstm_constant_folding(self):
895902
class LstmNet(nn.Module):
896903
def __init__(self, input_size, hidden_size, num_layers, bidirectional):

torch/onnx/symbolic_opset9.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,11 @@ def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable):
958958
eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm")
959959

960960

961+
@parse_args('v', 'v', 'i', 'f')
962+
def cosine_similarity(g, x1, x2, dim, eps):
963+
return g.op("ATen", x1, x2, dim_i=dim, eps_f=eps, operator_s="cosine_similarity")
964+
965+
961966
# ignore clone operators that are inserted by PyTorch autograd
962967
def clone(g, input):
963968
return input

0 commit comments

Comments
 (0)