Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ public int MaximumFollowRelLink

#region Helper Methods

private bool TryProcessFeedStream(BufferingStreamReader responseStream)
private bool TryProcessFeedStream(Stream responseStream)
{
bool isRssOrFeed = false;

Expand Down Expand Up @@ -382,95 +382,95 @@ internal override void ProcessResponse(HttpResponseMessage response)
{
if (response == null) { throw new ArgumentNullException("response"); }

using (BufferingStreamReader responseStream = new BufferingStreamReader(StreamHelper.GetResponseStream(response)))
var baseResponseStream = StreamHelper.GetResponseStream(response);

if (ShouldWriteToPipeline)
{
if (ShouldWriteToPipeline)
using var responseStream = new BufferingStreamReader(baseResponseStream);

// First see if it is an RSS / ATOM feed, in which case we can
// stream it - unless the user has overridden it with a return type of "XML"
if (TryProcessFeedStream(responseStream))
{
// First see if it is an RSS / ATOM feed, in which case we can
// stream it - unless the user has overridden it with a return type of "XML"
if (TryProcessFeedStream(responseStream))
// Do nothing, content has been processed.
}
else
{
// determine the response type
RestReturnType returnType = CheckReturnType(response);

// Try to get the response encoding from the ContentType header.
Encoding encoding = null;
string charSet = response.Content.Headers.ContentType?.CharSet;
if (!string.IsNullOrEmpty(charSet))
{
// Do nothing, content has been processed.
// NOTE: Don't use ContentHelper.GetEncoding; it returns a
// default which bypasses checking for a meta charset value.
StreamHelper.TryGetEncoding(charSet, out encoding);
}
else
{
// determine the response type
RestReturnType returnType = CheckReturnType(response);

// Try to get the response encoding from the ContentType header.
Encoding encoding = null;
string charSet = response.Content.Headers.ContentType?.CharSet;
if (!string.IsNullOrEmpty(charSet))
{
// NOTE: Don't use ContentHelper.GetEncoding; it returns a
// default which bypasses checking for a meta charset value.
StreamHelper.TryGetEncoding(charSet, out encoding);
}

if (string.IsNullOrEmpty(charSet) && returnType == RestReturnType.Json)
{
encoding = Encoding.UTF8;
}

object obj = null;
Exception ex = null;
if (string.IsNullOrEmpty(charSet) && returnType == RestReturnType.Json)
{
encoding = Encoding.UTF8;
}

string str = StreamHelper.DecodeStream(responseStream, ref encoding);
object obj = null;
Exception ex = null;

string encodingVerboseName;
try
{
encodingVerboseName = string.IsNullOrEmpty(encoding.HeaderName) ? encoding.EncodingName : encoding.HeaderName;
}
catch (NotSupportedException)
{
encodingVerboseName = encoding.EncodingName;
}
// NOTE: Tests use this verbose output to verify the encoding.
WriteVerbose(string.Format
(
System.Globalization.CultureInfo.InvariantCulture,
"Content encoding: {0}",
encodingVerboseName)
);
bool convertSuccess = false;

if (returnType == RestReturnType.Json)
{
convertSuccess = TryConvertToJson(str, out obj, ref ex) || TryConvertToXml(str, out obj, ref ex);
}
// default to try xml first since it's more common
else
{
convertSuccess = TryConvertToXml(str, out obj, ref ex) || TryConvertToJson(str, out obj, ref ex);
}
string str = StreamHelper.DecodeStream(responseStream, ref encoding);

if (!convertSuccess)
{
// fallback to string
obj = str;
}
string encodingVerboseName;
try
{
encodingVerboseName = string.IsNullOrEmpty(encoding.HeaderName) ? encoding.EncodingName : encoding.HeaderName;
}
catch (NotSupportedException)
{
encodingVerboseName = encoding.EncodingName;
}
// NOTE: Tests use this verbose output to verify the encoding.
WriteVerbose(string.Format
(
System.Globalization.CultureInfo.InvariantCulture,
"Content encoding: {0}",
encodingVerboseName)
);
bool convertSuccess = false;

if (returnType == RestReturnType.Json)
{
convertSuccess = TryConvertToJson(str, out obj, ref ex) || TryConvertToXml(str, out obj, ref ex);
}
// default to try xml first since it's more common
else
{
convertSuccess = TryConvertToXml(str, out obj, ref ex) || TryConvertToJson(str, out obj, ref ex);
}

WriteObject(obj);
if (!convertSuccess)
{
// fallback to string
obj = str;
}
}

if (ShouldSaveToOutFile)
{
StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this);
WriteObject(obj);
}
}
else if (ShouldSaveToOutFile)
{
StreamHelper.SaveStreamToFile(baseResponseStream, QualifiedOutFile, this, _cancelToken.Token);
}

if (!string.IsNullOrEmpty(StatusCodeVariable))
{
PSVariableIntrinsics vi = SessionState.PSVariable;
vi.Set(StatusCodeVariable, (int)response.StatusCode);
}
if (!string.IsNullOrEmpty(StatusCodeVariable))
{
PSVariableIntrinsics vi = SessionState.PSVariable;
vi.Set(StatusCodeVariable, (int)response.StatusCode);
}

if (!string.IsNullOrEmpty(ResponseHeadersVariable))
{
PSVariableIntrinsics vi = SessionState.PSVariable;
vi.Set(ResponseHeadersVariable, WebResponseHelper.GetHeadersDictionary(response));
}
if (!string.IsNullOrEmpty(ResponseHeadersVariable))
{
PSVariableIntrinsics vi = SessionState.PSVariable;
vi.Set(ResponseHeadersVariable, WebResponseHelper.GetHeadersDictionary(response));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ public abstract partial class WebRequestPSCmdlet : PSCmdlet
/// <summary>
/// Cancellation token source.
/// </summary>
private CancellationTokenSource _cancelToken = null;
internal CancellationTokenSource _cancelToken = null;

/// <summary>
/// Parse Rel Links.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ internal override void ProcessResponse(HttpResponseMessage response)

if (ShouldSaveToOutFile)
{
StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this);
StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this, _cancelToken.Token);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
using System.Net.Http;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.PowerShell.Commands
{
Expand Down Expand Up @@ -99,7 +101,7 @@ public override long Length
/// <param name="bufferSize"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public override System.Threading.Tasks.Task CopyToAsync(Stream destination, int bufferSize, System.Threading.CancellationToken cancellationToken)
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
Initialize();
return base.CopyToAsync(destination, bufferSize, cancellationToken);
Expand All @@ -124,7 +126,7 @@ public override int Read(byte[] buffer, int offset, int count)
/// <param name="count"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public override System.Threading.Tasks.Task<int> ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken)
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Initialize();
return base.ReadAsync(buffer, offset, count, cancellationToken);
Expand Down Expand Up @@ -175,7 +177,7 @@ public override void Write(byte[] buffer, int offset, int count)
/// <param name="count"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken)
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Initialize();
return base.WriteAsync(buffer, offset, count, cancellationToken);
Expand Down Expand Up @@ -273,73 +275,55 @@ internal static class StreamHelper

#region Static Methods

internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet)
internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet, CancellationToken cancellationToken)
{
byte[] data = new byte[ChunkSize];
if (cmdlet == null)
{
throw new ArgumentNullException(nameof(cmdlet));
}

int read = 0;
long totalWritten = 0;
do
Task copyTask = input.CopyToAsync(output, cancellationToken);

ProgressRecord record = new ProgressRecord(
ActivityId,
WebCmdletStrings.WriteRequestProgressActivity,
WebCmdletStrings.WriteRequestProgressStatus);
try
{
if (cmdlet != null)
do
{
ProgressRecord record = new ProgressRecord(ActivityId,
WebCmdletStrings.WriteRequestProgressActivity,
StringUtil.Format(WebCmdletStrings.WriteRequestProgressStatus, totalWritten));
record.StatusDescription = StringUtil.Format(WebCmdletStrings.WriteRequestProgressStatus, output.Position);
cmdlet.WriteProgress(record);
}

read = input.Read(data, 0, ChunkSize);
Task.Delay(1000).Wait(cancellationToken);
}
while (!copyTask.IsCompleted && !cancellationToken.IsCancellationRequested);

if (0 < read)
if (copyTask.IsCompleted)
{
output.Write(data, 0, read);
totalWritten += read;
record.StatusDescription = StringUtil.Format(WebCmdletStrings.WriteRequestComplete, output.Position);
cmdlet.WriteProgress(record);
}
} while (read != 0);

if (cmdlet != null)
}
catch (OperationCanceledException)
{
ProgressRecord record = new ProgressRecord(ActivityId,
WebCmdletStrings.WriteRequestProgressActivity,
StringUtil.Format(WebCmdletStrings.WriteRequestComplete, totalWritten));
record.RecordType = ProgressRecordType.Completed;
cmdlet.WriteProgress(record);
}

output.Flush();
}

internal static void WriteToStream(byte[] input, Stream output)
{
output.Write(input, 0, input.Length);
output.Flush();
}

/// <summary>
/// Saves content from stream into filePath.
/// Caller need to ensure <paramref name="stream"/> position is properly set.
/// </summary>
/// <param name="stream"></param>
/// <param name="filePath"></param>
/// <param name="cmdlet"></param>
internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet)
/// <param name="stream">Input stream.</param>
/// <param name="filePath">Output file name.</param>
/// <param name="cmdlet">Current cmdlet (Invoke-WebRequest or Invoke-RestMethod).</param>
/// <param name="cancellationToken">CancellationToken to track the cmdlet cancellation.</param>
internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet, CancellationToken cancellationToken)
{
// If the web cmdlet should resume, append the file instead of overwriting.
if (cmdlet is WebRequestPSCmdlet webCmdlet && webCmdlet.ShouldResume)
{
using (FileStream output = new FileStream(filePath, FileMode.Append, FileAccess.Write, FileShare.Read))
{
WriteToStream(stream, output, cmdlet);
}
}
else
{
using (FileStream output = File.Create(filePath))
{
WriteToStream(stream, output, cmdlet);
}
}
FileMode fileMode = cmdlet is WebRequestPSCmdlet webCmdlet && webCmdlet.ShouldResume ? FileMode.Append : FileMode.Create;
using FileStream output = new FileStream(filePath, fileMode, FileAccess.Write, FileShare.Read);
WriteToStream(stream, output, cmdlet, cancellationToken);
}

private static string StreamToString(Stream stream, Encoding encoding)
Expand Down