Skip to content

Commit 029a968

Browse files
ezyangfacebook-github-bot
authored andcommitted
Define __setstate__ on _ConvNd to handle pre-padding_mode pickles. (#21687)
Summary: Pull Request resolved: #21687 ghimport-source-id: df49530 Differential Revision: D15807402 Pulled By: ezyang fbshipit-source-id: f51b221444afc4e017db7544642a9c0a7d2a3efb
1 parent 7284f44 commit 029a968

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ test/.coverage
3131
test/.hypothesis/
3232
test/cpp/api/mnist
3333
test/custom_operator/model.pt
34-
test/data/gpu_tensors.pt
3534
test/data/legacy_modules.t7
36-
test/data/legacy_serialized.pt
37-
test/data/linear.pt
35+
test/data/*.pt
3836
dropout_model.pt
3937
test/generated_type_hints_smoketest.py
4038
test/htmlcov

test/test_nn.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import sys
23
import random
34
import string
45
import unittest
@@ -579,6 +580,26 @@ def test_module_backcompat(self):
579580
input = torch.randn(2, 3, dtype=torch.float)
580581
self.assertEqual(m(input).size(), (2, 5))
581582

583+
def test_conv_backcompat(self):
584+
from torch.serialization import SourceChangeWarning
585+
# This file was generated by running on PyTorch 1.0.1 on Python 2:
586+
#
587+
# import torch
588+
# from torch import nn
589+
# m = nn.Conv2d(1, 1, 1)
590+
# torch.save(m, 'legacy_conv2d.pt')
591+
#
592+
# NB: This Pickle also contains some Unicode data!
593+
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
594+
with warnings.catch_warnings():
595+
warnings.simplefilter('ignore', SourceChangeWarning)
596+
if sys.version_info[0] == 2:
597+
m = torch.load(path)
598+
else:
599+
m = torch.load(path, encoding='utf-8')
600+
input = torch.randn((1, 1, 1, 1), dtype=torch.float)
601+
self.assertEqual(m(input).size(), (1, 1, 1, 1))
602+
582603
def test_share_memory(self):
583604
class Net(nn.Module):
584605
def __init__(self):

torch/nn/modules/conv.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def extra_repr(self):
6868
s += ', bias=False'
6969
return s.format(**self.__dict__)
7070

71+
def __setstate__(self, state):
72+
super(_ConvNd, self).__setstate__(state)
73+
if not hasattr(self, 'padding_mode'):
74+
self.padding_mode = 'zeros'
75+
7176

7277
@weak_module
7378
class Conv1d(_ConvNd):

0 commit comments

Comments
 (0)