diff --git a/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs b/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs
index b3bd3421a..b87f1bc22 100644
--- a/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs
+++ b/Emby.Server.Implementations/HttpServer/WebSocketConnection.cs
@@ -42,17 +42,14 @@ namespace Emby.Server.Implementations.HttpServer
/// The logger.
/// The socket.
/// The remote end point.
- /// The query.
public WebSocketConnection(
ILogger logger,
WebSocket socket,
- IPAddress? remoteEndPoint,
- IQueryCollection query)
+ IPAddress? remoteEndPoint)
{
_logger = logger;
_socket = socket;
RemoteEndPoint = remoteEndPoint;
- QueryString = query;
_jsonOptions = JsonDefaults.Options;
LastActivityDate = DateTime.Now;
@@ -81,12 +78,6 @@ namespace Emby.Server.Implementations.HttpServer
///
public DateTime LastKeepAliveDate { get; set; }
- ///
- /// Gets the query string.
- ///
- /// The query string.
- public IQueryCollection QueryString { get; }
-
///
/// Gets the state.
///
diff --git a/Emby.Server.Implementations/HttpServer/WebSocketManager.cs b/Emby.Server.Implementations/HttpServer/WebSocketManager.cs
index e99876dce..4f7d1c40a 100644
--- a/Emby.Server.Implementations/HttpServer/WebSocketManager.cs
+++ b/Emby.Server.Implementations/HttpServer/WebSocketManager.cs
@@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Threading.Tasks;
+using MediaBrowser.Common.Extensions;
using MediaBrowser.Controller.Net;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
@@ -50,8 +51,7 @@ namespace Emby.Server.Implementations.HttpServer
using var connection = new WebSocketConnection(
_loggerFactory.CreateLogger(),
webSocket,
- context.Connection.RemoteIpAddress,
- context.Request.Query)
+ context.GetNormalizedRemoteIp())
{
OnReceive = ProcessWebSocketMessageReceived
};
@@ -59,7 +59,7 @@ namespace Emby.Server.Implementations.HttpServer
var tasks = new Task[_webSocketListeners.Length];
for (var i = 0; i < _webSocketListeners.Length; ++i)
{
- tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection);
+ tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection, context);
}
await Task.WhenAll(tasks).ConfigureAwait(false);
diff --git a/Emby.Server.Implementations/Session/SessionWebSocketListener.cs b/Emby.Server.Implementations/Session/SessionWebSocketListener.cs
index 2a14a8c7b..a085ee546 100644
--- a/Emby.Server.Implementations/Session/SessionWebSocketListener.cs
+++ b/Emby.Server.Implementations/Session/SessionWebSocketListener.cs
@@ -6,6 +6,7 @@ using System.Linq;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
+using MediaBrowser.Common.Extensions;
using MediaBrowser.Controller.Net;
using MediaBrowser.Controller.Session;
using MediaBrowser.Model.Net;
@@ -50,16 +51,10 @@ namespace Emby.Server.Implementations.Session
///
private readonly object _webSocketsLock = new object();
- ///
- /// The _session manager.
- ///
private readonly ISessionManager _sessionManager;
-
- ///
- /// The _logger.
- ///
private readonly ILogger _logger;
private readonly ILoggerFactory _loggerFactory;
+ private readonly IAuthorizationContext _authorizationContext;
///
/// The KeepAlive cancellation token.
@@ -72,14 +67,17 @@ namespace Emby.Server.Implementations.Session
/// The logger.
/// The session manager.
/// The logger factory.
+ /// The authorization context.
public SessionWebSocketListener(
ILogger logger,
ISessionManager sessionManager,
- ILoggerFactory loggerFactory)
+ ILoggerFactory loggerFactory,
+ IAuthorizationContext authorizationContext)
{
_logger = logger;
_sessionManager = sessionManager;
_loggerFactory = loggerFactory;
+ _authorizationContext = authorizationContext;
}
///
@@ -97,9 +95,9 @@ namespace Emby.Server.Implementations.Session
=> Task.CompletedTask;
///
- public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection)
+ public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext)
{
- var session = await GetSession(connection.QueryString, connection.RemoteEndPoint.ToString()).ConfigureAwait(false);
+ var session = await GetSession(httpContext, connection.RemoteEndPoint?.ToString()).ConfigureAwait(false);
if (session != null)
{
EnsureController(session, connection);
@@ -107,25 +105,28 @@ namespace Emby.Server.Implementations.Session
}
else
{
- _logger.LogWarning("Unable to determine session based on query string: {0}", connection.QueryString);
+ _logger.LogWarning("Unable to determine session based on query string: {0}", httpContext.Request.QueryString);
}
}
- private Task GetSession(IQueryCollection queryString, string remoteEndpoint)
+ private async Task GetSession(HttpContext httpContext, string remoteEndpoint)
{
- if (queryString == null)
+ var authorizationInfo = await _authorizationContext.GetAuthorizationInfo(httpContext)
+ .ConfigureAwait(false);
+
+ if (!authorizationInfo.IsAuthenticated)
{
return null;
}
- var token = queryString["api_key"];
- if (string.IsNullOrWhiteSpace(token))
+ var deviceId = authorizationInfo.DeviceId;
+ if (httpContext.Request.Query.TryGetValue("deviceId", out var queryDeviceId))
{
- return null;
+ deviceId = queryDeviceId;
}
- var deviceId = queryString["deviceId"];
- return _sessionManager.GetSessionByAuthenticationToken(token, deviceId, remoteEndpoint);
+ return await _sessionManager.GetSessionByAuthenticationToken(authorizationInfo.Token, deviceId, remoteEndpoint)
+ .ConfigureAwait(false);
}
private void EnsureController(SessionInfo session, IWebSocketConnection connection)
diff --git a/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs b/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs
index 0813a8e7d..eadc09fd4 100644
--- a/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs
+++ b/MediaBrowser.Controller/Net/BasePeriodicWebSocketListener.cs
@@ -11,6 +11,7 @@ using System.Threading;
using System.Threading.Tasks;
using MediaBrowser.Model.Net;
using MediaBrowser.Model.Session;
+using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
namespace MediaBrowser.Controller.Net
@@ -95,7 +96,7 @@ namespace MediaBrowser.Controller.Net
}
///
- public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) => Task.CompletedTask;
+ public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) => Task.CompletedTask;
///
/// Starts sending messages over a web socket.
diff --git a/MediaBrowser.Controller/Net/IWebSocketConnection.cs b/MediaBrowser.Controller/Net/IWebSocketConnection.cs
index c8c5caf80..2c6483ae2 100644
--- a/MediaBrowser.Controller/Net/IWebSocketConnection.cs
+++ b/MediaBrowser.Controller/Net/IWebSocketConnection.cs
@@ -29,12 +29,6 @@ namespace MediaBrowser.Controller.Net
/// The date of last Keeplive received.
DateTime LastKeepAliveDate { get; set; }
- ///
- /// Gets the query string.
- ///
- /// The query string.
- IQueryCollection QueryString { get; }
-
///
/// Gets or sets the receive action.
///
diff --git a/MediaBrowser.Controller/Net/IWebSocketListener.cs b/MediaBrowser.Controller/Net/IWebSocketListener.cs
index f1a75d518..672bb8cbf 100644
--- a/MediaBrowser.Controller/Net/IWebSocketListener.cs
+++ b/MediaBrowser.Controller/Net/IWebSocketListener.cs
@@ -1,4 +1,5 @@
using System.Threading.Tasks;
+using Microsoft.AspNetCore.Http;
namespace MediaBrowser.Controller.Net
{
@@ -18,7 +19,8 @@ namespace MediaBrowser.Controller.Net
/// Processes a new web socket connection.
///
/// An instance of the interface.
+ /// The current http context.
/// Task.
- Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection);
+ Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext);
}
}
diff --git a/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs b/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs
index 1ce2096ea..ef8f7cd90 100644
--- a/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs
+++ b/tests/Jellyfin.Server.Implementations.Tests/HttpServer/WebSocketConnectionTests.cs
@@ -13,7 +13,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
[Fact]
public void DeserializeWebSocketMessage_SingleSegment_Success()
{
- var con = new WebSocketConnection(new NullLogger(), null!, null!, null!);
+ var con = new WebSocketConnection(new NullLogger(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
con.DeserializeWebSocketMessage(new ReadOnlySequence(bytes), out var bytesConsumed);
Assert.Equal(109, bytesConsumed);
@@ -23,7 +23,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
public void DeserializeWebSocketMessage_MultipleSegments_Success()
{
const int SplitPos = 64;
- var con = new WebSocketConnection(new NullLogger(), null!, null!, null!);
+ var con = new WebSocketConnection(new NullLogger(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
var seg1 = new BufferSegment(new Memory(bytes, 0, SplitPos));
var seg2 = seg1.Append(new Memory(bytes, SplitPos, bytes.Length - SplitPos));
@@ -34,7 +34,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
[Fact]
public void DeserializeWebSocketMessage_ValidPartial_Success()
{
- var con = new WebSocketConnection(new NullLogger(), null!, null!, null!);
+ var con = new WebSocketConnection(new NullLogger(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
con.DeserializeWebSocketMessage(new ReadOnlySequence(bytes), out var bytesConsumed);
Assert.Equal(109, bytesConsumed);
@@ -43,7 +43,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
[Fact]
public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
{
- var con = new WebSocketConnection(new NullLogger(), null!, null!, null!);
+ var con = new WebSocketConnection(new NullLogger(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
Assert.Throws(() => con.DeserializeWebSocketMessage(new ReadOnlySequence(bytes), out var bytesConsumed));
}