Merge pull request #847 from Bond-009/async

Make websockets code async
This commit is contained in:
Vasily 2019-02-20 15:03:42 +03:00 committed by GitHub
commit 8ef41020d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 288 additions and 273 deletions

View File

@ -44,10 +44,11 @@ namespace Jellyfin.Server.SocketSharp
socket.OnMessage += OnSocketMessage; socket.OnMessage += OnSocketMessage;
socket.OnClose += OnSocketClose; socket.OnClose += OnSocketClose;
socket.OnError += OnSocketError; socket.OnError += OnSocketError;
WebSocket.ConnectAsServer();
} }
public Task ConnectAsServerAsync()
=> WebSocket.ConnectAsServer();
public Task StartReceive() public Task StartReceive()
{ {
return _taskCompletionSource.Task; return _taskCompletionSource.Task;
@ -133,7 +134,7 @@ namespace Jellyfin.Server.SocketSharp
_cancellationTokenSource.Cancel(); _cancellationTokenSource.Cancel();
WebSocket.Close(); WebSocket.CloseAsync().GetAwaiter().GetResult();
} }
_disposed = true; _disposed = true;

View File

@ -69,7 +69,7 @@ namespace Jellyfin.Server.SocketSharp
{ {
if (_listener == null) if (_listener == null)
{ {
_listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _networkManager, _streamHelper, _fileSystem, _environment); _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _streamHelper, _fileSystem, _environment);
} }
_listener.EnableDualMode = _enableDualMode; _listener.EnableDualMode = _enableDualMode;
@ -79,22 +79,14 @@ namespace Jellyfin.Server.SocketSharp
_listener.LoadCert(_certificate); _listener.LoadCert(_certificate);
} }
foreach (var prefix in urlPrefixes) _logger.LogInformation("Adding HttpListener prefixes {Prefixes}", urlPrefixes);
{ _listener.Prefixes.AddRange(urlPrefixes);
_logger.LogInformation("Adding HttpListener prefix " + prefix);
_listener.Prefixes.Add(prefix);
}
_listener.OnContext = ProcessContext; _listener.OnContext = async c => await InitTask(c, _disposeCancellationToken).ConfigureAwait(false);
_listener.Start(); _listener.Start();
} }
private void ProcessContext(HttpListenerContext context)
{
_ = Task.Run(async () => await InitTask(context, _disposeCancellationToken).ConfigureAwait(false));
}
private static void LogRequest(ILogger logger, HttpListenerRequest request) private static void LogRequest(ILogger logger, HttpListenerRequest request)
{ {
var url = request.Url.ToString(); var url = request.Url.ToString();
@ -151,10 +143,7 @@ namespace Jellyfin.Server.SocketSharp
Endpoint = endpoint Endpoint = endpoint
}; };
if (WebSocketConnecting != null) WebSocketConnecting?.Invoke(connectingArgs);
{
WebSocketConnecting(connectingArgs);
}
if (connectingArgs.AllowConnection) if (connectingArgs.AllowConnection)
{ {
@ -165,6 +154,7 @@ namespace Jellyfin.Server.SocketSharp
if (WebSocketConnected != null) if (WebSocketConnected != null)
{ {
var socket = new SharpWebSocket(webSocketContext.WebSocket, _logger); var socket = new SharpWebSocket(webSocketContext.WebSocket, _logger);
await socket.ConnectAsServerAsync().ConfigureAwait(false);
WebSocketConnected(new WebSocketConnectEventArgs WebSocketConnected(new WebSocketConnectEventArgs
{ {
@ -174,7 +164,7 @@ namespace Jellyfin.Server.SocketSharp
Endpoint = endpoint Endpoint = endpoint
}); });
await ReceiveWebSocket(ctx, socket).ConfigureAwait(false); await ReceiveWebSocketAsync(ctx, socket).ConfigureAwait(false);
} }
} }
else else
@ -192,7 +182,7 @@ namespace Jellyfin.Server.SocketSharp
} }
} }
private async Task ReceiveWebSocket(HttpListenerContext ctx, SharpWebSocket socket) private async Task ReceiveWebSocketAsync(HttpListenerContext ctx, SharpWebSocket socket)
{ {
try try
{ {

View File

@ -74,18 +74,20 @@ namespace SocketHttpListener
} }
} }
private static byte[] readBytes(this Stream stream, byte[] buffer, int offset, int length) private static async Task<byte[]> ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length)
{ {
var len = stream.Read(buffer, offset, length); var len = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false);
if (len < 1) if (len < 1)
return buffer.SubArray(0, offset); return buffer.SubArray(0, offset);
var tmp = 0; var tmp = 0;
while (len < length) while (len < length)
{ {
tmp = stream.Read(buffer, offset + len, length - len); tmp = await stream.ReadAsync(buffer, offset + len, length - len).ConfigureAwait(false);
if (tmp < 1) if (tmp < 1)
{
break; break;
}
len += tmp; len += tmp;
} }
@ -95,10 +97,9 @@ namespace SocketHttpListener
: buffer; : buffer;
} }
private static bool readBytes( private static async Task<bool> ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length, Stream dest)
this Stream stream, byte[] buffer, int offset, int length, Stream dest)
{ {
var bytes = stream.readBytes(buffer, offset, length); var bytes = await stream.ReadBytesAsync(buffer, offset, length).ConfigureAwait(false);
var len = bytes.Length; var len = bytes.Length;
dest.Write(bytes, 0, len); dest.Write(bytes, 0, len);
@ -109,16 +110,16 @@ namespace SocketHttpListener
#region Internal Methods #region Internal Methods
internal static byte[] Append(this ushort code, string reason) internal static async Task<byte[]> AppendAsync(this ushort code, string reason)
{ {
using (var buffer = new MemoryStream()) using (var buffer = new MemoryStream())
{ {
var tmp = code.ToByteArrayInternally(ByteOrder.Big); var tmp = code.ToByteArrayInternally(ByteOrder.Big);
buffer.Write(tmp, 0, 2); await buffer.WriteAsync(tmp, 0, 2).ConfigureAwait(false);
if (reason != null && reason.Length > 0) if (reason != null && reason.Length > 0)
{ {
tmp = Encoding.UTF8.GetBytes(reason); tmp = Encoding.UTF8.GetBytes(reason);
buffer.Write(tmp, 0, tmp.Length); await buffer.WriteAsync(tmp, 0, tmp.Length).ConfigureAwait(false);
} }
return buffer.ToArray(); return buffer.ToArray();
@ -331,12 +332,10 @@ namespace SocketHttpListener
: string.Format("\"{0}\"", value.Replace("\"", "\\\"")); : string.Format("\"{0}\"", value.Replace("\"", "\\\""));
} }
internal static byte[] ReadBytes(this Stream stream, int length) internal static Task<byte[]> ReadBytesAsync(this Stream stream, int length)
{ => stream.ReadBytesAsync(new byte[length], 0, length);
return stream.readBytes(new byte[length], 0, length);
}
internal static byte[] ReadBytes(this Stream stream, long length, int bufferLength) internal static async Task<byte[]> ReadBytesAsync(this Stream stream, long length, int bufferLength)
{ {
using (var result = new MemoryStream()) using (var result = new MemoryStream())
{ {
@ -347,7 +346,7 @@ namespace SocketHttpListener
var end = false; var end = false;
for (long i = 0; i < count; i++) for (long i = 0; i < count; i++)
{ {
if (!stream.readBytes(buffer, 0, bufferLength, result)) if (!await stream.ReadBytesAsync(buffer, 0, bufferLength, result).ConfigureAwait(false))
{ {
end = true; end = true;
break; break;
@ -355,26 +354,14 @@ namespace SocketHttpListener
} }
if (!end && rem > 0) if (!end && rem > 0)
stream.readBytes(new byte[rem], 0, rem, result); {
await stream.ReadBytesAsync(new byte[rem], 0, rem, result).ConfigureAwait(false);
}
return result.ToArray(); return result.ToArray();
} }
} }
internal static async Task<byte[]> ReadBytesAsync(this Stream stream, int length)
{
var buffer = new byte[length];
var len = await stream.ReadAsync(buffer, 0, length).ConfigureAwait(false);
var bytes = len < 1
? new byte[0]
: len < length
? stream.readBytes(buffer, len, length - len)
: buffer;
return bytes;
}
internal static string RemovePrefix(this string value, params string[] prefixes) internal static string RemovePrefix(this string value, params string[] prefixes)
{ {
var i = 0; var i = 0;
@ -493,19 +480,16 @@ namespace SocketHttpListener
return string.Format("{0}; {1}", m, parameters.ToString("; ")); return string.Format("{0}; {1}", m, parameters.ToString("; "));
} }
internal static List<TSource> ToList<TSource>(this IEnumerable<TSource> source)
{
return new List<TSource>(source);
}
internal static ushort ToUInt16(this byte[] src, ByteOrder srcOrder) internal static ushort ToUInt16(this byte[] src, ByteOrder srcOrder)
{ {
return BitConverter.ToUInt16(src.ToHostOrder(srcOrder), 0); src.ToHostOrder(srcOrder);
return BitConverter.ToUInt16(src, 0);
} }
internal static ulong ToUInt64(this byte[] src, ByteOrder srcOrder) internal static ulong ToUInt64(this byte[] src, ByteOrder srcOrder)
{ {
return BitConverter.ToUInt64(src.ToHostOrder(srcOrder), 0); src.ToHostOrder(srcOrder);
return BitConverter.ToUInt64(src, 0);
} }
internal static string TrimEndSlash(this string value) internal static string TrimEndSlash(this string value)
@ -852,14 +836,17 @@ namespace SocketHttpListener
/// <exception cref="ArgumentNullException"> /// <exception cref="ArgumentNullException">
/// <paramref name="src"/> is <see langword="null"/>. /// <paramref name="src"/> is <see langword="null"/>.
/// </exception> /// </exception>
public static byte[] ToHostOrder(this byte[] src, ByteOrder srcOrder) public static void ToHostOrder(this byte[] src, ByteOrder srcOrder)
{ {
if (src == null) if (src == null)
{
throw new ArgumentNullException(nameof(src)); throw new ArgumentNullException(nameof(src));
}
return src.Length > 1 && !srcOrder.IsHostOrder() if (src.Length > 1 && !srcOrder.IsHostOrder())
? src.Reverse() {
: src; Array.Reverse(src);
}
} }
/// <summary> /// <summary>

View File

@ -3,7 +3,6 @@ using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Net; using System.Net;
using System.Security.Cryptography.X509Certificates; using System.Security.Cryptography.X509Certificates;
using MediaBrowser.Common.Net;
using MediaBrowser.Model.Cryptography; using MediaBrowser.Model.Cryptography;
using MediaBrowser.Model.IO; using MediaBrowser.Model.IO;
using MediaBrowser.Model.Net; using MediaBrowser.Model.Net;
@ -18,47 +17,55 @@ namespace SocketHttpListener.Net
internal ISocketFactory SocketFactory { get; private set; } internal ISocketFactory SocketFactory { get; private set; }
internal IFileSystem FileSystem { get; private set; } internal IFileSystem FileSystem { get; private set; }
internal IStreamHelper StreamHelper { get; private set; } internal IStreamHelper StreamHelper { get; private set; }
internal INetworkManager NetworkManager { get; private set; }
internal IEnvironmentInfo EnvironmentInfo { get; private set; } internal IEnvironmentInfo EnvironmentInfo { get; private set; }
public bool EnableDualMode { get; set; } public bool EnableDualMode { get; set; }
AuthenticationSchemes auth_schemes; private AuthenticationSchemes auth_schemes;
HttpListenerPrefixCollection prefixes; private HttpListenerPrefixCollection prefixes;
AuthenticationSchemeSelector auth_selector; private AuthenticationSchemeSelector auth_selector;
string realm; private string realm;
bool unsafe_ntlm_auth; private bool unsafe_ntlm_auth;
bool listening; private bool listening;
bool disposed; private bool disposed;
Dictionary<HttpListenerContext, HttpListenerContext> registry; // Dictionary<HttpListenerContext,HttpListenerContext> private Dictionary<HttpListenerContext, HttpListenerContext> registry;
Dictionary<HttpConnection, HttpConnection> connections; private Dictionary<HttpConnection, HttpConnection> connections;
private ILogger _logger; private ILogger _logger;
private X509Certificate _certificate; private X509Certificate _certificate;
public Action<HttpListenerContext> OnContext { get; set; } public Action<HttpListenerContext> OnContext { get; set; }
public HttpListener(ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, public HttpListener(
INetworkManager networkManager, IStreamHelper streamHelper, IFileSystem fileSystem, ILogger logger,
ICryptoProvider cryptoProvider,
ISocketFactory socketFactory,
IStreamHelper streamHelper,
IFileSystem fileSystem,
IEnvironmentInfo environmentInfo) IEnvironmentInfo environmentInfo)
{ {
_logger = logger; _logger = logger;
CryptoProvider = cryptoProvider; CryptoProvider = cryptoProvider;
SocketFactory = socketFactory; SocketFactory = socketFactory;
NetworkManager = networkManager;
StreamHelper = streamHelper; StreamHelper = streamHelper;
FileSystem = fileSystem; FileSystem = fileSystem;
EnvironmentInfo = environmentInfo; EnvironmentInfo = environmentInfo;
prefixes = new HttpListenerPrefixCollection(logger, this); prefixes = new HttpListenerPrefixCollection(logger, this);
registry = new Dictionary<HttpListenerContext, HttpListenerContext>(); registry = new Dictionary<HttpListenerContext, HttpListenerContext>();
connections = new Dictionary<HttpConnection, HttpConnection>(); connections = new Dictionary<HttpConnection, HttpConnection>();
auth_schemes = AuthenticationSchemes.Anonymous; auth_schemes = AuthenticationSchemes.Anonymous;
} }
public HttpListener(ILogger logger, X509Certificate certificate, ICryptoProvider cryptoProvider, public HttpListener(
ISocketFactory socketFactory, INetworkManager networkManager, IStreamHelper streamHelper, ILogger logger,
IFileSystem fileSystem, IEnvironmentInfo environmentInfo) X509Certificate certificate,
: this(logger, cryptoProvider, socketFactory, networkManager, streamHelper, fileSystem, environmentInfo) ICryptoProvider cryptoProvider,
ISocketFactory socketFactory,
IStreamHelper streamHelper,
IFileSystem fileSystem,
IEnvironmentInfo environmentInfo)
: this(logger, cryptoProvider, socketFactory, streamHelper, fileSystem, environmentInfo)
{ {
_certificate = certificate; _certificate = certificate;
} }

View File

@ -7,18 +7,18 @@ namespace SocketHttpListener.Net
{ {
public class HttpListenerPrefixCollection : ICollection<string>, IEnumerable<string>, IEnumerable public class HttpListenerPrefixCollection : ICollection<string>, IEnumerable<string>, IEnumerable
{ {
List<string> prefixes = new List<string>(); private List<string> _prefixes = new List<string>();
HttpListener listener; private HttpListener _listener;
private ILogger _logger; private ILogger _logger;
internal HttpListenerPrefixCollection(ILogger logger, HttpListener listener) internal HttpListenerPrefixCollection(ILogger logger, HttpListener listener)
{ {
_logger = logger; _logger = logger;
this.listener = listener; _listener = listener;
} }
public int Count => prefixes.Count; public int Count => _prefixes.Count;
public bool IsReadOnly => false; public bool IsReadOnly => false;
@ -26,61 +26,90 @@ namespace SocketHttpListener.Net
public void Add(string uriPrefix) public void Add(string uriPrefix)
{ {
listener.CheckDisposed(); _listener.CheckDisposed();
//ListenerPrefix.CheckUri(uriPrefix); //ListenerPrefix.CheckUri(uriPrefix);
if (prefixes.Contains(uriPrefix)) if (_prefixes.Contains(uriPrefix))
{
return; return;
}
prefixes.Add(uriPrefix); _prefixes.Add(uriPrefix);
if (listener.IsListening) if (_listener.IsListening)
HttpEndPointManager.AddPrefix(_logger, uriPrefix, listener); {
HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener);
}
}
public void AddRange(IEnumerable<string> uriPrefixes)
{
_listener.CheckDisposed();
foreach (var uriPrefix in uriPrefixes)
{
if (_prefixes.Contains(uriPrefix))
{
continue;
}
_prefixes.Add(uriPrefix);
if (_listener.IsListening)
{
HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener);
}
}
} }
public void Clear() public void Clear()
{ {
listener.CheckDisposed(); _listener.CheckDisposed();
prefixes.Clear(); _prefixes.Clear();
if (listener.IsListening) if (_listener.IsListening)
HttpEndPointManager.RemoveListener(_logger, listener); {
HttpEndPointManager.RemoveListener(_logger, _listener);
}
} }
public bool Contains(string uriPrefix) public bool Contains(string uriPrefix)
{ {
listener.CheckDisposed(); _listener.CheckDisposed();
return prefixes.Contains(uriPrefix); return _prefixes.Contains(uriPrefix);
} }
public void CopyTo(string[] array, int offset) public void CopyTo(string[] array, int offset)
{ {
listener.CheckDisposed(); _listener.CheckDisposed();
prefixes.CopyTo(array, offset); _prefixes.CopyTo(array, offset);
} }
public void CopyTo(Array array, int offset) public void CopyTo(Array array, int offset)
{ {
listener.CheckDisposed(); _listener.CheckDisposed();
((ICollection)prefixes).CopyTo(array, offset); ((ICollection)_prefixes).CopyTo(array, offset);
} }
public IEnumerator<string> GetEnumerator() public IEnumerator<string> GetEnumerator()
{ {
return prefixes.GetEnumerator(); return _prefixes.GetEnumerator();
} }
IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()
{ {
return prefixes.GetEnumerator(); return _prefixes.GetEnumerator();
} }
public bool Remove(string uriPrefix) public bool Remove(string uriPrefix)
{ {
listener.CheckDisposed(); _listener.CheckDisposed();
if (uriPrefix == null) if (uriPrefix == null)
{
throw new ArgumentNullException(nameof(uriPrefix)); throw new ArgumentNullException(nameof(uriPrefix));
}
bool result = prefixes.Remove(uriPrefix); bool result = _prefixes.Remove(uriPrefix);
if (result && listener.IsListening) if (result && _listener.IsListening)
HttpEndPointManager.RemovePrefix(_logger, uriPrefix, listener); {
HttpEndPointManager.RemovePrefix(_logger, uriPrefix, _listener);
}
return result; return result;
} }

View File

@ -30,9 +30,9 @@ namespace SocketHttpListener
private CookieCollection _cookies; private CookieCollection _cookies;
private AutoResetEvent _exitReceiving; private AutoResetEvent _exitReceiving;
private object _forConn; private object _forConn;
private object _forEvent; private readonly SemaphoreSlim _forEvent = new SemaphoreSlim(1, 1);
private object _forMessageEventQueue; private object _forMessageEventQueue;
private object _forSend; private readonly SemaphoreSlim _forSend = new SemaphoreSlim(1, 1);
private const string _guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; private const string _guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
private Queue<MessageEventArgs> _messageEventQueue; private Queue<MessageEventArgs> _messageEventQueue;
private string _protocol; private string _protocol;
@ -109,12 +109,15 @@ namespace SocketHttpListener
#region Private Methods #region Private Methods
private void close(CloseStatusCode code, string reason, bool wait) private async Task CloseAsync(CloseStatusCode code, string reason, bool wait)
{ {
close(new PayloadData(((ushort)code).Append(reason)), !code.IsReserved(), wait); await CloseAsync(new PayloadData(
await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)),
!code.IsReserved(),
wait).ConfigureAwait(false);
} }
private void close(PayloadData payload, bool send, bool wait) private async Task CloseAsync(PayloadData payload, bool send, bool wait)
{ {
lock (_forConn) lock (_forConn)
{ {
@ -126,11 +129,12 @@ namespace SocketHttpListener
_readyState = WebSocketState.CloseSent; _readyState = WebSocketState.CloseSent;
} }
var e = new CloseEventArgs(payload); var e = new CloseEventArgs(payload)
e.WasClean = {
closeHandshake( WasClean = await CloseHandshakeAsync(
send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null, send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null,
wait ? 1000 : 0); wait ? 1000 : 0).ConfigureAwait(false)
};
_readyState = WebSocketState.Closed; _readyState = WebSocketState.Closed;
try try
@ -143,9 +147,9 @@ namespace SocketHttpListener
} }
} }
private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout) private async Task<bool> CloseHandshakeAsync(byte[] frameAsBytes, int millisecondsTimeout)
{ {
var sent = frameAsBytes != null && writeBytes(frameAsBytes); var sent = frameAsBytes != null && await WriteBytesAsync(frameAsBytes).ConfigureAwait(false);
var received = var received =
millisecondsTimeout == 0 || millisecondsTimeout == 0 ||
(sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout)); (sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout));
@ -189,11 +193,11 @@ namespace SocketHttpListener
_context = null; _context = null;
} }
private bool concatenateFragmentsInto(Stream dest) private async Task<bool> ConcatenateFragmentsIntoAsync(Stream dest)
{ {
while (true) while (true)
{ {
var frame = WebSocketFrame.Read(_stream, true); var frame = await WebSocketFrame.ReadAsync(_stream, true).ConfigureAwait(false);
if (frame.IsFinal) if (frame.IsFinal)
{ {
/* FINAL */ /* FINAL */
@ -221,7 +225,7 @@ namespace SocketHttpListener
// CLOSE // CLOSE
if (frame.IsClose) if (frame.IsClose)
return processCloseFrame(frame); return await ProcessCloseFrameAsync(frame).ConfigureAwait(false);
} }
else else
{ {
@ -236,10 +240,10 @@ namespace SocketHttpListener
} }
// ? // ?
return processUnsupportedFrame( return await ProcessUnsupportedFrameAsync(
frame, frame,
CloseStatusCode.IncorrectData, CloseStatusCode.IncorrectData,
"An incorrect data has been received while receiving fragmented data."); "An incorrect data has been received while receiving fragmented data.").ConfigureAwait(false);
} }
return true; return true;
@ -299,44 +303,42 @@ namespace SocketHttpListener
_compression = CompressionMethod.None; _compression = CompressionMethod.None;
_cookies = new CookieCollection(); _cookies = new CookieCollection();
_forConn = new object(); _forConn = new object();
_forEvent = new object();
_forSend = new object();
_messageEventQueue = new Queue<MessageEventArgs>(); _messageEventQueue = new Queue<MessageEventArgs>();
_forMessageEventQueue = ((ICollection)_messageEventQueue).SyncRoot; _forMessageEventQueue = ((ICollection)_messageEventQueue).SyncRoot;
_readyState = WebSocketState.Connecting; _readyState = WebSocketState.Connecting;
} }
private void open() private async Task OpenAsync()
{ {
try try
{ {
startReceiving(); startReceiving();
lock (_forEvent)
{
try
{
if (OnOpen != null)
{
OnOpen(this, EventArgs.Empty);
}
}
catch (Exception ex)
{
processException(ex, "An exception has occurred while OnOpen.");
}
}
} }
catch (Exception ex) catch (Exception ex)
{ {
processException(ex, "An exception has occurred while opening."); await ProcessExceptionAsync(ex, "An exception has occurred while opening.").ConfigureAwait(false);
}
await _forEvent.WaitAsync().ConfigureAwait(false);
try
{
OnOpen?.Invoke(this, EventArgs.Empty);
}
catch (Exception ex)
{
await ProcessExceptionAsync(ex, "An exception has occurred while OnOpen.").ConfigureAwait(false);
}
finally
{
_forEvent.Release();
} }
} }
private bool processCloseFrame(WebSocketFrame frame) private async Task<bool> ProcessCloseFrameAsync(WebSocketFrame frame)
{ {
var payload = frame.PayloadData; var payload = frame.PayloadData;
close(payload, !payload.ContainsReservedCloseStatusCode, false); await CloseAsync(payload, !payload.ContainsReservedCloseStatusCode, false).ConfigureAwait(false);
return false; return false;
} }
@ -352,7 +354,7 @@ namespace SocketHttpListener
return true; return true;
} }
private void processException(Exception exception, string message) private async Task ProcessExceptionAsync(Exception exception, string message)
{ {
var code = CloseStatusCode.Abnormal; var code = CloseStatusCode.Abnormal;
var reason = message; var reason = message;
@ -365,25 +367,31 @@ namespace SocketHttpListener
error(message ?? code.GetMessage(), exception); error(message ?? code.GetMessage(), exception);
if (_readyState == WebSocketState.Connecting) if (_readyState == WebSocketState.Connecting)
Close(HttpStatusCode.BadRequest); {
await CloseAsync(HttpStatusCode.BadRequest).ConfigureAwait(false);
}
else else
close(code, reason ?? code.GetMessage(), false); {
await CloseAsync(code, reason ?? code.GetMessage(), false).ConfigureAwait(false);
}
} }
private bool processFragmentedFrame(WebSocketFrame frame) private Task<bool> ProcessFragmentedFrameAsync(WebSocketFrame frame)
{ {
return frame.IsContinuation // Not first fragment return frame.IsContinuation // Not first fragment
? true ? Task.FromResult(true)
: processFragments(frame); : ProcessFragmentsAsync(frame);
} }
private bool processFragments(WebSocketFrame first) private async Task<bool> ProcessFragmentsAsync(WebSocketFrame first)
{ {
using (var buff = new MemoryStream()) using (var buff = new MemoryStream())
{ {
buff.WriteBytes(first.PayloadData.ApplicationData); buff.WriteBytes(first.PayloadData.ApplicationData);
if (!concatenateFragmentsInto(buff)) if (!await ConcatenateFragmentsIntoAsync(buff).ConfigureAwait(false))
{
return false; return false;
}
byte[] data; byte[] data;
if (_compression != CompressionMethod.None) if (_compression != CompressionMethod.None)
@ -412,36 +420,38 @@ namespace SocketHttpListener
return true; return true;
} }
private bool processUnsupportedFrame(WebSocketFrame frame, CloseStatusCode code, string reason) private async Task<bool> ProcessUnsupportedFrameAsync(WebSocketFrame frame, CloseStatusCode code, string reason)
{ {
processException(new WebSocketException(code, reason), null); await ProcessExceptionAsync(new WebSocketException(code, reason), null).ConfigureAwait(false);
return false; return false;
} }
private bool processWebSocketFrame(WebSocketFrame frame) private Task<bool> ProcessWebSocketFrameAsync(WebSocketFrame frame)
{ {
// TODO: @bond change to if/else chain
return frame.IsCompressed && _compression == CompressionMethod.None return frame.IsCompressed && _compression == CompressionMethod.None
? processUnsupportedFrame( ? ProcessUnsupportedFrameAsync(
frame, frame,
CloseStatusCode.IncorrectData, CloseStatusCode.IncorrectData,
"A compressed data has been received without available decompression method.") "A compressed data has been received without available decompression method.")
: frame.IsFragmented : frame.IsFragmented
? processFragmentedFrame(frame) ? ProcessFragmentedFrameAsync(frame)
: frame.IsData : frame.IsData
? processDataFrame(frame) ? Task.FromResult(processDataFrame(frame))
: frame.IsPing : frame.IsPing
? processPingFrame(frame) ? Task.FromResult(processPingFrame(frame))
: frame.IsPong : frame.IsPong
? processPongFrame(frame) ? Task.FromResult(processPongFrame(frame))
: frame.IsClose : frame.IsClose
? processCloseFrame(frame) ? ProcessCloseFrameAsync(frame)
: processUnsupportedFrame(frame, CloseStatusCode.PolicyViolation, null); : ProcessUnsupportedFrameAsync(frame, CloseStatusCode.PolicyViolation, null);
} }
private bool send(Opcode opcode, Stream stream) private async Task<bool> SendAsync(Opcode opcode, Stream stream)
{ {
lock (_forSend) await _forSend.WaitAsync().ConfigureAwait(false);
try
{ {
var src = stream; var src = stream;
var compressed = false; var compressed = false;
@ -454,7 +464,7 @@ namespace SocketHttpListener
compressed = true; compressed = true;
} }
sent = send(opcode, Mask.Unmask, stream, compressed); sent = await SendAsync(opcode, Mask.Unmask, stream, compressed).ConfigureAwait(false);
if (!sent) if (!sent)
error("Sending a data has been interrupted."); error("Sending a data has been interrupted.");
} }
@ -472,16 +482,20 @@ namespace SocketHttpListener
return sent; return sent;
} }
finally
{
_forSend.Release();
}
} }
private bool send(Opcode opcode, Mask mask, Stream stream, bool compressed) private async Task<bool> SendAsync(Opcode opcode, Mask mask, Stream stream, bool compressed)
{ {
var len = stream.Length; var len = stream.Length;
/* Not fragmented */ /* Not fragmented */
if (len == 0) if (len == 0)
return send(Fin.Final, opcode, mask, new byte[0], compressed); return await SendAsync(Fin.Final, opcode, mask, new byte[0], compressed).ConfigureAwait(false);
var quo = len / FragmentLength; var quo = len / FragmentLength;
var rem = (int)(len % FragmentLength); var rem = (int)(len % FragmentLength);
@ -490,26 +504,26 @@ namespace SocketHttpListener
if (quo == 0) if (quo == 0)
{ {
buff = new byte[rem]; buff = new byte[rem];
return stream.Read(buff, 0, rem) == rem && return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem &&
send(Fin.Final, opcode, mask, buff, compressed); await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false);
} }
buff = new byte[FragmentLength]; buff = new byte[FragmentLength];
if (quo == 1 && rem == 0) if (quo == 1 && rem == 0)
return stream.Read(buff, 0, FragmentLength) == FragmentLength && return await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) == FragmentLength &&
send(Fin.Final, opcode, mask, buff, compressed); await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false);
/* Send fragmented */ /* Send fragmented */
// Begin // Begin
if (stream.Read(buff, 0, FragmentLength) != FragmentLength || if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength ||
!send(Fin.More, opcode, mask, buff, compressed)) !await SendAsync(Fin.More, opcode, mask, buff, compressed).ConfigureAwait(false))
return false; return false;
var n = rem == 0 ? quo - 2 : quo - 1; var n = rem == 0 ? quo - 2 : quo - 1;
for (long i = 0; i < n; i++) for (long i = 0; i < n; i++)
if (stream.Read(buff, 0, FragmentLength) != FragmentLength || if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength ||
!send(Fin.More, Opcode.Cont, mask, buff, compressed)) !await SendAsync(Fin.More, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false))
return false; return false;
// End // End
@ -518,98 +532,88 @@ namespace SocketHttpListener
else else
buff = new byte[rem]; buff = new byte[rem];
return stream.Read(buff, 0, rem) == rem && return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem &&
send(Fin.Final, Opcode.Cont, mask, buff, compressed); await SendAsync(Fin.Final, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false);
} }
private bool send(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed) private Task<bool> SendAsync(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed)
{ {
lock (_forConn) lock (_forConn)
{ {
if (_readyState != WebSocketState.Open) if (_readyState != WebSocketState.Open)
{ {
return false; return Task.FromResult(false);
} }
return writeBytes( return WriteBytesAsync(
WebSocketFrame.CreateWebSocketFrame(fin, opcode, mask, data, compressed).ToByteArray()); WebSocketFrame.CreateWebSocketFrame(fin, opcode, mask, data, compressed).ToByteArray());
} }
} }
private Task sendAsync(Opcode opcode, Stream stream)
{
var completionSource = new TaskCompletionSource<bool>();
Task.Run(() =>
{
try
{
send(opcode, stream);
completionSource.TrySetResult(true);
}
catch (Exception ex)
{
completionSource.TrySetException(ex);
}
});
return completionSource.Task;
}
// As server // As server
private bool sendHttpResponse(HttpResponse response) private Task<bool> SendHttpResponseAsync(HttpResponse response)
{ => WriteBytesAsync(response.ToByteArray());
return writeBytes(response.ToByteArray());
}
private void startReceiving() private void startReceiving()
{ {
if (_messageEventQueue.Count > 0) if (_messageEventQueue.Count > 0)
{
_messageEventQueue.Clear(); _messageEventQueue.Clear();
}
_exitReceiving = new AutoResetEvent(false); _exitReceiving = new AutoResetEvent(false);
_receivePong = new AutoResetEvent(false); _receivePong = new AutoResetEvent(false);
Action receive = null; Action receive = null;
receive = () => WebSocketFrame.ReadAsync( receive = async () => await WebSocketFrame.ReadAsync(
_stream, _stream,
true, true,
frame => async frame =>
{ {
if (processWebSocketFrame(frame) && _readyState != WebSocketState.Closed) if (await ProcessWebSocketFrameAsync(frame).ConfigureAwait(false) && _readyState != WebSocketState.Closed)
{ {
receive(); receive();
if (!frame.IsData) if (!frame.IsData)
return; {
return;
}
lock (_forEvent) await _forEvent.WaitAsync().ConfigureAwait(false);
{
try try
{ {
var e = dequeueFromMessageEventQueue(); var e = dequeueFromMessageEventQueue();
if (e != null && _readyState == WebSocketState.Open) if (e != null && _readyState == WebSocketState.Open)
OnMessage.Emit(this, e); {
} OnMessage.Emit(this, e);
catch (Exception ex) }
{ }
processException(ex, "An exception has occurred while OnMessage."); catch (Exception ex)
} {
} await ProcessExceptionAsync(ex, "An exception has occurred while OnMessage.").ConfigureAwait(false);
} }
else if (_exitReceiving != null) finally
{ {
_exitReceiving.Set(); _forEvent.Release();
} }
},
ex => processException(ex, "An exception has occurred while receiving a message.")); }
else if (_exitReceiving != null)
{
_exitReceiving.Set();
}
},
async ex => await ProcessExceptionAsync(ex, "An exception has occurred while receiving a message.")).ConfigureAwait(false);
receive(); receive();
} }
private bool writeBytes(byte[] data) private async Task<bool> WriteBytesAsync(byte[] data)
{ {
try try
{ {
_stream.Write(data, 0, data.Length); await _stream.WriteAsync(data, 0, data.Length).ConfigureAwait(false);
return true; return true;
} }
catch (Exception) catch (Exception)
@ -623,10 +627,10 @@ namespace SocketHttpListener
#region Internal Methods #region Internal Methods
// As server // As server
internal void Close(HttpResponse response) internal async Task CloseAsync(HttpResponse response)
{ {
_readyState = WebSocketState.CloseSent; _readyState = WebSocketState.CloseSent;
sendHttpResponse(response); await SendHttpResponseAsync(response).ConfigureAwait(false);
closeServerResources(); closeServerResources();
@ -634,22 +638,20 @@ namespace SocketHttpListener
} }
// As server // As server
internal void Close(HttpStatusCode code) internal Task CloseAsync(HttpStatusCode code)
{ => CloseAsync(createHandshakeCloseResponse(code));
Close(createHandshakeCloseResponse(code));
}
// As server // As server
public void ConnectAsServer() public async Task ConnectAsServer()
{ {
try try
{ {
_readyState = WebSocketState.Open; _readyState = WebSocketState.Open;
open(); await OpenAsync().ConfigureAwait(false);
} }
catch (Exception ex) catch (Exception ex)
{ {
processException(ex, "An exception has occurred while connecting."); await ProcessExceptionAsync(ex, "An exception has occurred while connecting.").ConfigureAwait(false);
} }
} }
@ -660,18 +662,18 @@ namespace SocketHttpListener
/// <summary> /// <summary>
/// Closes the WebSocket connection, and releases all associated resources. /// Closes the WebSocket connection, and releases all associated resources.
/// </summary> /// </summary>
public void Close() public Task CloseAsync()
{ {
var msg = _readyState.CheckIfClosable(); var msg = _readyState.CheckIfClosable();
if (msg != null) if (msg != null)
{ {
error(msg); error(msg);
return; return Task.CompletedTask;
} }
var send = _readyState == WebSocketState.Open; var send = _readyState == WebSocketState.Open;
close(new PayloadData(), send, send); return CloseAsync(new PayloadData(), send, send);
} }
/// <summary> /// <summary>
@ -689,11 +691,11 @@ namespace SocketHttpListener
/// <param name="reason"> /// <param name="reason">
/// A <see cref="string"/> that represents the reason for the close. /// A <see cref="string"/> that represents the reason for the close.
/// </param> /// </param>
public void Close(CloseStatusCode code, string reason) public async Task CloseAsync(CloseStatusCode code, string reason)
{ {
byte[] data = null; byte[] data = null;
var msg = _readyState.CheckIfClosable() ?? var msg = _readyState.CheckIfClosable() ??
(data = ((ushort)code).Append(reason)).CheckIfValidControlData("reason"); (data = await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)).CheckIfValidControlData("reason");
if (msg != null) if (msg != null)
{ {
@ -703,7 +705,7 @@ namespace SocketHttpListener
} }
var send = _readyState == WebSocketState.Open && !code.IsReserved(); var send = _readyState == WebSocketState.Open && !code.IsReserved();
close(new PayloadData(data), send, send); await CloseAsync(new PayloadData(data), send, send).ConfigureAwait(false);
} }
/// <summary> /// <summary>
@ -728,7 +730,7 @@ namespace SocketHttpListener
throw new Exception(msg); throw new Exception(msg);
} }
return sendAsync(Opcode.Binary, new MemoryStream(data)); return SendAsync(Opcode.Binary, new MemoryStream(data));
} }
/// <summary> /// <summary>
@ -753,7 +755,7 @@ namespace SocketHttpListener
throw new Exception(msg); throw new Exception(msg);
} }
return sendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data))); return SendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data)));
} }
#endregion #endregion
@ -768,7 +770,7 @@ namespace SocketHttpListener
/// </remarks> /// </remarks>
void IDisposable.Dispose() void IDisposable.Dispose()
{ {
Close(CloseStatusCode.Away, null); CloseAsync(CloseStatusCode.Away, null).GetAwaiter().GetResult();
} }
#endregion #endregion

View File

@ -2,6 +2,7 @@ using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Threading.Tasks;
namespace SocketHttpListener namespace SocketHttpListener
{ {
@ -177,7 +178,7 @@ namespace SocketHttpListener
return opcode == Opcode.Text || opcode == Opcode.Binary; return opcode == Opcode.Text || opcode == Opcode.Binary;
} }
private static WebSocketFrame read(byte[] header, Stream stream, bool unmask) private static async Task<WebSocketFrame> ReadAsync(byte[] header, Stream stream, bool unmask)
{ {
/* Header */ /* Header */
@ -229,7 +230,7 @@ namespace SocketHttpListener
? 2 ? 2
: 8; : 8;
var extPayloadLen = size > 0 ? stream.ReadBytes(size) : new byte[0]; var extPayloadLen = size > 0 ? await stream.ReadBytesAsync(size).ConfigureAwait(false) : Array.Empty<byte>();
if (size > 0 && extPayloadLen.Length != size) if (size > 0 && extPayloadLen.Length != size)
throw new WebSocketException( throw new WebSocketException(
"The 'Extended Payload Length' of a frame cannot be read from the data source."); "The 'Extended Payload Length' of a frame cannot be read from the data source.");
@ -239,7 +240,7 @@ namespace SocketHttpListener
/* Masking Key */ /* Masking Key */
var masked = mask == Mask.Mask; var masked = mask == Mask.Mask;
var maskingKey = masked ? stream.ReadBytes(4) : new byte[0]; var maskingKey = masked ? await stream.ReadBytesAsync(4).ConfigureAwait(false) : Array.Empty<byte>();
if (masked && maskingKey.Length != 4) if (masked && maskingKey.Length != 4)
throw new WebSocketException( throw new WebSocketException(
"The 'Masking Key' of a frame cannot be read from the data source."); "The 'Masking Key' of a frame cannot be read from the data source.");
@ -264,8 +265,8 @@ namespace SocketHttpListener
"The length of 'Payload Data' of a frame is greater than the allowable length."); "The length of 'Payload Data' of a frame is greater than the allowable length.");
data = payloadLen > 126 data = payloadLen > 126
? stream.ReadBytes((long)len, 1024) ? await stream.ReadBytesAsync((long)len, 1024).ConfigureAwait(false)
: stream.ReadBytes((int)len); : await stream.ReadBytesAsync((int)len).ConfigureAwait(false);
//if (data.LongLength != (long)len) //if (data.LongLength != (long)len)
// throw new WebSocketException( // throw new WebSocketException(
@ -273,7 +274,7 @@ namespace SocketHttpListener
} }
else else
{ {
data = new byte[0]; data = Array.Empty<byte>();
} }
var payload = new PayloadData(data, masked); var payload = new PayloadData(data, masked);
@ -281,7 +282,7 @@ namespace SocketHttpListener
{ {
payload.Mask(maskingKey); payload.Mask(maskingKey);
frame._mask = Mask.Unmask; frame._mask = Mask.Unmask;
frame._maskingKey = new byte[0]; frame._maskingKey = Array.Empty<byte>();
} }
frame._payloadData = payload; frame._payloadData = payload;
@ -302,10 +303,10 @@ namespace SocketHttpListener
return new WebSocketFrame(Opcode.Close, mask, payload); return new WebSocketFrame(Opcode.Close, mask, payload);
} }
internal static WebSocketFrame CreateCloseFrame(Mask mask, CloseStatusCode code, string reason) internal static async Task<WebSocketFrame> CreateCloseFrameAsync(Mask mask, CloseStatusCode code, string reason)
{ {
return new WebSocketFrame( return new WebSocketFrame(
Opcode.Close, mask, new PayloadData(((ushort)code).Append(reason))); Opcode.Close, mask, new PayloadData(await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)));
} }
internal static WebSocketFrame CreatePingFrame(Mask mask) internal static WebSocketFrame CreatePingFrame(Mask mask)
@ -329,41 +330,39 @@ namespace SocketHttpListener
return new WebSocketFrame(fin, opcode, mask, new PayloadData(data), compressed); return new WebSocketFrame(fin, opcode, mask, new PayloadData(data), compressed);
} }
internal static WebSocketFrame Read(Stream stream) internal static Task<WebSocketFrame> ReadAsync(Stream stream)
{ => ReadAsync(stream, true);
return Read(stream, true);
}
internal static WebSocketFrame Read(Stream stream, bool unmask) internal static async Task<WebSocketFrame> ReadAsync(Stream stream, bool unmask)
{ {
var header = stream.ReadBytes(2); var header = await stream.ReadBytesAsync(2).ConfigureAwait(false);
if (header.Length != 2) if (header.Length != 2)
{
throw new WebSocketException( throw new WebSocketException(
"The header part of a frame cannot be read from the data source."); "The header part of a frame cannot be read from the data source.");
}
return read(header, stream, unmask); return await ReadAsync(header, stream, unmask).ConfigureAwait(false);
} }
internal static async void ReadAsync( internal static async Task ReadAsync(
Stream stream, bool unmask, Action<WebSocketFrame> completed, Action<Exception> error) Stream stream, bool unmask, Action<WebSocketFrame> completed, Action<Exception> error)
{ {
try try
{ {
var header = await stream.ReadBytesAsync(2).ConfigureAwait(false); var header = await stream.ReadBytesAsync(2).ConfigureAwait(false);
if (header.Length != 2) if (header.Length != 2)
{
throw new WebSocketException( throw new WebSocketException(
"The header part of a frame cannot be read from the data source."); "The header part of a frame cannot be read from the data source.");
}
var frame = read(header, stream, unmask); var frame = await ReadAsync(header, stream, unmask).ConfigureAwait(false);
if (completed != null) completed?.Invoke(frame);
completed(frame);
} }
catch (Exception ex) catch (Exception ex)
{ {
if (error != null) error.Invoke(ex);
{
error(ex);
}
} }
} }