-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathDeviceManager.cs
More file actions
140 lines (120 loc) · 5.12 KB
/
DeviceManager.cs
File metadata and controls
140 lines (120 loc) · 5.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright (c) TensorStack. All rights reserved.
// Licensed under the Apache 2.0 License.
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using System.Linq;
namespace TensorStack.Common
{
public static class DeviceManager
{
private static OrtEnv _environment;
private static EnvironmentCreationOptions _environmentOptions;
private static IReadOnlyList<Device> _devices;
private static string _deviceProvider;
/// <summary>
/// Initializes this instance.
/// </summary>
public static void Initialize(string executionProvider, string libraryPath = default)
{
Initialize(new EnvironmentCreationOptions
{
logId = "TensorStack",
threadOptions = new OrtThreadingOptions
{
GlobalSpinControl = true,
GlobalInterOpNumThreads = 1,
GlobalIntraOpNumThreads = 1
}
}, executionProvider, libraryPath);
}
/// <summary>
/// Initializes the specified environment options.
/// </summary>
/// <param name="environmentOptions">The environment options.</param>
public static void Initialize(EnvironmentCreationOptions environmentOptions, string executionProvider, string libraryPath = default)
{
if (_environment is not null)
throw new Exception("Environment is already initialized.");
_deviceProvider = executionProvider;
_environmentOptions = environmentOptions;
_environment = OrtEnv.CreateInstanceWithOptions(ref _environmentOptions);
var providers = _environment.GetAvailableProviders();
if (!providers.Contains(_deviceProvider, StringComparer.OrdinalIgnoreCase))
throw new Exception($"Provider {_deviceProvider} was not found in GetAvailableProviders().");
if (!string.IsNullOrEmpty(libraryPath))
_environment.RegisterExecutionProviderLibrary(_deviceProvider, libraryPath);
var devices = new List<Device>();
foreach (var epDevice in _environment.GetEpDevices())
{
if (epDevice.HardwareDevice.Type == OrtHardwareDeviceType.CPU || epDevice.EpName.Equals(_deviceProvider, StringComparison.OrdinalIgnoreCase))
devices.Add(CreateDevice(epDevice));
}
_devices = devices;
}
/// <summary>
/// Gets the devices.
/// </summary>
public static IReadOnlyList<Device> Devices => _devices;
/// <summary>
/// The cpu provider name
/// </summary>
public const string CPUProviderName = "CPUExecutionProvider";
/// <summary>
/// Creates the device.
/// </summary>
/// <param name="epDevice">The ep device.</param>
/// <returns>Device.</returns>
private static Device CreateDevice(OrtEpDevice epDevice)
{
var device = epDevice.HardwareDevice;
var metadata = device.Metadata.Entries;
return new Device
{
Id = metadata.ParseOrDefault("DxgiAdapterNumber", 0),
DeviceId = metadata.ParseOrDefault("DxgiHighPerformanceIndex", 0),
Type = Enum.Parse<DeviceType>(device.Type.ToString()),
Name = metadata.ParseOrDefault("Description", string.Empty),
Memory = metadata.ParseOrDefault("DxgiVideoMemory", 0, " MB"),
HardwareLUID = metadata.ParseOrDefault("LUID", 0),
HardwareID = (int)device.DeviceId,
HardwareVendor = device.Vendor,
HardwareVendorId = (int)device.VendorId,
};
}
/// <summary>
/// Parse Metadata values
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="metadata">The metadata.</param>
/// <param name="key">The key.</param>
/// <param name="defaultValue">The default value.</param>
/// <param name="replace">The replace.</param>
/// <returns>T.</returns>
private static T ParseOrDefault<T>(this IReadOnlyDictionary<string, string> metadata, string key, T defaultValue, string replace = null)
{
if (!metadata.ContainsKey(key))
return defaultValue;
var value = metadata[key].Trim();
if (!string.IsNullOrEmpty(replace))
value = value.Replace(replace, string.Empty);
if (typeof(T) == typeof(string))
{
return (T)(object)value;
}
else if (typeof(T) == typeof(int))
{
if (!int.TryParse(value, out var intResult))
return defaultValue;
return (T)(object)intResult;
}
else if (typeof(T) == typeof(Enum))
{
if (!Enum.TryParse(typeof(T), value, out var enumResult))
return defaultValue;
return (T)enumResult;
}
return defaultValue;
}
}
}