Use IAuthorizationContext for websocket
This commit is contained in:
parent
1725b2ee69
commit
0765fd568f
|
@ -42,17 +42,14 @@ namespace Emby.Server.Implementations.HttpServer
|
|||
/// <param name="logger">The logger.</param>
|
||||
/// <param name="socket">The socket.</param>
|
||||
/// <param name="remoteEndPoint">The remote end point.</param>
|
||||
/// <param name="query">The query.</param>
|
||||
public WebSocketConnection(
|
||||
ILogger<WebSocketConnection> logger,
|
||||
WebSocket socket,
|
||||
IPAddress? remoteEndPoint,
|
||||
IQueryCollection query)
|
||||
IPAddress? remoteEndPoint)
|
||||
{
|
||||
_logger = logger;
|
||||
_socket = socket;
|
||||
RemoteEndPoint = remoteEndPoint;
|
||||
QueryString = query;
|
||||
|
||||
_jsonOptions = JsonDefaults.Options;
|
||||
LastActivityDate = DateTime.Now;
|
||||
|
@ -81,12 +78,6 @@ namespace Emby.Server.Implementations.HttpServer
|
|||
/// <inheritdoc />
|
||||
public DateTime LastKeepAliveDate { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the query string.
|
||||
/// </summary>
|
||||
/// <value>The query string.</value>
|
||||
public IQueryCollection QueryString { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the state.
|
||||
/// </summary>
|
||||
|
|
|
@ -7,6 +7,7 @@ using System.Collections.Generic;
|
|||
using System.Linq;
|
||||
using System.Net.WebSockets;
|
||||
using System.Threading.Tasks;
|
||||
using MediaBrowser.Common.Extensions;
|
||||
using MediaBrowser.Controller.Net;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
@ -50,8 +51,7 @@ namespace Emby.Server.Implementations.HttpServer
|
|||
using var connection = new WebSocketConnection(
|
||||
_loggerFactory.CreateLogger<WebSocketConnection>(),
|
||||
webSocket,
|
||||
context.Connection.RemoteIpAddress,
|
||||
context.Request.Query)
|
||||
context.GetNormalizedRemoteIp())
|
||||
{
|
||||
OnReceive = ProcessWebSocketMessageReceived
|
||||
};
|
||||
|
@ -59,7 +59,7 @@ namespace Emby.Server.Implementations.HttpServer
|
|||
var tasks = new Task[_webSocketListeners.Length];
|
||||
for (var i = 0; i < _webSocketListeners.Length; ++i)
|
||||
{
|
||||
tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection);
|
||||
tasks[i] = _webSocketListeners[i].ProcessWebSocketConnectedAsync(connection, context);
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks).ConfigureAwait(false);
|
||||
|
|
|
@ -6,6 +6,7 @@ using System.Linq;
|
|||
using System.Net.WebSockets;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using MediaBrowser.Common.Extensions;
|
||||
using MediaBrowser.Controller.Net;
|
||||
using MediaBrowser.Controller.Session;
|
||||
using MediaBrowser.Model.Net;
|
||||
|
@ -50,16 +51,10 @@ namespace Emby.Server.Implementations.Session
|
|||
/// </summary>
|
||||
private readonly object _webSocketsLock = new object();
|
||||
|
||||
/// <summary>
|
||||
/// The _session manager.
|
||||
/// </summary>
|
||||
private readonly ISessionManager _sessionManager;
|
||||
|
||||
/// <summary>
|
||||
/// The _logger.
|
||||
/// </summary>
|
||||
private readonly ILogger<SessionWebSocketListener> _logger;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly IAuthorizationContext _authorizationContext;
|
||||
|
||||
/// <summary>
|
||||
/// The KeepAlive cancellation token.
|
||||
|
@ -72,14 +67,17 @@ namespace Emby.Server.Implementations.Session
|
|||
/// <param name="logger">The logger.</param>
|
||||
/// <param name="sessionManager">The session manager.</param>
|
||||
/// <param name="loggerFactory">The logger factory.</param>
|
||||
/// <param name="authorizationContext">The authorization context.</param>
|
||||
public SessionWebSocketListener(
|
||||
ILogger<SessionWebSocketListener> logger,
|
||||
ISessionManager sessionManager,
|
||||
ILoggerFactory loggerFactory)
|
||||
ILoggerFactory loggerFactory,
|
||||
IAuthorizationContext authorizationContext)
|
||||
{
|
||||
_logger = logger;
|
||||
_sessionManager = sessionManager;
|
||||
_loggerFactory = loggerFactory;
|
||||
_authorizationContext = authorizationContext;
|
||||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
|
@ -97,9 +95,9 @@ namespace Emby.Server.Implementations.Session
|
|||
=> Task.CompletedTask;
|
||||
|
||||
/// <inheritdoc />
|
||||
public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection)
|
||||
public async Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext)
|
||||
{
|
||||
var session = await GetSession(connection.QueryString, connection.RemoteEndPoint.ToString()).ConfigureAwait(false);
|
||||
var session = await GetSession(httpContext, connection.RemoteEndPoint?.ToString()).ConfigureAwait(false);
|
||||
if (session != null)
|
||||
{
|
||||
EnsureController(session, connection);
|
||||
|
@ -107,25 +105,28 @@ namespace Emby.Server.Implementations.Session
|
|||
}
|
||||
else
|
||||
{
|
||||
_logger.LogWarning("Unable to determine session based on query string: {0}", connection.QueryString);
|
||||
_logger.LogWarning("Unable to determine session based on query string: {0}", httpContext.Request.QueryString);
|
||||
}
|
||||
}
|
||||
|
||||
private Task<SessionInfo> GetSession(IQueryCollection queryString, string remoteEndpoint)
|
||||
private async Task<SessionInfo> GetSession(HttpContext httpContext, string remoteEndpoint)
|
||||
{
|
||||
if (queryString == null)
|
||||
var authorizationInfo = await _authorizationContext.GetAuthorizationInfo(httpContext)
|
||||
.ConfigureAwait(false);
|
||||
|
||||
if (!authorizationInfo.IsAuthenticated)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var token = queryString["api_key"];
|
||||
if (string.IsNullOrWhiteSpace(token))
|
||||
var deviceId = authorizationInfo.DeviceId;
|
||||
if (httpContext.Request.Query.TryGetValue("deviceId", out var queryDeviceId))
|
||||
{
|
||||
return null;
|
||||
deviceId = queryDeviceId;
|
||||
}
|
||||
|
||||
var deviceId = queryString["deviceId"];
|
||||
return _sessionManager.GetSessionByAuthenticationToken(token, deviceId, remoteEndpoint);
|
||||
return await _sessionManager.GetSessionByAuthenticationToken(authorizationInfo.Token, deviceId, remoteEndpoint)
|
||||
.ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private void EnsureController(SessionInfo session, IWebSocketConnection connection)
|
||||
|
|
|
@ -11,6 +11,7 @@ using System.Threading;
|
|||
using System.Threading.Tasks;
|
||||
using MediaBrowser.Model.Net;
|
||||
using MediaBrowser.Model.Session;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace MediaBrowser.Controller.Net
|
||||
|
@ -95,7 +96,7 @@ namespace MediaBrowser.Controller.Net
|
|||
}
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection) => Task.CompletedTask;
|
||||
public Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext) => Task.CompletedTask;
|
||||
|
||||
/// <summary>
|
||||
/// Starts sending messages over a web socket.
|
||||
|
|
|
@ -29,12 +29,6 @@ namespace MediaBrowser.Controller.Net
|
|||
/// <value>The date of last Keeplive received.</value>
|
||||
DateTime LastKeepAliveDate { get; set; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the query string.
|
||||
/// </summary>
|
||||
/// <value>The query string.</value>
|
||||
IQueryCollection QueryString { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets or sets the receive action.
|
||||
/// </summary>
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
|
||||
namespace MediaBrowser.Controller.Net
|
||||
{
|
||||
|
@ -18,7 +19,8 @@ namespace MediaBrowser.Controller.Net
|
|||
/// Processes a new web socket connection.
|
||||
/// </summary>
|
||||
/// <param name="connection">An instance of the <see cref="IWebSocketConnection"/> interface.</param>
|
||||
/// <param name="httpContext">The current http context.</param>
|
||||
/// <returns>Task.</returns>
|
||||
Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection);
|
||||
Task ProcessWebSocketConnectedAsync(IWebSocketConnection connection, HttpContext httpContext);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
|
|||
[Fact]
|
||||
public void DeserializeWebSocketMessage_SingleSegment_Success()
|
||||
{
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
|
||||
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
|
||||
con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
|
||||
Assert.Equal(109, bytesConsumed);
|
||||
|
@ -23,7 +23,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
|
|||
public void DeserializeWebSocketMessage_MultipleSegments_Success()
|
||||
{
|
||||
const int SplitPos = 64;
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
|
||||
var bytes = File.ReadAllBytes("Test Data/HttpServer/ForceKeepAlive.json");
|
||||
var seg1 = new BufferSegment(new Memory<byte>(bytes, 0, SplitPos));
|
||||
var seg2 = seg1.Append(new Memory<byte>(bytes, SplitPos, bytes.Length - SplitPos));
|
||||
|
@ -34,7 +34,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
|
|||
[Fact]
|
||||
public void DeserializeWebSocketMessage_ValidPartial_Success()
|
||||
{
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
|
||||
var bytes = File.ReadAllBytes("Test Data/HttpServer/ValidPartial.json");
|
||||
con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed);
|
||||
Assert.Equal(109, bytesConsumed);
|
||||
|
@ -43,7 +43,7 @@ namespace Jellyfin.Server.Implementations.Tests.HttpServer
|
|||
[Fact]
|
||||
public void DeserializeWebSocketMessage_Partial_ThrowJsonException()
|
||||
{
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!, null!);
|
||||
var con = new WebSocketConnection(new NullLogger<WebSocketConnection>(), null!, null!);
|
||||
var bytes = File.ReadAllBytes("Test Data/HttpServer/Partial.json");
|
||||
Assert.Throws<JsonException>(() => con.DeserializeWebSocketMessage(new ReadOnlySequence<byte>(bytes), out var bytesConsumed));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user