diff --git a/Jellyfin.Server/SocketSharp/SharpWebSocket.cs b/Jellyfin.Server/SocketSharp/SharpWebSocket.cs index 6eee4cd12..9b0951857 100644 --- a/Jellyfin.Server/SocketSharp/SharpWebSocket.cs +++ b/Jellyfin.Server/SocketSharp/SharpWebSocket.cs @@ -44,10 +44,11 @@ namespace Jellyfin.Server.SocketSharp socket.OnMessage += OnSocketMessage; socket.OnClose += OnSocketClose; socket.OnError += OnSocketError; - - WebSocket.ConnectAsServer(); } + public Task ConnectAsServerAsync() + => WebSocket.ConnectAsServer(); + public Task StartReceive() { return _taskCompletionSource.Task; @@ -133,7 +134,7 @@ namespace Jellyfin.Server.SocketSharp _cancellationTokenSource.Cancel(); - WebSocket.Close(); + WebSocket.CloseAsync().GetAwaiter().GetResult(); } _disposed = true; diff --git a/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs b/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs index 58c4d38a2..736f9feef 100644 --- a/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs +++ b/Jellyfin.Server/SocketSharp/WebSocketSharpListener.cs @@ -69,7 +69,7 @@ namespace Jellyfin.Server.SocketSharp { if (_listener == null) { - _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _networkManager, _streamHelper, _fileSystem, _environment); + _listener = new HttpListener(_logger, _cryptoProvider, _socketFactory, _streamHelper, _fileSystem, _environment); } _listener.EnableDualMode = _enableDualMode; @@ -79,22 +79,14 @@ namespace Jellyfin.Server.SocketSharp _listener.LoadCert(_certificate); } - foreach (var prefix in urlPrefixes) - { - _logger.LogInformation("Adding HttpListener prefix " + prefix); - _listener.Prefixes.Add(prefix); - } + _logger.LogInformation("Adding HttpListener prefixes {Prefixes}", urlPrefixes); + _listener.Prefixes.AddRange(urlPrefixes); - _listener.OnContext = ProcessContext; + _listener.OnContext = async c => await InitTask(c, _disposeCancellationToken).ConfigureAwait(false); _listener.Start(); } - private void ProcessContext(HttpListenerContext context) - { - _ = Task.Run(async () => await InitTask(context, _disposeCancellationToken).ConfigureAwait(false)); - } - private static void LogRequest(ILogger logger, HttpListenerRequest request) { var url = request.Url.ToString(); @@ -151,10 +143,7 @@ namespace Jellyfin.Server.SocketSharp Endpoint = endpoint }; - if (WebSocketConnecting != null) - { - WebSocketConnecting(connectingArgs); - } + WebSocketConnecting?.Invoke(connectingArgs); if (connectingArgs.AllowConnection) { @@ -165,6 +154,7 @@ namespace Jellyfin.Server.SocketSharp if (WebSocketConnected != null) { var socket = new SharpWebSocket(webSocketContext.WebSocket, _logger); + await socket.ConnectAsServerAsync().ConfigureAwait(false); WebSocketConnected(new WebSocketConnectEventArgs { @@ -174,7 +164,7 @@ namespace Jellyfin.Server.SocketSharp Endpoint = endpoint }); - await ReceiveWebSocket(ctx, socket).ConfigureAwait(false); + await ReceiveWebSocketAsync(ctx, socket).ConfigureAwait(false); } } else @@ -192,7 +182,7 @@ namespace Jellyfin.Server.SocketSharp } } - private async Task ReceiveWebSocket(HttpListenerContext ctx, SharpWebSocket socket) + private async Task ReceiveWebSocketAsync(HttpListenerContext ctx, SharpWebSocket socket) { try { diff --git a/SocketHttpListener/Ext.cs b/SocketHttpListener/Ext.cs index a02b48061..2b3c67071 100644 --- a/SocketHttpListener/Ext.cs +++ b/SocketHttpListener/Ext.cs @@ -74,18 +74,20 @@ namespace SocketHttpListener } } - private static byte[] readBytes(this Stream stream, byte[] buffer, int offset, int length) + private static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length) { - var len = stream.Read(buffer, offset, length); + var len = await stream.ReadAsync(buffer, offset, length).ConfigureAwait(false); if (len < 1) return buffer.SubArray(0, offset); var tmp = 0; while (len < length) { - tmp = stream.Read(buffer, offset + len, length - len); + tmp = await stream.ReadAsync(buffer, offset + len, length - len).ConfigureAwait(false); if (tmp < 1) + { break; + } len += tmp; } @@ -95,10 +97,9 @@ namespace SocketHttpListener : buffer; } - private static bool readBytes( - this Stream stream, byte[] buffer, int offset, int length, Stream dest) + private static async Task ReadBytesAsync(this Stream stream, byte[] buffer, int offset, int length, Stream dest) { - var bytes = stream.readBytes(buffer, offset, length); + var bytes = await stream.ReadBytesAsync(buffer, offset, length).ConfigureAwait(false); var len = bytes.Length; dest.Write(bytes, 0, len); @@ -109,16 +110,16 @@ namespace SocketHttpListener #region Internal Methods - internal static byte[] Append(this ushort code, string reason) + internal static async Task AppendAsync(this ushort code, string reason) { using (var buffer = new MemoryStream()) { var tmp = code.ToByteArrayInternally(ByteOrder.Big); - buffer.Write(tmp, 0, 2); + await buffer.WriteAsync(tmp, 0, 2).ConfigureAwait(false); if (reason != null && reason.Length > 0) { tmp = Encoding.UTF8.GetBytes(reason); - buffer.Write(tmp, 0, tmp.Length); + await buffer.WriteAsync(tmp, 0, tmp.Length).ConfigureAwait(false); } return buffer.ToArray(); @@ -331,12 +332,10 @@ namespace SocketHttpListener : string.Format("\"{0}\"", value.Replace("\"", "\\\"")); } - internal static byte[] ReadBytes(this Stream stream, int length) - { - return stream.readBytes(new byte[length], 0, length); - } + internal static Task ReadBytesAsync(this Stream stream, int length) + => stream.ReadBytesAsync(new byte[length], 0, length); - internal static byte[] ReadBytes(this Stream stream, long length, int bufferLength) + internal static async Task ReadBytesAsync(this Stream stream, long length, int bufferLength) { using (var result = new MemoryStream()) { @@ -347,7 +346,7 @@ namespace SocketHttpListener var end = false; for (long i = 0; i < count; i++) { - if (!stream.readBytes(buffer, 0, bufferLength, result)) + if (!await stream.ReadBytesAsync(buffer, 0, bufferLength, result).ConfigureAwait(false)) { end = true; break; @@ -355,26 +354,14 @@ namespace SocketHttpListener } if (!end && rem > 0) - stream.readBytes(new byte[rem], 0, rem, result); + { + await stream.ReadBytesAsync(new byte[rem], 0, rem, result).ConfigureAwait(false); + } return result.ToArray(); } } - internal static async Task ReadBytesAsync(this Stream stream, int length) - { - var buffer = new byte[length]; - - var len = await stream.ReadAsync(buffer, 0, length).ConfigureAwait(false); - var bytes = len < 1 - ? new byte[0] - : len < length - ? stream.readBytes(buffer, len, length - len) - : buffer; - - return bytes; - } - internal static string RemovePrefix(this string value, params string[] prefixes) { var i = 0; @@ -493,19 +480,16 @@ namespace SocketHttpListener return string.Format("{0}; {1}", m, parameters.ToString("; ")); } - internal static List ToList(this IEnumerable source) - { - return new List(source); - } - internal static ushort ToUInt16(this byte[] src, ByteOrder srcOrder) { - return BitConverter.ToUInt16(src.ToHostOrder(srcOrder), 0); + src.ToHostOrder(srcOrder); + return BitConverter.ToUInt16(src, 0); } internal static ulong ToUInt64(this byte[] src, ByteOrder srcOrder) { - return BitConverter.ToUInt64(src.ToHostOrder(srcOrder), 0); + src.ToHostOrder(srcOrder); + return BitConverter.ToUInt64(src, 0); } internal static string TrimEndSlash(this string value) @@ -852,14 +836,17 @@ namespace SocketHttpListener /// /// is . /// - public static byte[] ToHostOrder(this byte[] src, ByteOrder srcOrder) + public static void ToHostOrder(this byte[] src, ByteOrder srcOrder) { if (src == null) + { throw new ArgumentNullException(nameof(src)); + } - return src.Length > 1 && !srcOrder.IsHostOrder() - ? src.Reverse() - : src; + if (src.Length > 1 && !srcOrder.IsHostOrder()) + { + Array.Reverse(src); + } } /// diff --git a/SocketHttpListener/Net/HttpListener.cs b/SocketHttpListener/Net/HttpListener.cs index b80180679..f17036a21 100644 --- a/SocketHttpListener/Net/HttpListener.cs +++ b/SocketHttpListener/Net/HttpListener.cs @@ -3,7 +3,6 @@ using System.Collections; using System.Collections.Generic; using System.Net; using System.Security.Cryptography.X509Certificates; -using MediaBrowser.Common.Net; using MediaBrowser.Model.Cryptography; using MediaBrowser.Model.IO; using MediaBrowser.Model.Net; @@ -18,47 +17,55 @@ namespace SocketHttpListener.Net internal ISocketFactory SocketFactory { get; private set; } internal IFileSystem FileSystem { get; private set; } internal IStreamHelper StreamHelper { get; private set; } - internal INetworkManager NetworkManager { get; private set; } internal IEnvironmentInfo EnvironmentInfo { get; private set; } public bool EnableDualMode { get; set; } - AuthenticationSchemes auth_schemes; - HttpListenerPrefixCollection prefixes; - AuthenticationSchemeSelector auth_selector; - string realm; - bool unsafe_ntlm_auth; - bool listening; - bool disposed; + private AuthenticationSchemes auth_schemes; + private HttpListenerPrefixCollection prefixes; + private AuthenticationSchemeSelector auth_selector; + private string realm; + private bool unsafe_ntlm_auth; + private bool listening; + private bool disposed; - Dictionary registry; // Dictionary - Dictionary connections; + private Dictionary registry; + private Dictionary connections; private ILogger _logger; private X509Certificate _certificate; public Action OnContext { get; set; } - public HttpListener(ILogger logger, ICryptoProvider cryptoProvider, ISocketFactory socketFactory, - INetworkManager networkManager, IStreamHelper streamHelper, IFileSystem fileSystem, + public HttpListener( + ILogger logger, + ICryptoProvider cryptoProvider, + ISocketFactory socketFactory, + IStreamHelper streamHelper, + IFileSystem fileSystem, IEnvironmentInfo environmentInfo) { _logger = logger; CryptoProvider = cryptoProvider; SocketFactory = socketFactory; - NetworkManager = networkManager; StreamHelper = streamHelper; FileSystem = fileSystem; EnvironmentInfo = environmentInfo; + prefixes = new HttpListenerPrefixCollection(logger, this); registry = new Dictionary(); connections = new Dictionary(); auth_schemes = AuthenticationSchemes.Anonymous; } - public HttpListener(ILogger logger, X509Certificate certificate, ICryptoProvider cryptoProvider, - ISocketFactory socketFactory, INetworkManager networkManager, IStreamHelper streamHelper, - IFileSystem fileSystem, IEnvironmentInfo environmentInfo) - : this(logger, cryptoProvider, socketFactory, networkManager, streamHelper, fileSystem, environmentInfo) + public HttpListener( + ILogger logger, + X509Certificate certificate, + ICryptoProvider cryptoProvider, + ISocketFactory socketFactory, + IStreamHelper streamHelper, + IFileSystem fileSystem, + IEnvironmentInfo environmentInfo) + : this(logger, cryptoProvider, socketFactory, streamHelper, fileSystem, environmentInfo) { _certificate = certificate; } diff --git a/SocketHttpListener/Net/HttpListenerPrefixCollection.cs b/SocketHttpListener/Net/HttpListenerPrefixCollection.cs index 97dc6797c..400a1adb6 100644 --- a/SocketHttpListener/Net/HttpListenerPrefixCollection.cs +++ b/SocketHttpListener/Net/HttpListenerPrefixCollection.cs @@ -7,18 +7,18 @@ namespace SocketHttpListener.Net { public class HttpListenerPrefixCollection : ICollection, IEnumerable, IEnumerable { - List prefixes = new List(); - HttpListener listener; + private List _prefixes = new List(); + private HttpListener _listener; private ILogger _logger; internal HttpListenerPrefixCollection(ILogger logger, HttpListener listener) { _logger = logger; - this.listener = listener; + _listener = listener; } - public int Count => prefixes.Count; + public int Count => _prefixes.Count; public bool IsReadOnly => false; @@ -26,61 +26,90 @@ namespace SocketHttpListener.Net public void Add(string uriPrefix) { - listener.CheckDisposed(); + _listener.CheckDisposed(); //ListenerPrefix.CheckUri(uriPrefix); - if (prefixes.Contains(uriPrefix)) + if (_prefixes.Contains(uriPrefix)) + { return; + } - prefixes.Add(uriPrefix); - if (listener.IsListening) - HttpEndPointManager.AddPrefix(_logger, uriPrefix, listener); + _prefixes.Add(uriPrefix); + if (_listener.IsListening) + { + HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener); + } + } + + public void AddRange(IEnumerable uriPrefixes) + { + _listener.CheckDisposed(); + + foreach (var uriPrefix in uriPrefixes) + { + if (_prefixes.Contains(uriPrefix)) + { + continue; + } + + _prefixes.Add(uriPrefix); + if (_listener.IsListening) + { + HttpEndPointManager.AddPrefix(_logger, uriPrefix, _listener); + } + } } public void Clear() { - listener.CheckDisposed(); - prefixes.Clear(); - if (listener.IsListening) - HttpEndPointManager.RemoveListener(_logger, listener); + _listener.CheckDisposed(); + _prefixes.Clear(); + if (_listener.IsListening) + { + HttpEndPointManager.RemoveListener(_logger, _listener); + } } public bool Contains(string uriPrefix) { - listener.CheckDisposed(); - return prefixes.Contains(uriPrefix); + _listener.CheckDisposed(); + return _prefixes.Contains(uriPrefix); } public void CopyTo(string[] array, int offset) { - listener.CheckDisposed(); - prefixes.CopyTo(array, offset); + _listener.CheckDisposed(); + _prefixes.CopyTo(array, offset); } public void CopyTo(Array array, int offset) { - listener.CheckDisposed(); - ((ICollection)prefixes).CopyTo(array, offset); + _listener.CheckDisposed(); + ((ICollection)_prefixes).CopyTo(array, offset); } public IEnumerator GetEnumerator() { - return prefixes.GetEnumerator(); + return _prefixes.GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { - return prefixes.GetEnumerator(); + return _prefixes.GetEnumerator(); } public bool Remove(string uriPrefix) { - listener.CheckDisposed(); + _listener.CheckDisposed(); if (uriPrefix == null) + { throw new ArgumentNullException(nameof(uriPrefix)); + } - bool result = prefixes.Remove(uriPrefix); - if (result && listener.IsListening) - HttpEndPointManager.RemovePrefix(_logger, uriPrefix, listener); + bool result = _prefixes.Remove(uriPrefix); + if (result && _listener.IsListening) + { + HttpEndPointManager.RemovePrefix(_logger, uriPrefix, _listener); + } return result; } diff --git a/SocketHttpListener/WebSocket.cs b/SocketHttpListener/WebSocket.cs index 128bc8b97..0dcb6a64b 100644 --- a/SocketHttpListener/WebSocket.cs +++ b/SocketHttpListener/WebSocket.cs @@ -30,9 +30,9 @@ namespace SocketHttpListener private CookieCollection _cookies; private AutoResetEvent _exitReceiving; private object _forConn; - private object _forEvent; + private readonly SemaphoreSlim _forEvent = new SemaphoreSlim(1, 1); private object _forMessageEventQueue; - private object _forSend; + private readonly SemaphoreSlim _forSend = new SemaphoreSlim(1, 1); private const string _guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; private Queue _messageEventQueue; private string _protocol; @@ -109,12 +109,15 @@ namespace SocketHttpListener #region Private Methods - private void close(CloseStatusCode code, string reason, bool wait) + private async Task CloseAsync(CloseStatusCode code, string reason, bool wait) { - close(new PayloadData(((ushort)code).Append(reason)), !code.IsReserved(), wait); + await CloseAsync(new PayloadData( + await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)), + !code.IsReserved(), + wait).ConfigureAwait(false); } - private void close(PayloadData payload, bool send, bool wait) + private async Task CloseAsync(PayloadData payload, bool send, bool wait) { lock (_forConn) { @@ -126,11 +129,12 @@ namespace SocketHttpListener _readyState = WebSocketState.CloseSent; } - var e = new CloseEventArgs(payload); - e.WasClean = - closeHandshake( + var e = new CloseEventArgs(payload) + { + WasClean = await CloseHandshakeAsync( send ? WebSocketFrame.CreateCloseFrame(Mask.Unmask, payload).ToByteArray() : null, - wait ? 1000 : 0); + wait ? 1000 : 0).ConfigureAwait(false) + }; _readyState = WebSocketState.Closed; try @@ -143,9 +147,9 @@ namespace SocketHttpListener } } - private bool closeHandshake(byte[] frameAsBytes, int millisecondsTimeout) + private async Task CloseHandshakeAsync(byte[] frameAsBytes, int millisecondsTimeout) { - var sent = frameAsBytes != null && writeBytes(frameAsBytes); + var sent = frameAsBytes != null && await WriteBytesAsync(frameAsBytes).ConfigureAwait(false); var received = millisecondsTimeout == 0 || (sent && _exitReceiving != null && _exitReceiving.WaitOne(millisecondsTimeout)); @@ -189,11 +193,11 @@ namespace SocketHttpListener _context = null; } - private bool concatenateFragmentsInto(Stream dest) + private async Task ConcatenateFragmentsIntoAsync(Stream dest) { while (true) { - var frame = WebSocketFrame.Read(_stream, true); + var frame = await WebSocketFrame.ReadAsync(_stream, true).ConfigureAwait(false); if (frame.IsFinal) { /* FINAL */ @@ -221,7 +225,7 @@ namespace SocketHttpListener // CLOSE if (frame.IsClose) - return processCloseFrame(frame); + return await ProcessCloseFrameAsync(frame).ConfigureAwait(false); } else { @@ -236,10 +240,10 @@ namespace SocketHttpListener } // ? - return processUnsupportedFrame( + return await ProcessUnsupportedFrameAsync( frame, CloseStatusCode.IncorrectData, - "An incorrect data has been received while receiving fragmented data."); + "An incorrect data has been received while receiving fragmented data.").ConfigureAwait(false); } return true; @@ -299,44 +303,42 @@ namespace SocketHttpListener _compression = CompressionMethod.None; _cookies = new CookieCollection(); _forConn = new object(); - _forEvent = new object(); - _forSend = new object(); _messageEventQueue = new Queue(); _forMessageEventQueue = ((ICollection)_messageEventQueue).SyncRoot; _readyState = WebSocketState.Connecting; } - private void open() + private async Task OpenAsync() { try { startReceiving(); - lock (_forEvent) - { - try - { - if (OnOpen != null) - { - OnOpen(this, EventArgs.Empty); - } - } - catch (Exception ex) - { - processException(ex, "An exception has occurred while OnOpen."); - } - } } catch (Exception ex) { - processException(ex, "An exception has occurred while opening."); + await ProcessExceptionAsync(ex, "An exception has occurred while opening.").ConfigureAwait(false); + } + + await _forEvent.WaitAsync().ConfigureAwait(false); + try + { + OnOpen?.Invoke(this, EventArgs.Empty); + } + catch (Exception ex) + { + await ProcessExceptionAsync(ex, "An exception has occurred while OnOpen.").ConfigureAwait(false); + } + finally + { + _forEvent.Release(); } } - private bool processCloseFrame(WebSocketFrame frame) + private async Task ProcessCloseFrameAsync(WebSocketFrame frame) { var payload = frame.PayloadData; - close(payload, !payload.ContainsReservedCloseStatusCode, false); + await CloseAsync(payload, !payload.ContainsReservedCloseStatusCode, false).ConfigureAwait(false); return false; } @@ -352,7 +354,7 @@ namespace SocketHttpListener return true; } - private void processException(Exception exception, string message) + private async Task ProcessExceptionAsync(Exception exception, string message) { var code = CloseStatusCode.Abnormal; var reason = message; @@ -365,25 +367,31 @@ namespace SocketHttpListener error(message ?? code.GetMessage(), exception); if (_readyState == WebSocketState.Connecting) - Close(HttpStatusCode.BadRequest); + { + await CloseAsync(HttpStatusCode.BadRequest).ConfigureAwait(false); + } else - close(code, reason ?? code.GetMessage(), false); + { + await CloseAsync(code, reason ?? code.GetMessage(), false).ConfigureAwait(false); + } } - private bool processFragmentedFrame(WebSocketFrame frame) + private Task ProcessFragmentedFrameAsync(WebSocketFrame frame) { return frame.IsContinuation // Not first fragment - ? true - : processFragments(frame); + ? Task.FromResult(true) + : ProcessFragmentsAsync(frame); } - private bool processFragments(WebSocketFrame first) + private async Task ProcessFragmentsAsync(WebSocketFrame first) { using (var buff = new MemoryStream()) { buff.WriteBytes(first.PayloadData.ApplicationData); - if (!concatenateFragmentsInto(buff)) + if (!await ConcatenateFragmentsIntoAsync(buff).ConfigureAwait(false)) + { return false; + } byte[] data; if (_compression != CompressionMethod.None) @@ -412,36 +420,38 @@ namespace SocketHttpListener return true; } - private bool processUnsupportedFrame(WebSocketFrame frame, CloseStatusCode code, string reason) + private async Task ProcessUnsupportedFrameAsync(WebSocketFrame frame, CloseStatusCode code, string reason) { - processException(new WebSocketException(code, reason), null); + await ProcessExceptionAsync(new WebSocketException(code, reason), null).ConfigureAwait(false); return false; } - private bool processWebSocketFrame(WebSocketFrame frame) + private Task ProcessWebSocketFrameAsync(WebSocketFrame frame) { + // TODO: @bond change to if/else chain return frame.IsCompressed && _compression == CompressionMethod.None - ? processUnsupportedFrame( + ? ProcessUnsupportedFrameAsync( frame, CloseStatusCode.IncorrectData, "A compressed data has been received without available decompression method.") : frame.IsFragmented - ? processFragmentedFrame(frame) + ? ProcessFragmentedFrameAsync(frame) : frame.IsData - ? processDataFrame(frame) + ? Task.FromResult(processDataFrame(frame)) : frame.IsPing - ? processPingFrame(frame) + ? Task.FromResult(processPingFrame(frame)) : frame.IsPong - ? processPongFrame(frame) + ? Task.FromResult(processPongFrame(frame)) : frame.IsClose - ? processCloseFrame(frame) - : processUnsupportedFrame(frame, CloseStatusCode.PolicyViolation, null); + ? ProcessCloseFrameAsync(frame) + : ProcessUnsupportedFrameAsync(frame, CloseStatusCode.PolicyViolation, null); } - private bool send(Opcode opcode, Stream stream) + private async Task SendAsync(Opcode opcode, Stream stream) { - lock (_forSend) + await _forSend.WaitAsync().ConfigureAwait(false); + try { var src = stream; var compressed = false; @@ -454,7 +464,7 @@ namespace SocketHttpListener compressed = true; } - sent = send(opcode, Mask.Unmask, stream, compressed); + sent = await SendAsync(opcode, Mask.Unmask, stream, compressed).ConfigureAwait(false); if (!sent) error("Sending a data has been interrupted."); } @@ -472,16 +482,20 @@ namespace SocketHttpListener return sent; } + finally + { + _forSend.Release(); + } } - private bool send(Opcode opcode, Mask mask, Stream stream, bool compressed) + private async Task SendAsync(Opcode opcode, Mask mask, Stream stream, bool compressed) { var len = stream.Length; /* Not fragmented */ if (len == 0) - return send(Fin.Final, opcode, mask, new byte[0], compressed); + return await SendAsync(Fin.Final, opcode, mask, new byte[0], compressed).ConfigureAwait(false); var quo = len / FragmentLength; var rem = (int)(len % FragmentLength); @@ -490,26 +504,26 @@ namespace SocketHttpListener if (quo == 0) { buff = new byte[rem]; - return stream.Read(buff, 0, rem) == rem && - send(Fin.Final, opcode, mask, buff, compressed); + return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem && + await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false); } buff = new byte[FragmentLength]; if (quo == 1 && rem == 0) - return stream.Read(buff, 0, FragmentLength) == FragmentLength && - send(Fin.Final, opcode, mask, buff, compressed); + return await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) == FragmentLength && + await SendAsync(Fin.Final, opcode, mask, buff, compressed).ConfigureAwait(false); /* Send fragmented */ // Begin - if (stream.Read(buff, 0, FragmentLength) != FragmentLength || - !send(Fin.More, opcode, mask, buff, compressed)) + if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength || + !await SendAsync(Fin.More, opcode, mask, buff, compressed).ConfigureAwait(false)) return false; var n = rem == 0 ? quo - 2 : quo - 1; for (long i = 0; i < n; i++) - if (stream.Read(buff, 0, FragmentLength) != FragmentLength || - !send(Fin.More, Opcode.Cont, mask, buff, compressed)) + if (await stream.ReadAsync(buff, 0, FragmentLength).ConfigureAwait(false) != FragmentLength || + !await SendAsync(Fin.More, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false)) return false; // End @@ -518,98 +532,88 @@ namespace SocketHttpListener else buff = new byte[rem]; - return stream.Read(buff, 0, rem) == rem && - send(Fin.Final, Opcode.Cont, mask, buff, compressed); + return await stream.ReadAsync(buff, 0, rem).ConfigureAwait(false) == rem && + await SendAsync(Fin.Final, Opcode.Cont, mask, buff, compressed).ConfigureAwait(false); } - private bool send(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed) + private Task SendAsync(Fin fin, Opcode opcode, Mask mask, byte[] data, bool compressed) { lock (_forConn) { if (_readyState != WebSocketState.Open) { - return false; + return Task.FromResult(false); } - return writeBytes( + return WriteBytesAsync( WebSocketFrame.CreateWebSocketFrame(fin, opcode, mask, data, compressed).ToByteArray()); } } - private Task sendAsync(Opcode opcode, Stream stream) - { - var completionSource = new TaskCompletionSource(); - Task.Run(() => - { - try - { - send(opcode, stream); - completionSource.TrySetResult(true); - } - catch (Exception ex) - { - completionSource.TrySetException(ex); - } - }); - return completionSource.Task; - } - // As server - private bool sendHttpResponse(HttpResponse response) - { - return writeBytes(response.ToByteArray()); - } + private Task SendHttpResponseAsync(HttpResponse response) + => WriteBytesAsync(response.ToByteArray()); private void startReceiving() { if (_messageEventQueue.Count > 0) + { _messageEventQueue.Clear(); + } _exitReceiving = new AutoResetEvent(false); _receivePong = new AutoResetEvent(false); Action receive = null; - receive = () => WebSocketFrame.ReadAsync( - _stream, - true, - frame => - { - if (processWebSocketFrame(frame) && _readyState != WebSocketState.Closed) - { - receive(); + receive = async () => await WebSocketFrame.ReadAsync( + _stream, + true, + async frame => + { + if (await ProcessWebSocketFrameAsync(frame).ConfigureAwait(false) && _readyState != WebSocketState.Closed) + { + receive(); - if (!frame.IsData) - return; + if (!frame.IsData) + { + return; + } - lock (_forEvent) - { - try - { - var e = dequeueFromMessageEventQueue(); - if (e != null && _readyState == WebSocketState.Open) - OnMessage.Emit(this, e); - } - catch (Exception ex) - { - processException(ex, "An exception has occurred while OnMessage."); - } - } - } - else if (_exitReceiving != null) - { - _exitReceiving.Set(); - } - }, - ex => processException(ex, "An exception has occurred while receiving a message.")); + await _forEvent.WaitAsync().ConfigureAwait(false); + + try + { + var e = dequeueFromMessageEventQueue(); + if (e != null && _readyState == WebSocketState.Open) + { + OnMessage.Emit(this, e); + } + } + catch (Exception ex) + { + await ProcessExceptionAsync(ex, "An exception has occurred while OnMessage.").ConfigureAwait(false); + } + finally + { + _forEvent.Release(); + } + + } + else if (_exitReceiving != null) + { + _exitReceiving.Set(); + } + }, + async ex => await ProcessExceptionAsync(ex, "An exception has occurred while receiving a message.")).ConfigureAwait(false); receive(); } - private bool writeBytes(byte[] data) + private async Task WriteBytesAsync(byte[] data) { try { - _stream.Write(data, 0, data.Length); + await _stream.WriteAsync(data, 0, data.Length).ConfigureAwait(false); return true; } catch (Exception) @@ -623,10 +627,10 @@ namespace SocketHttpListener #region Internal Methods // As server - internal void Close(HttpResponse response) + internal async Task CloseAsync(HttpResponse response) { _readyState = WebSocketState.CloseSent; - sendHttpResponse(response); + await SendHttpResponseAsync(response).ConfigureAwait(false); closeServerResources(); @@ -634,22 +638,20 @@ namespace SocketHttpListener } // As server - internal void Close(HttpStatusCode code) - { - Close(createHandshakeCloseResponse(code)); - } + internal Task CloseAsync(HttpStatusCode code) + => CloseAsync(createHandshakeCloseResponse(code)); // As server - public void ConnectAsServer() + public async Task ConnectAsServer() { try { _readyState = WebSocketState.Open; - open(); + await OpenAsync().ConfigureAwait(false); } catch (Exception ex) { - processException(ex, "An exception has occurred while connecting."); + await ProcessExceptionAsync(ex, "An exception has occurred while connecting.").ConfigureAwait(false); } } @@ -660,18 +662,18 @@ namespace SocketHttpListener /// /// Closes the WebSocket connection, and releases all associated resources. /// - public void Close() + public Task CloseAsync() { var msg = _readyState.CheckIfClosable(); if (msg != null) { error(msg); - return; + return Task.CompletedTask; } var send = _readyState == WebSocketState.Open; - close(new PayloadData(), send, send); + return CloseAsync(new PayloadData(), send, send); } /// @@ -689,11 +691,11 @@ namespace SocketHttpListener /// /// A that represents the reason for the close. /// - public void Close(CloseStatusCode code, string reason) + public async Task CloseAsync(CloseStatusCode code, string reason) { byte[] data = null; var msg = _readyState.CheckIfClosable() ?? - (data = ((ushort)code).Append(reason)).CheckIfValidControlData("reason"); + (data = await ((ushort)code).AppendAsync(reason).ConfigureAwait(false)).CheckIfValidControlData("reason"); if (msg != null) { @@ -703,7 +705,7 @@ namespace SocketHttpListener } var send = _readyState == WebSocketState.Open && !code.IsReserved(); - close(new PayloadData(data), send, send); + await CloseAsync(new PayloadData(data), send, send).ConfigureAwait(false); } /// @@ -728,7 +730,7 @@ namespace SocketHttpListener throw new Exception(msg); } - return sendAsync(Opcode.Binary, new MemoryStream(data)); + return SendAsync(Opcode.Binary, new MemoryStream(data)); } /// @@ -753,7 +755,7 @@ namespace SocketHttpListener throw new Exception(msg); } - return sendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data))); + return SendAsync(Opcode.Text, new MemoryStream(Encoding.UTF8.GetBytes(data))); } #endregion @@ -768,7 +770,7 @@ namespace SocketHttpListener /// void IDisposable.Dispose() { - Close(CloseStatusCode.Away, null); + CloseAsync(CloseStatusCode.Away, null).GetAwaiter().GetResult(); } #endregion diff --git a/SocketHttpListener/WebSocketFrame.cs b/SocketHttpListener/WebSocketFrame.cs index 74ed23c45..8ec64026b 100644 --- a/SocketHttpListener/WebSocketFrame.cs +++ b/SocketHttpListener/WebSocketFrame.cs @@ -2,6 +2,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.IO; +using System.Threading.Tasks; namespace SocketHttpListener { @@ -177,7 +178,7 @@ namespace SocketHttpListener return opcode == Opcode.Text || opcode == Opcode.Binary; } - private static WebSocketFrame read(byte[] header, Stream stream, bool unmask) + private static async Task ReadAsync(byte[] header, Stream stream, bool unmask) { /* Header */ @@ -229,7 +230,7 @@ namespace SocketHttpListener ? 2 : 8; - var extPayloadLen = size > 0 ? stream.ReadBytes(size) : new byte[0]; + var extPayloadLen = size > 0 ? await stream.ReadBytesAsync(size).ConfigureAwait(false) : Array.Empty(); if (size > 0 && extPayloadLen.Length != size) throw new WebSocketException( "The 'Extended Payload Length' of a frame cannot be read from the data source."); @@ -239,7 +240,7 @@ namespace SocketHttpListener /* Masking Key */ var masked = mask == Mask.Mask; - var maskingKey = masked ? stream.ReadBytes(4) : new byte[0]; + var maskingKey = masked ? await stream.ReadBytesAsync(4).ConfigureAwait(false) : Array.Empty(); if (masked && maskingKey.Length != 4) throw new WebSocketException( "The 'Masking Key' of a frame cannot be read from the data source."); @@ -264,8 +265,8 @@ namespace SocketHttpListener "The length of 'Payload Data' of a frame is greater than the allowable length."); data = payloadLen > 126 - ? stream.ReadBytes((long)len, 1024) - : stream.ReadBytes((int)len); + ? await stream.ReadBytesAsync((long)len, 1024).ConfigureAwait(false) + : await stream.ReadBytesAsync((int)len).ConfigureAwait(false); //if (data.LongLength != (long)len) // throw new WebSocketException( @@ -273,7 +274,7 @@ namespace SocketHttpListener } else { - data = new byte[0]; + data = Array.Empty(); } var payload = new PayloadData(data, masked); @@ -281,7 +282,7 @@ namespace SocketHttpListener { payload.Mask(maskingKey); frame._mask = Mask.Unmask; - frame._maskingKey = new byte[0]; + frame._maskingKey = Array.Empty(); } frame._payloadData = payload; @@ -302,10 +303,10 @@ namespace SocketHttpListener return new WebSocketFrame(Opcode.Close, mask, payload); } - internal static WebSocketFrame CreateCloseFrame(Mask mask, CloseStatusCode code, string reason) + internal static async Task CreateCloseFrameAsync(Mask mask, CloseStatusCode code, string reason) { return new WebSocketFrame( - Opcode.Close, mask, new PayloadData(((ushort)code).Append(reason))); + Opcode.Close, mask, new PayloadData(await ((ushort)code).AppendAsync(reason).ConfigureAwait(false))); } internal static WebSocketFrame CreatePingFrame(Mask mask) @@ -329,41 +330,39 @@ namespace SocketHttpListener return new WebSocketFrame(fin, opcode, mask, new PayloadData(data), compressed); } - internal static WebSocketFrame Read(Stream stream) - { - return Read(stream, true); - } + internal static Task ReadAsync(Stream stream) + => ReadAsync(stream, true); - internal static WebSocketFrame Read(Stream stream, bool unmask) + internal static async Task ReadAsync(Stream stream, bool unmask) { - var header = stream.ReadBytes(2); + var header = await stream.ReadBytesAsync(2).ConfigureAwait(false); if (header.Length != 2) + { throw new WebSocketException( "The header part of a frame cannot be read from the data source."); + } - return read(header, stream, unmask); + return await ReadAsync(header, stream, unmask).ConfigureAwait(false); } - internal static async void ReadAsync( + internal static async Task ReadAsync( Stream stream, bool unmask, Action completed, Action error) { try { var header = await stream.ReadBytesAsync(2).ConfigureAwait(false); if (header.Length != 2) + { throw new WebSocketException( "The header part of a frame cannot be read from the data source."); + } - var frame = read(header, stream, unmask); - if (completed != null) - completed(frame); + var frame = await ReadAsync(header, stream, unmask).ConfigureAwait(false); + completed?.Invoke(frame); } catch (Exception ex) { - if (error != null) - { - error(ex); - } + error.Invoke(ex); } }