Skip to content

Commit bf94eba

Browse files
author
lixinyu
committed
C++ APIs Transformer NN Module Top Layer
ghstack-source-id: 2936c7f Pull Request resolved: #44333
1 parent d232fec commit bf94eba

File tree

10 files changed

+544
-158
lines changed

10 files changed

+544
-158
lines changed

test/cpp/api/transformer.cpp

Lines changed: 127 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@ using namespace torch::nn;
88

99
struct TransformerTest : torch::test::SeedingFixture {};
1010

11+
// a generic function to set constants for parameters so we have fixed result for deterministic test
12+
template<typename Model>
13+
void set_parameter_to_constants(Model& model, const torch::TensorOptions& tensor_options) {
14+
torch::NoGradGuard guard;
15+
for (auto& p : model->parameters()) {
16+
auto sz = p.view(-1).size(0);
17+
p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
18+
}
19+
}
20+
1121
// a generic function to provide consistent encoder/decoder layer for all the transformer tests
1222
template<typename T_LAYER, typename T_OPTIONS>
1323
T_LAYER get_a_test_layer(const torch::TensorOptions& tensor_options) {
@@ -23,13 +33,7 @@ T_LAYER get_a_test_layer(const torch::TensorOptions& tensor_options) {
2333
}
2434

2535
// set constant weights of the model
26-
{
27-
torch::NoGradGuard guard;
28-
for (auto& p : layer->parameters()) {
29-
auto sz = p.view(-1).size(0);
30-
p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
31-
}
32-
}
36+
set_parameter_to_constants<T_LAYER>(layer, tensor_options);
3337

3438
return layer;
3539
}
@@ -579,25 +583,26 @@ TEST_F(TransformerTest, PrettyPrintTransformerEncoder) {
579583
")");
580584
}
581585

586+
582587
TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) {
583588
ASSERT_EQ(
584589
c10::str(TransformerDecoderLayer(4, 2)),
585590
"torch::nn::TransformerDecoderLayerImpl(\n"
586591
" (self_attn): torch::nn::MultiheadAttention(\n"
587592
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
588593
" )\n"
589-
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
590-
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
591594
" (multihead_attn): torch::nn::MultiheadAttention(\n"
592595
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
593596
" )\n"
594-
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
595-
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
596597
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
597598
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
598599
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
599-
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
600+
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
601+
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
600602
" (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
603+
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
604+
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
605+
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
601606
")");
602607
}
603608

@@ -978,8 +983,7 @@ void transformer_decoder_test_helper(bool is_cuda) {
978983

979984
// Multiple layers with norm
980985
norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
981-
model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)
982-
.norm(AnyModule(norm)));
986+
model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
983987
if (is_cuda) {
984988
model->to(torch::kCUDA);
985989
}
@@ -1036,54 +1040,129 @@ TEST_F(TransformerTest, PrettyPrintTransformerDecoder) {
10361040
" (layers): torch::nn::ModuleList(\n"
10371041
" (0): torch::nn::TransformerDecoderLayerImpl(\n"
10381042
" (self_attn): torch::nn::MultiheadAttention(\n"
1039-
" (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1040-
" bias=true)\n"
1043+
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
10411044
" )\n"
1042-
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1043-
" (norm1): torch::nn::LayerNorm([4], eps=1e-05,"
1044-
" elementwise_affine=true)\n"
10451045
" (multihead_attn): torch::nn::MultiheadAttention(\n"
1046-
" (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1047-
" bias=true)\n"
1046+
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
10481047
" )\n"
1049-
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1050-
" (norm2): torch::nn::LayerNorm([4], eps=1e-05,"
1051-
" elementwise_affine=true)\n"
1052-
" (linear1): torch::nn::Linear(in_features=4, out_features=2048,"
1053-
" bias=true)\n"
1048+
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
10541049
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1055-
" (linear2): torch::nn::Linear(in_features=2048, out_features=4,"
1056-
" bias=true)\n"
1050+
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1051+
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1052+
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1053+
" (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1054+
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1055+
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
10571056
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1058-
" (norm3): torch::nn::LayerNorm([4], eps=1e-05,"
1059-
" elementwise_affine=true)\n"
10601057
" )\n"
10611058
" (1): torch::nn::TransformerDecoderLayerImpl(\n"
10621059
" (self_attn): torch::nn::MultiheadAttention(\n"
1063-
" (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1064-
" bias=true)\n"
1060+
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
10651061
" )\n"
1066-
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1067-
" (norm1): torch::nn::LayerNorm([4], eps=1e-05,"
1068-
" elementwise_affine=true)\n"
10691062
" (multihead_attn): torch::nn::MultiheadAttention(\n"
1070-
" (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1071-
" bias=true)\n"
1063+
" (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
10721064
" )\n"
1073-
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1074-
" (norm2): torch::nn::LayerNorm([4], eps=1e-05,"
1075-
" elementwise_affine=true)\n"
1076-
" (linear1): torch::nn::Linear(in_features=4, out_features=2048,"
1077-
" bias=true)\n"
1065+
" (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
10781066
" (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1079-
" (linear2): torch::nn::Linear(in_features=2048, out_features=4,"
1080-
" bias=true)\n"
1067+
" (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1068+
" (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1069+
" (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1070+
" (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1071+
" (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1072+
" (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
10811073
" (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1082-
" (norm3): torch::nn::LayerNorm([4], eps=1e-05,"
1083-
" elementwise_affine=true)\n"
10841074
" )\n"
10851075
" )\n"
1086-
" (norm): torch::nn::LayerNorm([4], eps=1e-05,"
1087-
" elementwise_affine=true)\n"
1076+
" (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
10881077
")");
10891078
}
1079+
1080+
void transformer_test_helper(bool is_cuda) {
1081+
// this is a deterministic test for Transformere
1082+
torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
1083+
torch::TensorOptions tensor_options = torch::TensorOptions().dtype(torch::kFloat32).device(device);
1084+
1085+
// transformer created encoder/decoder
1086+
Transformer model(TransformerOptions()
1087+
.d_model(4)
1088+
.nhead(2)
1089+
.num_encoder_layers(2)
1090+
.num_decoder_layers(1)
1091+
.dim_feedforward(16)
1092+
.dropout(0.0)
1093+
.activation(torch::kReLU));
1094+
1095+
set_parameter_to_constants<Transformer>(model, tensor_options);
1096+
if (tensor_options.device() == torch::kCUDA) {
1097+
model->to(torch::kCUDA);
1098+
}
1099+
1100+
// transformer with customized encoder/decoder
1101+
LayerNorm enorm(LayerNormOptions({4}));
1102+
TransformerEncoder encoder(TransformerEncoderOptions(
1103+
TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), 2).norm(AnyModule(enorm)));
1104+
1105+
LayerNorm dnorm(LayerNormOptions({4}));
1106+
TransformerDecoder decoder(TransformerDecoderOptions(
1107+
TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0), 1).norm(AnyModule(dnorm)));
1108+
1109+
Transformer model_cus(TransformerOptions()
1110+
.d_model(4)
1111+
.nhead(2)
1112+
.custom_encoder(AnyModule(encoder))
1113+
.custom_decoder(AnyModule(decoder)));
1114+
1115+
set_parameter_to_constants<Transformer>(model_cus, tensor_options);
1116+
if (tensor_options.device() == torch::kCUDA) {
1117+
model_cus->to(torch::kCUDA);
1118+
}
1119+
1120+
// test cases
1121+
torch::Tensor src = torch::tensor({
1122+
{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1123+
{{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}},
1124+
{{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}}, tensor_options);
1125+
1126+
torch::Tensor tgt = torch::tensor({
1127+
{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1128+
{{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}}, tensor_options);
1129+
1130+
torch::Tensor ref_output = torch::tensor({
1131+
{{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}},
1132+
{{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}}}, tensor_options);
1133+
torch::Tensor result = model(src, tgt);
1134+
torch::Tensor result_cus = model_cus(src, tgt);
1135+
ASSERT_EQ(result.sizes(), ref_output.sizes());
1136+
ASSERT_TRUE(result.equal(result_cus));
1137+
ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1138+
1139+
torch::Tensor src_mask = Transformer::Impl::generate_square_subsequent_mask(src.size(0)).to(tensor_options);
1140+
ref_output = torch::tensor({
1141+
{{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}},
1142+
{{2.695875, 0.347114, -0.044355, -0.549541}, {2.696091, 0.347015, -0.044770, -0.548522}}}, tensor_options);
1143+
result = model(src, tgt, src_mask);
1144+
result_cus = model_cus(src, tgt, src_mask);
1145+
ASSERT_EQ(result.sizes(), ref_output.sizes());
1146+
ASSERT_TRUE(result.equal(result_cus));
1147+
ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1148+
1149+
torch::Tensor tgt_key_padding_mask = torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1;
1150+
tgt_key_padding_mask[0][0] = 1;
1151+
tgt_key_padding_mask[1][1] = 1;
1152+
ref_output = torch::tensor({
1153+
{{2.696114, 0.347004, -0.044813, -0.548417}, {2.696091, 0.347015, -0.044770, -0.548522}},
1154+
{{2.696114, 0.347004, -0.044813, -0.548417}, {2.696091, 0.347015, -0.044770, -0.548522}}}, tensor_options);
1155+
result = model(src, tgt, src_mask, torch::Tensor(), torch::Tensor(), torch::Tensor(), tgt_key_padding_mask);
1156+
result_cus = model_cus(src, tgt, src_mask, torch::Tensor(), torch::Tensor(), torch::Tensor(), tgt_key_padding_mask);
1157+
ASSERT_EQ(result.sizes(), ref_output.sizes());
1158+
ASSERT_TRUE(result.equal(result_cus));
1159+
ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1160+
}
1161+
1162+
TEST_F(TransformerTest, Transformer) {
1163+
transformer_test_helper(false);
1164+
}
1165+
1166+
TEST_F(TransformerTest, Transformer_CUDA) {
1167+
transformer_test_helper(true);
1168+
}

torch/csrc/api/include/torch/nn/modules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@
3232
#include <torch/nn/modules/normalization.h>
3333
#include <torch/nn/modules/transformerlayer.h>
3434
#include <torch/nn/modules/transformercoder.h>
35+
#include <torch/nn/modules/transformer.h>
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#pragma once
2+
3+
#include <torch/nn/cloneable.h>
4+
#include <torch/nn/module.h>
5+
#include <torch/nn/options/transformer.h>
6+
#include <torch/nn/pimpl.h>
7+
#include <torch/nn/modules/common.h>
8+
9+
#include <torch/types.h>
10+
11+
#include <ostream>
12+
13+
namespace torch {
14+
namespace nn {
15+
16+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
17+
18+
/// A transformer model. User is able to modify the attributes as needed. The architecture
19+
/// is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
20+
/// Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
21+
/// Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
22+
/// Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
23+
/// model with corresponding parameters.
24+
///
25+
/// See https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html to
26+
/// learn abouut the exact behavior of this transformer model
27+
///
28+
/// See the documentation for `torch::nn::Transformer` class to learn what
29+
/// constructor arguments are supported for this encoder layer model
30+
///
31+
/// Example:
32+
/// ```
33+
/// Transformer trans(TransformerOptions(512, 8));
34+
/// ```
35+
class TORCH_API TransformerImpl : public Cloneable<TransformerImpl> {
36+
37+
public:
38+
explicit TransformerImpl(TransformerOptions options_);
39+
40+
41+
/// forward function for Transformer Module
42+
/// Args:
43+
/// src: the sequence to the encoder (required).
44+
/// tgt: the sequence to the decoder (required).
45+
/// src_mask: the additive mask for the src sequence (optional).
46+
/// tgt_mask: the additive mask for the tgt sequence (optional).
47+
/// memory_mask: the additive mask for the encoder output (optional).
48+
/// src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
49+
/// tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
50+
/// memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
51+
///
52+
/// Shape:
53+
/// src: `(S, N, E)`
54+
/// tgt: `(T, N, E)`
55+
/// src_mask: `(S, S)`
56+
/// tgt_mask: `(T, T)`
57+
/// memory_mask: `(T, S)`
58+
/// src_key_padding_mask: `(N, S)`
59+
/// tgt_key_padding_mask: `(N, T)`
60+
/// memory_key_padding_mask: `(N, S)`
61+
///
62+
/// Note:
63+
/// [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
64+
/// positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
65+
/// while the zero positions will be unchanged. If a BoolTensor is provided, positions with `True`
66+
/// are not allowed to attend while `False` values will be unchanged. If a FloatTensor
67+
/// is provided, it will be added to the attention weight.
68+
///
69+
/// [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
70+
/// the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
71+
/// positions will be unchanged. If a BoolTensor is provided, the positions with the
72+
/// value of `True` will be ignored while the position with the value of `False` will be unchanged.
73+
///
74+
/// output: `(T, N, E)`
75+
///
76+
/// Note:
77+
/// Due to the multi-head attention architecture in the transformer model,
78+
/// the output sequence length of a transformer is same as the input sequence
79+
/// (i.e. target) length of the decode.
80+
///
81+
/// where
82+
/// S is the source sequence length,
83+
/// T is the target sequence length,
84+
/// N is the batch size,
85+
/// E is the feature number.
86+
Tensor forward(
87+
const Tensor& src,
88+
const Tensor& tgt,
89+
const Tensor& src_mask = {},
90+
const Tensor& tgt_mask = {},
91+
const Tensor& memory_mask = {},
92+
const Tensor& src_key_padding_mask = {},
93+
const Tensor& tgt_key_padding_mask = {},
94+
const Tensor& memory_key_padding_mask = {});
95+
96+
void reset() override;
97+
98+
void reset_parameters();
99+
100+
/// Generate a square mask for the sequence.
101+
/// The masked positions are filled with `-inf` in float type.
102+
/// Unmasked positions are filled with `0.0` in float type.
103+
/// Note:
104+
/// 1. This function will always return a CPU tensor.
105+
/// 2. This function requires the platform support IEEE754, since `-inf` is guaranteed to
106+
/// be valid only when IEEE754 is supported. If the platform doesn't support IEEE754,
107+
/// this function will fill the mask with the smallest float number instead of `-inf`,
108+
/// a one time warning will be pop up as well.
109+
static Tensor generate_square_subsequent_mask(int64_t sz);
110+
111+
protected:
112+
FORWARD_HAS_DEFAULT_ARGS(
113+
{2, AnyValue(Tensor())},
114+
{3, AnyValue(Tensor())},
115+
{4, AnyValue(Tensor())},
116+
{5, AnyValue(Tensor())},
117+
{6, AnyValue(Tensor())},
118+
{7, AnyValue(Tensor())})
119+
120+
public:
121+
/// options with which this `Transformer` was constructed
122+
TransformerOptions options;
123+
124+
/// encoder module
125+
AnyModule encoder;
126+
127+
/// decoder module
128+
AnyModule decoder;
129+
};
130+
131+
/// A `ModuleHolder` subclass for `TransformerImpl`.
132+
/// See the documentation for `TransformerImpl` class to learn what
133+
/// methods it provides, and examples of how to use `Transformer` with
134+
/// `torch::nn::TransformerOptions`.
135+
/// See the documentation for `ModuleHolder` to learn about PyTorch's
136+
/// module storage semantics.
137+
TORCH_MODULE(Transformer);
138+
139+
} // namespace nn
140+
} // namespace torch

0 commit comments

Comments
 (0)