Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -587,4 +587,35 @@ public static void SetPowerShellAssemblyLoadContext([MarshalAs(UnmanagedType.LPW
PowerShellAssemblyLoadContext.InitializeSingleton(basePaths);
}
}

/// <summary>
/// Provides helper functions to faciliate calling managed code from a native PowerShell host.
/// </summary>
public static unsafe class PowerShellUnsafeAssemblyLoad
{
/// <summary>
/// Load an assembly in memory from unmanaged code.
/// </summary>
/// <remarks>
/// This API is covered by the experimental feature 'PSLoadAssemblyFromNativeCode',
/// and it may be deprecated and removed in future.
/// </remarks>
/// <param name="data">Unmanaged pointer to assembly data buffer.</param>
/// <param name="size">Size in bytes of the assembly data buffer.</param>
/// <returns>Returns zero on success and non-zero on failure.</returns>
[UnmanagedCallersOnly]
public static int LoadAssemblyFromNativeMemory(IntPtr data, int size)
{
try
{
using var stream = new UnmanagedMemoryStream((byte*)data, size);
AssemblyLoadContext.Default.LoadFromStream(stream);
return 0;
}
catch
{
return -1;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ static ExperimentalFeature()
new ExperimentalFeature(
name: PSNativeCommandArgumentPassingFeatureName,
description: "Use ArgumentList when invoking a native command"),
new ExperimentalFeature(
name: "PSLoadAssemblyFromNativeCode",
description: "Expose an API to allow assembly loading from native code"),
};

EngineExperimentalFeatures = new ReadOnlyCollection<ExperimentalFeature>(engineFeatures);

// Initialize the readonly dictionary 'EngineExperimentalFeatureMap'.
Expand Down
107 changes: 107 additions & 0 deletions test/xUnit/csharp/test_NativeInterop.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;
using System.Management.Automation;
using Xunit;

using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Emit;
using Microsoft.CodeAnalysis.Text;

namespace PSTests.Sequential
{
public static class NativeInterop
{
[Fact]
public static void TestLoadNativeInMemoryAssembly()
{
string tempDir = Path.Combine(Path.GetTempPath(), "TestLoadNativeInMemoryAssembly");
string testDll = Path.Combine(tempDir, "test.dll");

if (!File.Exists(testDll))
{
Directory.CreateDirectory(tempDir);
bool result = CreateTestDll(testDll);
Assert.True(result, "The call to 'CreateTestDll' should be successful and return true.");
Assert.True(File.Exists(testDll), "The test assembly should be created.");
}

var asmName = AssemblyName.GetAssemblyName(testDll);
string asmFullName = SearchAssembly(asmName.Name);
Assert.Null(asmFullName);

unsafe
{
int ret = LoadAssemblyTest(testDll);
Assert.Equal(0, ret);
}

asmFullName = SearchAssembly(asmName.Name);
Assert.Equal(asmName.FullName, asmFullName);
}

private static unsafe int LoadAssemblyTest(string assemblyPath)
{
// The 'LoadAssemblyFromNativeMemory' method is annotated with 'UnmanagedCallersOnly' attribute,
// so we have to use the 'unmanaged' function pointer to invoke it.
delegate* unmanaged<IntPtr, int, int> funcPtr = &PowerShellUnsafeAssemblyLoad.LoadAssemblyFromNativeMemory;

int length = 0;
IntPtr nativeMem = IntPtr.Zero;

try
{
using (var fileStream = new FileStream(assemblyPath, FileMode.Open, FileAccess.Read))
{
length = (int)fileStream.Length;
nativeMem = Marshal.AllocHGlobal(length);

using var unmanagedStream = new UnmanagedMemoryStream((byte*)nativeMem, length, length, FileAccess.Write);
fileStream.CopyTo(unmanagedStream);
}

// Call the function pointer.
return funcPtr(nativeMem, length);
}
finally
{
// Free the native memory
Marshal.FreeHGlobal(nativeMem);
}
}

private static string SearchAssembly(string assemblyName)
{
Assembly asm = AssemblyLoadContext.Default.Assemblies.FirstOrDefault(
assembly => assembly.FullName.StartsWith(assemblyName, StringComparison.OrdinalIgnoreCase));

return asm?.FullName;
}

private static bool CreateTestDll(string dllPath)
{
var parseOptions = CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest);
var compilationOptions = new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary);

List<SyntaxTree> syntaxTrees = new();
SourceText sourceText = SourceText.From("public class Utt { }");
syntaxTrees.Add(CSharpSyntaxTree.ParseText(sourceText, parseOptions));

var refs = new List<PortableExecutableReference> { MetadataReference.CreateFromFile(typeof(object).Assembly.Location) };
Compilation compilation = CSharpCompilation.Create(
Path.GetRandomFileName(),
syntaxTrees: syntaxTrees,
references: refs,
options: compilationOptions);

using var fs = new FileStream(dllPath, FileMode.CreateNew, FileAccess.ReadWrite, FileShare.None);
EmitResult emitResult = compilation.Emit(peStream: fs, options: null);
return emitResult.Success;
}
}
}