Merge pull request #7080 from crobibero/ws-token

This commit is contained in:
Cody Robibero 2022-01-03 17:48:21 -07:00 committed by GitHub
commit c6a1dcf420
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 43 deletions

View File

@ -42,17 +42,14 @@ namespace Emby.Server.Implementations.HttpServer
/// <param name="logger">The logger.</param> /// <param name="logger">The logger.</param>
/// <param name="socket">The socket.</param> /// <param name="socket">The socket.</param>
/// <param name="remoteEndPoint">The remote end point.</param> /// <param name="remoteEndPoint">The remote end point.</param>
/// <param name="query">The query.</param>
public WebSocketConnection( public WebSocketConnection(
ILogger<WebSocketConnection> logger, ILogger<WebSocketConnection> logger,
WebSocket socket, WebSocket socket,
IPAddress? remoteEndPoint, IPAddress? remoteEndPoint)
IQueryCollection query)
{ {
_logger = logger; _logger = logger;
_socket = socket; _socket = socket;
RemoteEndPoint = remoteEndPoint; RemoteEndPoint = remoteEndPoint;
QueryString = query;
_jsonOptions = JsonDefaults.Options; _jsonOptions = JsonDefaults.Options;
LastActivityDate = DateTime.Now; LastActivityDate = DateTime.Now;
@ -81,12 +78,6 @@ namespace Emby.Server.Implementations.HttpServer
/// <inheritdoc /> /// <inheritdoc />
public DateTime LastKeepAliveDate { get; set; } public DateTime LastKeepAliveDate { get; set; }
/// <summary>
/// Gets the query string.
/// </summary>
/// <value>The query string.</value>
public IQueryCollection QueryString { get; }
/// <summary> /// <summary>
/// Gets the state. /// Gets the state.
/// </summary> /// </summary>

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Threading.Tasks; using System.Threading.Tasks;
using MediaBrowser.Common.Extensions;
using MediaBrowser.Controller.Net; using MediaBrowser.Controller.Net;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
@ -50,8 +51,7 @@ namespace Emby.Server.Implementations.HttpServer
using var connection = new WebSocketConnection( using var connection = new WebSocketConnection(
_loggerFactory.CreateLogger<WebSocketConnection>(), _loggerFactory.CreateLogger<WebSocketConnection>(),
webSocket, webSocket,
context.Connection.RemoteIpAddress, context.GetNormalizedRemoteIp())
context.Request.Query)
{ {
OnReceive = ProcessWebSocketMessageReceived OnReceive = ProcessWebSocketMessageReceived
}; };
@ -59,7 +59,7 @@ namespace Emby.Server.Implementations.HttpServer
var tasks = new Task[_webSocketListeners.Length]; var tasks = new Task[_webSocketListeners.Length];
for (var i = 0; i < _webSocketListeners.Length; ++i) for (var i = 0; i < _webSocketListeners.Length; ++i)
{ {
tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection); tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection, context);
} }
await Task.WhenAll(tasks).ConfigureAwait(false); await Task.WhenAll(tasks).ConfigureAwait(false);

View File

@ -6,6 +6,7 @@ using System.Linq;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MediaBrowser.Common.Extensions;
using MediaBrowser.Controller.Net; using MediaBrowser.Controller.Net;
using MediaBrowser.Controller.Session; using MediaBrowser.Controller.Session;
using MediaBrowser.Model.Net; using MediaBrowser.Model.Net;
@ -50,16 +51,10 @@ namespace Emby.Server.Implementations.Session
/// </summary> /// </summary>
private readonly object _webSocketsLock = new object(); private readonly object _webSocketsLock = new object();
/// <summary>
/// The _session manager.
/// </summary>
private readonly ISessionManager _sessionManager; private readonly ISessionManager _sessionManager;
/// <summary>
/// The _logger.
/// </summary>
private readonly ILogger<SessionWebSocketListener> _logger; private readonly ILogger<SessionWebSocketListener> _logger;
private readonly ILoggerFactory _loggerFactory; private readonly ILoggerFactory _loggerFactory;
private readonly IAuthorizationContext _authorizationContext;
/// <summary> /// <summary>
/// The KeepAlive cancellation token. /// The KeepAlive cancellation token.
@ -72,14 +67,17 @@ namespace Emby.Server.Implementations.Session
/// <param name="logger">The logger.</param> /// <param name="logger">The logger.</param>
/// <param name="sessionManager">The session manager.</param> /// <param name="sessionManager">The session manager.</param>
/// <param name="loggerFactory">The logger factory.</param> /// <param name="loggerFactory">The logger factory.</param>
/// <param name="authorizationContext">The authorization context.</param>
public SessionWebSocketListener( public SessionWebSocketListener(
ILogger<SessionWebSocketListener> logger, ILogger<SessionWebSocketListener> logger,
ISessionManager sessionManager, ISessionManager sessionManager,
ILoggerFactory loggerFactory) ILoggerFactory loggerFactory,
IAuthorizationContext authorizationContext)
{ {
_logger = logger; _logger = logger;
_sessionManager = sessionManager; _sessionManager = sessionManager;
_loggerFactory = loggerFactory; _loggerFactory = loggerFactory;
_authorizationContext = authorizationContext;
} }
/// <inheritdoc /> /// <inheritdoc />
@ -97,9 +95,9 @@ namespace Emby.Server.Implementations.Session
=> Task.CompletedTask; => Task.CompletedTask;
/// <inheritdoc /> /// <inheritdoc />
public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext)
{ {
var session = await GetSession(connection.QueryString, connection.RemoteEndPoint.ToString()).ConfigureAwait(false); var session = await GetSession(httpContext, connection.RemoteEndPoint?.ToString()).ConfigureAwait(false);
if (session != null) if (session != null)
{ {
EnsureController(session, connection); EnsureController(session, connection);
@ -107,25 +105,28 @@ namespace Emby.Server.Implementations.Session
} }
else else
{ {
_logger.LogWarning("Unable to determine session based on query string: {0}", connection.QueryString); _logger.LogWarning("Unable to determine session based on query string: {0}", httpContext.Request.QueryString);
} }
} }
private Task<SessionInfo> GetSession(IQueryCollection queryString, string remoteEndpoint) private async Task<SessionInfo> GetSession(HttpContext httpContext, string remoteEndpoint)
{ {
if (queryString == null) var authorizationInfo = await _authorizationContext.GetAuthorizationInfo(httpContext)
.ConfigureAwait(false);
if (!authorizationInfo.IsAuthenticated)
{ {
return null; return null;
} }
var token = queryString["api_key"]; var deviceId = authorizationInfo.DeviceId;
if (string.IsNullOrWhiteSpace(token)) if (httpContext.Request.Query.TryGetValue("deviceId", out var queryDeviceId))
{ {
return null; deviceId = queryDeviceId;
} }
var deviceId = queryString["deviceId"]; return await _sessionManager.GetSessionByAuthenticationToken(authorizationInfo.Token, deviceId, remoteEndpoint)
return _sessionManager.GetSessionByAuthenticationToken(token, deviceId, remoteEndpoint); .ConfigureAwait(false);
} }
private void EnsureController(SessionInfo session, IWebSocketConnection connection) private void EnsureController(SessionInfo session, IWebSocketConnection connection)

View File

@ -11,6 +11,7 @@ using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using MediaBrowser.Model.Net; using MediaBrowser.Model.Net;
using MediaBrowser.Model.Session; using MediaBrowser.Model.Session;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
namespace MediaBrowser.Controller.Net namespace MediaBrowser.Controller.Net
@ -95,7 +96,7 @@ namespace MediaBrowser.Controller.Net
} }
/// <inheritdoc /> /// <inheritdoc />
public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) => Task.CompletedTask; public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) => Task.CompletedTask;
/// <summary> /// <summary>
/// Starts sending messages over a web socket. /// Starts sending messages over a web socket.

View File

@ -29,12 +29,6 @@ namespace MediaBrowser.Controller.Net
/// <value>The date of last Keeplive received.</value> /// <value>The date of last Keeplive received.</value>
DateTime LastKeepAliveDate { get; set; } DateTime LastKeepAliveDate { get; set; }
/// <summary>
/// Gets the query string.
/// </summary>
/// <value>The query string.</value>
IQueryCollection QueryString { get; }
/// <summary> /// <summary>
/// Gets or sets the receive action. /// Gets or sets the receive action.
/// </summary> /// </summary>

View File

@ -1,4 +1,5 @@
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace MediaBrowser.Controller.Net namespace MediaBrowser.Controller.Net
{ {
@ -18,7 +19,8 @@ namespace MediaBrowser.Controller.Net
/// Processes a new web socket connection. /// Processes a new web socket connection.
/// </summary> /// </summary>
/// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param> /// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param>
/// <param name="httpContext">The current http context.</param>
/// <returns>Task.</returns> /// <returns>Task.</returns>
Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection); Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext);
} }
} }

View File

@ -13,7 +13,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
[Fact] [Fact]
public void DeserializeWebSocketMessage_SingleSegment_Success() public void DeserializeWebSocketMessage_SingleSegment_Success()
{ {
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json"); var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed); con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
Assert.Equal(109, bytesConsumed); Assert.Equal(109, bytesConsumed);
@ -23,7 +23,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
public void DeserializeWebSocketMessage_MultipleSegments_Success() public void DeserializeWebSocketMessage_MultipleSegments_Success()
{ {
const int SplitPos = 64; const int SplitPos = 64;
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json"); var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos)); var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos));
var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos)); var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos));
@ -34,7 +34,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
[Fact] [Fact]
public void DeserializeWebSocketMessage_ValidPartial_Success() public void DeserializeWebSocketMessage_ValidPartial_Success()
{ {
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json"); var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed); con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
Assert.Equal(109, bytesConsumed); Assert.Equal(109, bytesConsumed);
@ -43,7 +43,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
[Fact] [Fact]
public void DeserializeWebSocketMessage_Partial_ThrowJsonException() public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
{ {
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!); var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json"); var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed)); Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed));
} }