@@ -8,6 +8,16 @@ using namespace torch::nn;
88
99struct TransformerTest : torch::test::SeedingFixture {};
1010
11+ // a generic function to set constants for parameters so we have fixed result for deterministic test
12+ template <typename Model>
13+ void set_parameter_to_constants (Model& model, const torch::TensorOptions& tensor_options) {
14+ torch::NoGradGuard guard;
15+ for (auto & p : model->parameters ()) {
16+ auto sz = p.view (-1 ).size (0 );
17+ p.copy_ (torch::cos (torch::arange (0 , sz, tensor_options).view (p.sizes ())));
18+ }
19+ }
20+
1121// a generic function to provide consistent encoder/decoder layer for all the transformer tests
1222template <typename T_LAYER, typename T_OPTIONS>
1323T_LAYER get_a_test_layer (const torch::TensorOptions& tensor_options) {
@@ -23,13 +33,7 @@ T_LAYER get_a_test_layer(const torch::TensorOptions& tensor_options) {
2333 }
2434
2535 // set constant weights of the model
26- {
27- torch::NoGradGuard guard;
28- for (auto & p : layer->parameters ()) {
29- auto sz = p.view (-1 ).size (0 );
30- p.copy_ (torch::cos (torch::arange (0 , sz, tensor_options).view (p.sizes ())));
31- }
32- }
36+ set_parameter_to_constants<T_LAYER>(layer, tensor_options);
3337
3438 return layer;
3539}
@@ -579,25 +583,26 @@ TEST_F(TransformerTest, PrettyPrintTransformerEncoder) {
579583 " )" );
580584}
581585
586+
582587TEST_F (TransformerTest, PrettyPrintTransformerDecoderLayer) {
583588 ASSERT_EQ (
584589 c10::str (TransformerDecoderLayer (4 , 2 )),
585590 " torch::nn::TransformerDecoderLayerImpl(\n "
586591 " (self_attn): torch::nn::MultiheadAttention(\n "
587592 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n "
588593 " )\n "
589- " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n "
590- " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
591594 " (multihead_attn): torch::nn::MultiheadAttention(\n "
592595 " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n "
593596 " )\n "
594- " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n "
595- " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
596597 " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n "
597598 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n "
598599 " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n "
599- " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n "
600+ " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
601+ " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
600602 " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
603+ " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n "
604+ " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n "
605+ " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n "
601606 " )" );
602607}
603608
@@ -978,8 +983,7 @@ void transformer_decoder_test_helper(bool is_cuda) {
978983
979984 // Multiple layers with norm
980985 norm = LayerNorm (LayerNormOptions ({decoder_layer.get ()->options .d_model ()}));
981- model = TransformerDecoder (TransformerDecoderOptions (decoder_layer, 6 )
982- .norm (AnyModule (norm)));
986+ model = TransformerDecoder (TransformerDecoderOptions (decoder_layer, 6 ).norm (AnyModule (norm)));
983987 if (is_cuda) {
984988 model->to (torch::kCUDA );
985989 }
@@ -1036,54 +1040,129 @@ TEST_F(TransformerTest, PrettyPrintTransformerDecoder) {
10361040 " (layers): torch::nn::ModuleList(\n "
10371041 " (0): torch::nn::TransformerDecoderLayerImpl(\n "
10381042 " (self_attn): torch::nn::MultiheadAttention(\n "
1039- " (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1040- " bias=true)\n "
1043+ " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n "
10411044 " )\n "
1042- " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n "
1043- " (norm1): torch::nn::LayerNorm([4], eps=1e-05,"
1044- " elementwise_affine=true)\n "
10451045 " (multihead_attn): torch::nn::MultiheadAttention(\n "
1046- " (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1047- " bias=true)\n "
1046+ " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n "
10481047 " )\n "
1049- " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n "
1050- " (norm2): torch::nn::LayerNorm([4], eps=1e-05,"
1051- " elementwise_affine=true)\n "
1052- " (linear1): torch::nn::Linear(in_features=4, out_features=2048,"
1053- " bias=true)\n "
1048+ " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n "
10541049 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n "
1055- " (linear2): torch::nn::Linear(in_features=2048, out_features=4,"
1056- " bias=true)\n "
1050+ " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n "
1051+ " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
1052+ " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
1053+ " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
1054+ " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n "
1055+ " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n "
10571056 " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n "
1058- " (norm3): torch::nn::LayerNorm([4], eps=1e-05,"
1059- " elementwise_affine=true)\n "
10601057 " )\n "
10611058 " (1): torch::nn::TransformerDecoderLayerImpl(\n "
10621059 " (self_attn): torch::nn::MultiheadAttention(\n "
1063- " (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1064- " bias=true)\n "
1060+ " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n "
10651061 " )\n "
1066- " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n "
1067- " (norm1): torch::nn::LayerNorm([4], eps=1e-05,"
1068- " elementwise_affine=true)\n "
10691062 " (multihead_attn): torch::nn::MultiheadAttention(\n "
1070- " (out_proj): torch::nn::Linear(in_features=4, out_features=4,"
1071- " bias=true)\n "
1063+ " (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n "
10721064 " )\n "
1073- " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n "
1074- " (norm2): torch::nn::LayerNorm([4], eps=1e-05,"
1075- " elementwise_affine=true)\n "
1076- " (linear1): torch::nn::Linear(in_features=4, out_features=2048,"
1077- " bias=true)\n "
1065+ " (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n "
10781066 " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n "
1079- " (linear2): torch::nn::Linear(in_features=2048, out_features=4,"
1080- " bias=true)\n "
1067+ " (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n "
1068+ " (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
1069+ " (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
1070+ " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
1071+ " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n "
1072+ " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n "
10811073 " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n "
1082- " (norm3): torch::nn::LayerNorm([4], eps=1e-05,"
1083- " elementwise_affine=true)\n "
10841074 " )\n "
10851075 " )\n "
1086- " (norm): torch::nn::LayerNorm([4], eps=1e-05,"
1087- " elementwise_affine=true)\n "
1076+ " (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n "
10881077 " )" );
10891078}
1079+
1080+ void transformer_test_helper (bool is_cuda) {
1081+ // this is a deterministic test for Transformere
1082+ torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU ;
1083+ torch::TensorOptions tensor_options = torch::TensorOptions ().dtype (torch::kFloat32 ).device (device);
1084+
1085+ // transformer created encoder/decoder
1086+ Transformer model (TransformerOptions ()
1087+ .d_model (4 )
1088+ .nhead (2 )
1089+ .num_encoder_layers (2 )
1090+ .num_decoder_layers (1 )
1091+ .dim_feedforward (16 )
1092+ .dropout (0.0 )
1093+ .activation (torch::kReLU ));
1094+
1095+ set_parameter_to_constants<Transformer>(model, tensor_options);
1096+ if (tensor_options.device () == torch::kCUDA ) {
1097+ model->to (torch::kCUDA );
1098+ }
1099+
1100+ // transformer with customized encoder/decoder
1101+ LayerNorm enorm (LayerNormOptions ({4 }));
1102+ TransformerEncoder encoder (TransformerEncoderOptions (
1103+ TransformerEncoderLayerOptions (4 , 2 ).dim_feedforward (16 ).dropout (0.0 ), 2 ).norm (AnyModule (enorm)));
1104+
1105+ LayerNorm dnorm (LayerNormOptions ({4 }));
1106+ TransformerDecoder decoder (TransformerDecoderOptions (
1107+ TransformerDecoderLayerOptions (4 , 2 ).dim_feedforward (16 ).dropout (0.0 ), 1 ).norm (AnyModule (dnorm)));
1108+
1109+ Transformer model_cus (TransformerOptions ()
1110+ .d_model (4 )
1111+ .nhead (2 )
1112+ .custom_encoder (AnyModule (encoder))
1113+ .custom_decoder (AnyModule (decoder)));
1114+
1115+ set_parameter_to_constants<Transformer>(model_cus, tensor_options);
1116+ if (tensor_options.device () == torch::kCUDA ) {
1117+ model_cus->to (torch::kCUDA );
1118+ }
1119+
1120+ // test cases
1121+ torch::Tensor src = torch::tensor ({
1122+ {{1.0 , 2.0 , 3.0 , 4.0 }, {5.0 , 6.0 , 7.0 , 8.0 }},
1123+ {{9.0 , 10.0 , 11.0 , 12.0 }, {13.0 , 14.0 , 15.0 , 16.0 }},
1124+ {{17.0 , 18.0 , 19.0 , 20.0 }, {21.0 , 22.0 , 23.0 , 24.0 }}}, tensor_options);
1125+
1126+ torch::Tensor tgt = torch::tensor ({
1127+ {{1.0 , 2.0 , 3.0 , 4.0 }, {5.0 , 6.0 , 7.0 , 8.0 }},
1128+ {{9.0 , 10.0 , 11.0 , 12.0 }, {13.0 , 14.0 , 15.0 , 16.0 }}}, tensor_options);
1129+
1130+ torch::Tensor ref_output = torch::tensor ({
1131+ {{2.695875 , 0.347114 , -0.044355 , -0.549541 }, {2.696091 , 0.347015 , -0.044770 , -0.548522 }},
1132+ {{2.695875 , 0.347114 , -0.044355 , -0.549541 }, {2.696091 , 0.347015 , -0.044770 , -0.548522 }}}, tensor_options);
1133+ torch::Tensor result = model (src, tgt);
1134+ torch::Tensor result_cus = model_cus (src, tgt);
1135+ ASSERT_EQ (result.sizes (), ref_output.sizes ());
1136+ ASSERT_TRUE (result.equal (result_cus));
1137+ ASSERT_TRUE (torch::allclose (result, ref_output, 1e-7 , 1e-5 , /* equal_nan=*/ true ));
1138+
1139+ torch::Tensor src_mask = Transformer::Impl::generate_square_subsequent_mask (src.size (0 )).to (tensor_options);
1140+ ref_output = torch::tensor ({
1141+ {{2.695875 , 0.347114 , -0.044355 , -0.549541 }, {2.696091 , 0.347015 , -0.044770 , -0.548522 }},
1142+ {{2.695875 , 0.347114 , -0.044355 , -0.549541 }, {2.696091 , 0.347015 , -0.044770 , -0.548522 }}}, tensor_options);
1143+ result = model (src, tgt, src_mask);
1144+ result_cus = model_cus (src, tgt, src_mask);
1145+ ASSERT_EQ (result.sizes (), ref_output.sizes ());
1146+ ASSERT_TRUE (result.equal (result_cus));
1147+ ASSERT_TRUE (torch::allclose (result, ref_output, 1e-7 , 1e-5 , /* equal_nan=*/ true ));
1148+
1149+ torch::Tensor tgt_key_padding_mask = torch::zeros ({tgt.size (1 ), tgt.size (0 )}, tensor_options) == 1 ;
1150+ tgt_key_padding_mask[0 ][0 ] = 1 ;
1151+ tgt_key_padding_mask[1 ][1 ] = 1 ;
1152+ ref_output = torch::tensor ({
1153+ {{2.696114 , 0.347004 , -0.044813 , -0.548417 }, {2.696091 , 0.347015 , -0.044770 , -0.548522 }},
1154+ {{2.696114 , 0.347004 , -0.044813 , -0.548417 }, {2.696091 , 0.347015 , -0.044770 , -0.548522 }}}, tensor_options);
1155+ result = model (src, tgt, src_mask, torch::Tensor (), torch::Tensor (), torch::Tensor (), tgt_key_padding_mask);
1156+ result_cus = model_cus (src, tgt, src_mask, torch::Tensor (), torch::Tensor (), torch::Tensor (), tgt_key_padding_mask);
1157+ ASSERT_EQ (result.sizes (), ref_output.sizes ());
1158+ ASSERT_TRUE (result.equal (result_cus));
1159+ ASSERT_TRUE (torch::allclose (result, ref_output, 1e-7 , 1e-5 , /* equal_nan=*/ true ));
1160+ }
1161+
1162+ TEST_F (TransformerTest, Transformer) {
1163+ transformer_test_helper (false );
1164+ }
1165+
1166+ TEST_F (TransformerTest, Transformer_CUDA) {
1167+ transformer_test_helper (true );
1168+ }
0 commit comments