Skip to content

Commit 3f4b5b3

Browse files
committed
add transpose API eagerly
1 parent f7e61b0 commit 3f4b5b3

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,15 @@ public static Tensor tile<T>(Tensor input, T multiples, string name = null)
487487

488488
public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null)
489489
{
490+
if (tf.context.executing_eagerly())
491+
{
492+
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
493+
"Transpose", name,
494+
null,
495+
x, perm);
496+
497+
return results[0];
498+
}
490499
var _op = tf._op_def_lib._apply_op_helper("Transpose", name, new { x, perm });
491500
return _op.outputs[0];
492501
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using FluentAssertions;
2+
using Microsoft.VisualStudio.TestTools.UnitTesting;
3+
using NumSharp;
4+
using System.Linq;
5+
using Tensorflow;
6+
using static Tensorflow.Binding;
7+
8+
namespace Tensorflow.UnitTest.TF_API
9+
{
10+
[TestClass]
11+
public class TensorOperate
12+
{
13+
[TestMethod]
14+
public void TransposeTest()
15+
{
16+
var a = tf.constant(np.array(new[, , ,] { { { { 1, 11, 2, 22 } }, { { 3, 33, 4, 44 } } },
17+
{ { { 5, 55, 6, 66 } }, { { 7, 77, 8, 88 } } } }));
18+
var b = tf.transpose(a, new[] { 3, 1, 2, 0 });
19+
var transpose_a = tf.constant(np.array(new[, , ,] { { { { 1, 5 } }, { { 3, 7 } } },
20+
{ { { 11, 55 } }, { { 33, 77 } } }, { { { 2, 6 } }, { { 4, 8 } } },
21+
{ { { 22, 66 } }, { { 44, 88 } } } }));
22+
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 4, 2, 1, 2 }, b.shape));
23+
Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>()));
24+
}
25+
26+
27+
}
28+
}

0 commit comments

Comments
 (0)