As opposed to torch.save() and the pickle api, load() resets the position pointer of the file to zero. This keeps the following from workgin
import tempfile
import torch
import pickle
a = torch.zeros(2, 2)
with tempfile.TemporaryFile(mode='w+b') as tmp:
i = 41
pickle.dump(i, tmp)
torch.save(a, tmp)
tmp.seek(0)
j = pickle.load(tmp)
b = torch.load(tmp)