-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathSequence.cs
More file actions
126 lines (107 loc) · 3.74 KB
/
Sequence.cs
File metadata and controls
126 lines (107 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright (c) TensorStack. All rights reserved.
// Licensed under the Apache 2.0 License.
using Microsoft.ML.OnnxRuntime;
using System;
using System.Collections.Generic;
using TensorStack.Common.Tensor;
using TensorStack.TextGeneration.Cache;
namespace TensorStack.TextGeneration
{
public sealed class Sequence : IDisposable
{
private IKVCache _cache;
private Tensor<float> _lastHiddenState;
/// <summary>
/// Initializes a new instance of the <see cref="Sequence"/> class.
/// </summary>
/// <param name="cache">The cache.</param>
/// <param name="bos">The bos.</param>
public Sequence(IKVCache cache, params List<long> startSequence)
{
_cache = cache;
Tokens = startSequence;
}
/// <summary>
/// Initializes a new instance of the <see cref="Sequence"/> class.
/// </summary>
/// <param name="tokens">The tokens.</param>
/// <param name="score">The score.</param>
/// <param name="cache">The cache.</param>
private Sequence(List<long> tokens, float score, IKVCache cache)
{
Score = score;
Tokens = tokens;
_cache = cache;
}
/// <summary>
/// Gets the tokens.
/// </summary>
public List<long> Tokens { get; }
/// <summary>
/// Gets or sets a value indicating whether this sequence is complete.
/// </summary>
public bool IsComplete { get; set; }
/// <summary>
/// Gets the sequence length.
/// </summary>
public int Length => Tokens.Count;
/// <summary>
/// Gets the cache.
/// </summary>
public OrtValue[] Cache => _cache.Values;
/// <summary>
/// Gets the LastHiddenState.
/// </summary>
public Tensor<float> LastHiddenState => _lastHiddenState;
/// <summary>
/// Gets or sets the sequnece score.
/// </summary>
public float Score { get; set; }
/// <summary>
/// Gets or sets the penalty score.
/// </summary>
public float PenaltyScore { get; set; }
/// <summary>
/// Returns true if the sequence is valid.
/// </summary>
public bool IsValid => !float.IsNegativeInfinity(Score);
/// <summary>
/// Initializes the sequence with the specified initial length.
/// </summary>
/// <param name="initialLength">The initial length.</param>
public bool Initialize(int initialLength)
{
var isInitialized = _cache.IsInitialized;
if (!isInitialized)
_cache.Initialize(initialLength);
return isInitialized;
}
/// <summary>
/// Updates the cache.
/// </summary>
/// <param name="currentValues">The current values.</param>
/// <param name="useBranchCache">if set to <c>true</c> use branch cache.</param>
public void UpdateCache(OrtValue[] currentValues, bool useBranchCache, Tensor<float> lastHiddenState = default)
{
_lastHiddenState = lastHiddenState;
_cache.Update(currentValues, useBranchCache);
}
/// <summary>
/// Clones this sequence.
/// </summary>
/// <returns>Sequence.</returns>
public Sequence Clone()
{
return new Sequence([.. Tokens], Score, _cache.Clone());
}
/// <summary>
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
/// </summary>
public void Dispose()
{
Tokens.Clear();
_cache?.Dispose();
_cache = null;
}
}
}