Skip to content

Commit f6f9d22

Browse files
neginraooffacebook-github-bot
authored andcommitted
[ONNX] Export KLDivLoss (#41858)
Summary: Enable export for KLDivLoss Pull Request resolved: #41858 Reviewed By: mrshenli Differential Revision: D22918004 Pulled By: bzinodev fbshipit-source-id: e3debf77a4cf0eae0df6ed5a72ee91c43e482b62
1 parent 4716284 commit f6f9d22

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3670,6 +3670,72 @@ def forward(self, input, target):
36703670

36713671
self.run_test(CrossEntropyLossMeanWeightIgnoreIndex(), input=(x, y))
36723672

3673+
@skipIfUnsupportedMinOpsetVersion(9)
3674+
def test_kldiv_loss(self):
3675+
3676+
x = torch.randn(5)
3677+
y = torch.randn(5)
3678+
self._kldiv_loss(x, y)
3679+
3680+
x = torch.randn(2, 3, 5)
3681+
y = torch.randn(2, 3, 5)
3682+
self._kldiv_loss(x, y)
3683+
3684+
x = torch.randn(2, 3, 5, 7)
3685+
y = torch.randn(2, 3, 5, 7)
3686+
self._kldiv_loss(x, y)
3687+
3688+
def _kldiv_loss(self, x, y):
3689+
class KLDivLossNone(torch.nn.Module):
3690+
def __init__(self):
3691+
super(KLDivLossNone, self).__init__()
3692+
self.loss = torch.nn.KLDivLoss(reduction='none', log_target=True)
3693+
3694+
def forward(self, input, target):
3695+
return self.loss(input, target)
3696+
3697+
self.run_test(KLDivLossNone(), input=(x, y))
3698+
3699+
class KLDivLossMean(torch.nn.Module):
3700+
def __init__(self):
3701+
super(KLDivLossMean, self).__init__()
3702+
self.loss = torch.nn.KLDivLoss(reduction='mean', log_target=False)
3703+
3704+
def forward(self, input, target):
3705+
return self.loss(input, target)
3706+
3707+
self.run_test(KLDivLossMean(), input=(x, y))
3708+
3709+
class KLDivLossSum(torch.nn.Module):
3710+
def __init__(self):
3711+
super(KLDivLossSum, self).__init__()
3712+
self.loss = torch.nn.KLDivLoss(reduction='sum', log_target=True)
3713+
3714+
def forward(self, input, target):
3715+
return self.loss(input, target)
3716+
3717+
self.run_test(KLDivLossSum(), input=(x, y))
3718+
3719+
class KLDivLossBatchMean(torch.nn.Module):
3720+
def __init__(self):
3721+
super(KLDivLossBatchMean, self).__init__()
3722+
self.loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=False)
3723+
3724+
def forward(self, input, target):
3725+
return self.loss(input, target)
3726+
3727+
self.run_test(KLDivLossBatchMean(), input=(x, y))
3728+
3729+
class KLDivLossMiniBatchMean(torch.nn.Module):
3730+
def __init__(self):
3731+
super(KLDivLossMiniBatchMean, self).__init__()
3732+
self.loss = torch.nn.KLDivLoss(reduction='batchmean', size_average=False, log_target=True)
3733+
3734+
def forward(self, input, target):
3735+
return self.loss(input, target)
3736+
3737+
self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
3738+
36733739
@skipIfUnsupportedMinOpsetVersion(12)
36743740
def test_nllloss(self):
36753741
class NLLModel(torch.nn.Module):

torch/onnx/symbolic_opset9.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,6 +2530,42 @@ def take(g, self, index):
25302530
out = reshape_as(g, out, index)
25312531
return out
25322532

2533+
2534+
def _kl_div_log_target_impl(g, input, target):
2535+
diff_ = sub(g, target, input)
2536+
exp_ = exp(g, target)
2537+
output = mul(g, exp_, diff_)
2538+
return output
2539+
2540+
2541+
def _kl_div_non_log_target_impl(g, input, target):
2542+
log_ = log(g, target)
2543+
diff_ = sub(g, log_, input)
2544+
output_pos = mul(g, target, diff_)
2545+
zeros_ = zeros_like(g, output_pos)
2546+
mask_ = gt(g, target, g.op("Constant", value_t=torch.tensor(0)))
2547+
output = where(g, mask_, output_pos, zeros_)
2548+
return output
2549+
2550+
2551+
@parse_args('v', 'v', 'i', 'b')
2552+
def kl_div(g, input, target, reduction, log_target):
2553+
if log_target:
2554+
output = _kl_div_log_target_impl(g, input, target)
2555+
else:
2556+
output = _kl_div_non_log_target_impl(g, input, target)
2557+
2558+
if reduction == 0:
2559+
return output
2560+
elif reduction == 1:
2561+
return g.op("ReduceMean", output, keepdims_i=0)
2562+
elif reduction == 2:
2563+
return g.op("ReduceSum", output, keepdims_i=0)
2564+
else:
2565+
return sym_help._onnx_unsupported("kl_div with reduction other than none, mean, or sum. Please open a bug to "
2566+
"request ONNX export support for the missing reduction type.")
2567+
2568+
25332569
@parse_args('v', 'v', 'is', 'i')
25342570
def as_strided(g, self, sizes, strides, offset=None):
25352571
sizes = sym_help._maybe_get_const(sizes, 'is')

0 commit comments

Comments
 (0)