Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -276,26 +276,55 @@ internal static class StreamHelper

#region Static Methods

internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet, CancellationToken cancellationToken)
internal static void WriteToStream(Stream input, Stream output, WebRequestPSCmdlet cmdlet, CancellationToken cancellationToken)
{
if (cmdlet == null)
{
throw new ArgumentNullException(nameof(cmdlet));
}

Task copyTask = input.CopyToAsync(output, cancellationToken);

ProgressRecord record = new(
ActivityId,
WebCmdletStrings.WriteRequestProgressActivity,
WebCmdletStrings.WriteRequestProgressStatus);

// If the client loses network connectivity for over 1 minute, the task returned from
// 'CopyToAsync()' will never complete and therefore the loop below will be infinite.
// Short explaination:
// - CopyToAsync() uses ReadAsync() in loop until ReadAsync() returns 0, indicating EOF.
// - ReadAsync() read from network Socket.
// - If losing network connectivity for over 1 minute, the Socket will lose connect to target service
// but OS never close the Socket (tested on Windows).
// - Then Socket never return anything, and thus ReadAsync() will be infinitely blocked on the Socket
// and CopyToAsync() task will be never completed.
//
// Since cancelation token in CopyToAsync() applies to whole CopyToAsync()
// and it is not restarted for very internal ReadAsync()
// we have to use a workaround to cancel the ReadAsync() which is blocked on Socket longer Timeout.
// The workaround is to reset the cancelation timer
// while the length of the output file is constantly increasing.
long previousLength = 0;
int timeout = cmdlet.TimeoutSec == 0 ? Timeout.Infinite : cmdlet.TimeoutSec * 1000;
using var timeoutCTS = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
timeoutCTS.CancelAfter(timeout);
Task copyTask = input.CopyToAsync(output, timeoutCTS.Token);

try
{
while (!copyTask.Wait(1000, cancellationToken))
// The Wait(TimeSpan, CancellationToken) overload checks CTS cancellation first to avoid race condition.
TimeSpan waitTime = new TimeSpan(0, 0, seconds: 1);
while (!copyTask.Wait(waitTime, timeoutCTS.Token))
{
record.StatusDescription = StringUtil.Format(WebCmdletStrings.WriteRequestProgressStatus, output.Position);
cmdlet.WriteProgress(record);

if (previousLength != output.Length)
{
// Reset cancelation timer while information continues to flow from network.
// Cancelation timer applies only during no network connectivity.
previousLength = output.Length;
timeoutCTS.CancelAfter(timeout);
}
}

if (copyTask.IsCompleted)
Expand All @@ -304,8 +333,10 @@ internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet,
cmdlet.WriteProgress(record);
}
}
catch (OperationCanceledException)
catch (OperationCanceledException exc) when (!cancellationToken.IsCancellationRequested && timeoutCTS.IsCancellationRequested)
{
ErrorRecord er = new(new TimeoutException(message: WebCmdletStrings.RequestTimeout, exc), "OperationTimeout", ErrorCategory.OperationTimeout, input);
cmdlet.WriteError(er);
}
}

Expand All @@ -317,7 +348,7 @@ internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet,
/// <param name="filePath">Output file name.</param>
/// <param name="cmdlet">Current cmdlet (Invoke-WebRequest or Invoke-RestMethod).</param>
/// <param name="cancellationToken">CancellationToken to track the cmdlet cancellation.</param>
internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet, CancellationToken cancellationToken)
internal static void SaveStreamToFile(Stream stream, string filePath, WebRequestPSCmdlet cmdlet, CancellationToken cancellationToken)
{
// If the web cmdlet should resume, append the file instead of overwriting.
FileMode fileMode = cmdlet is WebRequestPSCmdlet webCmdlet && webCmdlet.ShouldResume ? FileMode.Append : FileMode.Create;
Expand Down