Skip to content

Commit 0595ea1

Browse files
committed
2 parents 512f3ce + f9d2604 commit 0595ea1

File tree

8 files changed

+26
-12
lines changed

8 files changed

+26
-12
lines changed

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
4848

4949
<ItemGroup>
5050
<PackageReference Include="Google.Protobuf" Version="3.8.0" />
51-
<PackageReference Include="NumSharp" Version="0.10.2" />
51+
<PackageReference Include="NumSharp" Version="0.10.3" />
5252
</ItemGroup>
5353

5454
<ItemGroup>

src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = S
1616
_write_version = write_version;
1717
}
1818

19+
/// <summary>
20+
/// Create an Op to save 'saveables'.
21+
/// </summary>
22+
/// <param name="filename_tensor"></param>
23+
/// <param name="saveables"></param>
24+
/// <returns></returns>
1925
public virtual Operation save_op(Tensor filename_tensor, SaveableObject[] saveables)
2026
{
2127
var tensor_names = new List<string>();
@@ -105,6 +111,10 @@ public virtual SaverDef _build_internal(VariableV1[] names_to_saveables,
105111
}
106112

107113
var graph = ops.get_default_graph();
114+
// Do some sanity checking on collections containing
115+
// PartitionedVariables. If a saved collection has a PartitionedVariable,
116+
// the GraphDef needs to include concat ops to get the value (or there'll
117+
// be a lookup error on load).
108118
var check_collection_list = graph.get_all_collection_keys();
109119
foreach (var collection_type in check_collection_list)
110120
{

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,10 @@ public string save(Session sess,
158158
string model_checkpoint_path = "";
159159
string checkpoint_file = "";
160160

161-
checkpoint_file = $"{save_path}-{global_step}";
161+
if (global_step > 0)
162+
checkpoint_file = $"{save_path}-{global_step}";
163+
else
164+
checkpoint_file = save_path;
162165

163166
var save_path_parent = Path.GetDirectoryName(save_path);
164167

@@ -291,15 +294,13 @@ private void _RecordLastCheckpoint(string latest_save_path)
291294
if (_saver_def.MaxToKeep <= 0) return;
292295

293296
// Remove first from list if the same name was used before.
294-
foreach (var p in _last_checkpoints)
295-
if (latest_save_path == _CheckpointFilename((p.Key, p.Value)))
296-
_last_checkpoints.Remove(p.Key);
297-
298-
// Append new path to list
299-
_last_checkpoints.Add(latest_save_path, Python.time());
297+
var _existed_checkpoints = _last_checkpoints.FirstOrDefault(p => latest_save_path == _CheckpointFilename((p.Key, p.Value)));
298+
if (_existed_checkpoints.Key != null)
299+
_last_checkpoints.Remove(_existed_checkpoints.Key);
300+
_last_checkpoints.Add(latest_save_path, time());
300301

301302
// If more than max_to_keep, remove oldest.
302-
if(_last_checkpoints.Count > _saver_def.MaxToKeep)
303+
if (_last_checkpoints.Count > _saver_def.MaxToKeep)
303304
{
304305
var first = _last_checkpoints.First();
305306
_last_checkpoints.Remove(first.Key);

src/TensorFlowNET.Core/Train/Saving/saver.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public static (Saver, object) _import_meta_graph_with_return_elements(string met
2525
var saver = _create_saver_from_imported_meta_graph(
2626
meta_graph_def, import_scope, imported_vars);
2727

28-
return (saver, null);
28+
return (saver, imported_return_elements);
2929
}
3030

3131
/// <summary>

test/KerasNET.Test/Keras.UnitTest.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
</PropertyGroup>
2727

2828
<ItemGroup>
29-
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.0" />
29+
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" />
3030
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
3131
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
3232
</ItemGroup>

test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ public bool Run()
105105
// Create a train saver that is used to restore values into an eval graph
106106
// when exporting models.
107107
var train_saver = tf.train.Saver();
108+
train_saver.save(sess, CHECKPOINT_NAME);
109+
108110
sw.Restart();
109111

110112
for (int i = 0; i < how_many_training_steps; i++)

test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />
1818
<PackageReference Include="SharpZipLib" Version="1.1.0" />
1919
<PackageReference Include="System.Drawing.Common" Version="4.5.1" />
20+
<PackageReference Include="TensorFlow.NET" Version="0.8.0" />
2021
</ItemGroup>
2122

2223
<ItemGroup>

test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
</PropertyGroup>
1717

1818
<ItemGroup>
19-
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
19+
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.1.1" />
2020
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
2121
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
2222
</ItemGroup>

0 commit comments

Comments
 (0)