88from lxml import etree
99from os .path import dirname , join , exists
1010import unittest
11+ from defusedxml .lxml import fromstring
1112from xml .dom .minidom import parseString
1213
1314from onelogin .saml2 import compat
@@ -613,7 +614,7 @@ def testDecryptElement(self):
613614 encrypted_nameid_nodes = dom_nameid_enc .find ('.//saml:EncryptedID' , namespaces = OneLogin_Saml2_Constants .NSMAP )
614615 encrypted_data = encrypted_nameid_nodes [0 ]
615616 decrypted_nameid = OneLogin_Saml2_Utils .decrypt_element (encrypted_data , key )
616- self .assertEqual ('{%s} NameID' % OneLogin_Saml2_Constants . NS_SAML , decrypted_nameid .tag )
617+ self .assertEqual ('saml: NameID' , decrypted_nameid .tag )
617618 self .assertEqual ('2de11defd199f8d5bb63f9b7deb265ba5c675c10' , decrypted_nameid .text )
618619
619620 xml_assertion_enc = b64decode (self .file_contents (join (self .data_path , 'responses' , 'valid_encrypted_assertion_encrypted_nameid.xml.base64' )))
@@ -636,24 +637,71 @@ def testDecryptElement(self):
636637 key2 = f .read ()
637638 f .close ()
638639
639- self .assertRaises (Exception , OneLogin_Saml2_Utils .decrypt_element , encrypted_data , key2 )
640+ # sp.key and sp2.key are equivalent we should be able to decrypt the nameID again
641+ decrypted_nameid = OneLogin_Saml2_Utils .decrypt_element (encrypted_data , key2 )
642+ self .assertIn ('{%s}NameID' % (OneLogin_Saml2_Constants .NS_SAML ), decrypted_nameid .tag )
643+ self .assertEqual ('457bdb600de717891c77647b0806ce59c089d5b8' , decrypted_nameid .text )
640644
641- key_3_file_name = join (self .data_path , 'misc' , 'sp2 .key' )
645+ key_3_file_name = join (self .data_path , 'misc' , 'sp3 .key' )
642646 f = open (key_3_file_name , 'r' )
643647 key3 = f .read ()
644648 f .close ()
645- self .assertRaises (Exception , OneLogin_Saml2_Utils .decrypt_element , encrypted_data , key3 )
649+
650+ # sp.key and sp3.key are equivalent we should be able to decrypt the nameID again
651+ decrypted_nameid = OneLogin_Saml2_Utils .decrypt_element (encrypted_data , key3 )
652+ self .assertIn ('{%s}NameID' % (OneLogin_Saml2_Constants .NS_SAML ), decrypted_nameid .tag )
653+ self .assertEqual ('457bdb600de717891c77647b0806ce59c089d5b8' , decrypted_nameid .text )
654+
655+ key_4_file_name = join (self .data_path , 'misc' , 'sp4.key' )
656+ f = open (key_4_file_name , 'r' )
657+ key4 = f .read ()
658+ f .close ()
659+
660+ with self .assertRaisesRegex (Exception , "(1, 'failed to decrypt')" ):
661+ OneLogin_Saml2_Utils .decrypt_element (encrypted_data , key4 )
662+
646663 xml_nameid_enc_2 = b64decode (self .file_contents (join (self .data_path , 'responses' , 'invalids' , 'encrypted_nameID_without_EncMethod.xml.base64' )))
647- dom_nameid_enc_2 = etree .fromstring (xml_nameid_enc_2 )
648- encrypted_nameid_nodes_2 = dom_nameid_enc_2 .find ('.//saml:EncryptedID' , namespaces = OneLogin_Saml2_Constants .NSMAP )
649- encrypted_data_2 = encrypted_nameid_nodes_2 [0 ]
650- self .assertRaises (Exception , OneLogin_Saml2_Utils .decrypt_element , encrypted_data_2 , key )
664+ dom_nameid_enc_2 = parseString (xml_nameid_enc_2 )
665+ encrypted_nameid_nodes_2 = dom_nameid_enc_2 .getElementsByTagName ('saml:EncryptedID' )
666+ encrypted_data_2 = encrypted_nameid_nodes_2 [0 ].firstChild
667+
668+ with self .assertRaisesRegex (Exception , "(1, 'failed to decrypt')" ):
669+ OneLogin_Saml2_Utils .decrypt_element (encrypted_data_2 , key )
651670
652671 xml_nameid_enc_3 = b64decode (self .file_contents (join (self .data_path , 'responses' , 'invalids' , 'encrypted_nameID_without_keyinfo.xml.base64' )))
653- dom_nameid_enc_3 = etree .fromstring (xml_nameid_enc_3 )
654- encrypted_nameid_nodes_3 = dom_nameid_enc_3 .find ('.//saml:EncryptedID' , namespaces = OneLogin_Saml2_Constants .NSMAP )
655- encrypted_data_3 = encrypted_nameid_nodes_3 [0 ]
656- self .assertRaises (Exception , OneLogin_Saml2_Utils .decrypt_element , encrypted_data_3 , key )
672+ dom_nameid_enc_3 = parseString (xml_nameid_enc_3 )
673+ encrypted_nameid_nodes_3 = dom_nameid_enc_3 .getElementsByTagName ('saml:EncryptedID' )
674+ encrypted_data_3 = encrypted_nameid_nodes_3 [0 ].firstChild
675+
676+ with self .assertRaisesRegex (Exception , "(1, 'failed to decrypt')" ):
677+ OneLogin_Saml2_Utils .decrypt_element (encrypted_data_3 , key )
678+
679+ def testDecryptElementInplace (self ):
680+ """
681+ Tests the decrypt_element method of the OneLogin_Saml2_Utils with inplace=True
682+ """
683+ settings = OneLogin_Saml2_Settings (self .loadSettingsJSON ())
684+
685+ key = settings .get_sp_key ()
686+
687+ xml_nameid_enc = b64decode (self .file_contents (join (self .data_path , 'responses' , 'response_encrypted_nameid.xml.base64' )))
688+ dom = fromstring (xml_nameid_enc )
689+ encrypted_node = dom .xpath ('//saml:EncryptedID/xenc:EncryptedData' , namespaces = OneLogin_Saml2_Constants .NSMAP )[0 ]
690+
691+ # can be decrypted twice when copy the node first
692+ for _ in range (2 ):
693+ decrypted_nameid = OneLogin_Saml2_Utils .decrypt_element (encrypted_node , key , inplace = False )
694+ self .assertIn ('NameID' , decrypted_nameid .tag )
695+ self .assertEqual ('2de11defd199f8d5bb63f9b7deb265ba5c675c10' , decrypted_nameid .text )
696+
697+ # can only be decrypted once in place
698+ decrypted_nameid = OneLogin_Saml2_Utils .decrypt_element (encrypted_node , key , inplace = True )
699+ self .assertIn ('NameID' , decrypted_nameid .tag )
700+ self .assertEqual ('2de11defd199f8d5bb63f9b7deb265ba5c675c10' , decrypted_nameid .text )
701+
702+ # can't be decrypted twice since it has been decrypted inplace
703+ with self .assertRaisesRegex (Exception , "(1, 'failed to decrypt')" ):
704+ OneLogin_Saml2_Utils .decrypt_element (encrypted_node , key , inplace = True )
657705
658706 def testAddSign (self ):
659707 """
0 commit comments