Skip to content

fix: fix the bug of load LSTM model and add test#1144

Merged
Oceania2018 merged 4 commits intoSciSharp:masterfrom
Wanglongzhi2001:master
Jul 13, 2023
Merged

fix: fix the bug of load LSTM model and add test#1144
Oceania2018 merged 4 commits intoSciSharp:masterfrom
Wanglongzhi2001:master

Conversation

@Wanglongzhi2001
Copy link
Contributor

@Wanglongzhi2001 Wanglongzhi2001 commented Jul 11, 2023

在使用反射来根据 metadata 来恢复 layer 的时候使用的参数是Tensorflow.Keras.ArgsDefinition.{class_name}Args,没有多余的.rnn,因此放在命名空间Tensorflow.Keras.ArgsDefinition.Rnn里的Args都读不出来,所以将所有的 layer 的 Args 都应该放在原来的Tensorflow.Keras.ArgsDefinition命名空间。同理,layer 也应该放在Tensorflow.Keras.Layers命名空间。并且所有需要恢复的 layer 的 args 里的参数都应该加上[JsonProperty] attribute,否则无法反序列化成功:

public static Layer deserialize_keras_object(string class_name, JToken config)
{
var argType = Assembly.Load("Tensorflow.Binding").GetType($"Tensorflow.Keras.ArgsDefinition.{class_name}Args");
if(argType is null)
{
return null;
}
var deserializationMethod = typeof(JToken).GetMethods(BindingFlags.Instance | BindingFlags.Public)
.Single(x => x.Name == "ToObject" && x.IsGenericMethodDefinition && x.GetParameters().Count() == 0);
var deserializationGenericMethod = deserializationMethod.MakeGenericMethod(argType);
var args = deserializationGenericMethod.Invoke(config, null);
var layer = Assembly.Load("Tensorflow.Keras").CreateInstance($"Tensorflow.Keras.Layers.{class_name}", true, BindingFlags.Default, null, new object[] { args }, null, null);

@Wanglongzhi2001 Wanglongzhi2001 changed the title fix:fix the bug of load LSTM model fix:fix the bug of load LSTM model and add test Jul 11, 2023
@Wanglongzhi2001 Wanglongzhi2001 changed the title fix:fix the bug of load LSTM model and add test fix: fix the bug of load LSTM model and add test Jul 11, 2023
@SanftMonster
Copy link
Collaborator

plz fix the ci error

@Oceania2018
Copy link
Member

  Failed LSTMLoad [26 s]
  Error Message:
   Test method Tensorflow.Keras.UnitTest.Model.ModelLoadTest.LSTMLoad threw exception: 
Google.Protobuf.InvalidProtocolBufferException: Protocol message contained an invalid tag (zero).
  Stack Trace:
      at Google.Protobuf.ParsingPrimitives.ParseTag(ReadOnlySpan`1& buffer, ParserInternalState& state)
   at Google.Protobuf.CodedInputStream.ReadTag()
   at ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedMetadata.MergeFrom(CodedInputStream input) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Protobuf\SavedMetadata.cs:line 158
   at Google.Protobuf.MessageExtensions.MergeFrom(IMessage message, Stream input, Boolean discardUnknownFields, ExtensionRegistry registry)
   at Google.Protobuf.MessageExtensions.MergeFrom(IMessage message, Stream input)
   at Tensorflow.Keras.Saving.SavedModel.KerasLoadModelUtils.load(String path, Boolean compile, LoadOptions options) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Saving\SavedModel\load.cs:line 55
   at Tensorflow.Keras.Saving.SavedModel.KerasLoadModelUtils.load_model(String filepath, IDictionary`2 custom_objects, Boolean compile, LoadOptions options) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Saving\SavedModel\load.cs:line 37
   at Tensorflow.Keras.Models.ModelsApi.load_model(String filepath, Boolean compile, LoadOptions options) in D:\a\TensorFlow.NET\TensorFlow.NET\src\TensorFlowNET.Keras\Models\ModelsApi.cs:line 19
   at Tensorflow.Keras.UnitTest.Model.ModelLoadTest.LSTMLoad() in D:\a\TensorFlow.NET\TensorFlow.NET\test\TensorFlowNET.Keras.UnitTest\Model\ModelLoadTest.cs:line 87

@Oceania2018 Oceania2018 merged commit d452d8c into SciSharp:master Jul 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working keras model save/load

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants