Problem
The following code is pulling data off an exchange. It keeps the subscriptions to the channels because when the websocket client reconnects, all subscriptions are gone and we basically need to resubscribe. When a message is received, it parses it directly from ReadOnlyMemory<T>
to avoid string allocations and if it’s recognized as a subscription, it calls its callback. I would like to get a review on this code because managing websocket subscriptions is something common and this code going to be reused into other exchanges’ implementation.
public class Subscription
{
public Subscription(SubscriptionRequest request, Action<Notification> action)
{
Id = Guid.NewGuid();
Request = request;
Action = action;
}
public Guid Id { get; }
public SubscriptionRequest Request { get; }
public Action<Notification> Action { get; }
}
public class SubscriptionManager
{
private readonly ILogger<SubscriptionManager> _logger;
private readonly IList<Subscription> _subscriptions = new List<Subscription>();
public SubscriptionManager(ILoggerFactory? loggerFactory = default)
{
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<SubscriptionManager>();
}
public void AddSubscription(Subscription subscription)
{
lock (_subscriptions)
{
if (!_subscriptions.Contains(subscription))
{
_subscriptions.Add(subscription);
}
}
}
public void RemoveSubscription(Subscription subscription)
{
lock (_subscriptions)
{
if (_subscriptions.Contains(subscription))
{
_subscriptions.Remove(subscription);
}
}
}
public Subscription? GetSubscription(Guid id)
{
lock (_subscriptions)
{
return _subscriptions.FirstOrDefault(x => x.Id == id);
}
}
public IList<Subscription> GetSubscriptions()
{
lock (_subscriptions)
{
return _subscriptions;
}
}
public IEnumerable<Action<Notification>> GetCallbacks(string channel)
{
lock (_subscriptions)
{
foreach (var subscription in _subscriptions)
{
if (subscription.Request.Channels.Contains(channel))
{
yield return subscription.Action;
}
}
}
}
public void Reset()
{
lock (_subscriptions)
{
_subscriptions.Clear();
}
}
}
public sealed class BitClient : IDisposable
{
private readonly ILogger<BitClient> _logger;
private readonly string _accessKey;
private readonly string _secretKey;
private readonly RestClient _restClient;
private readonly StsWebSocketClient _socketClient;
private readonly SubscriptionManager _subscriptionManager;
public BitClient(BitEndpointType endpointType, string accessKey, string secretKey, ILoggerFactory? loggerFactory = default)
{
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<BitClient>();
_restClient = endpointType switch
{
BitEndpointType.Production => new RestClient("https://api.bit.com"),
BitEndpointType.Testnet => new RestClient("https://betaapi.bitexch.dev"),
_ => throw new NotSupportedException()
};
_socketClient = endpointType switch
{
BitEndpointType.Production => new StsWebSocketClient("wss://spot-ws.bit.com", loggerFactory),
BitEndpointType.Testnet => new StsWebSocketClient("wss://betaspot-ws.bitexch.dev", loggerFactory),
_ => throw new NotSupportedException()
};
_socketClient.Connected += OnConnected;
_socketClient.Disconnected += OnDisconnected;
_socketClient.MessageReceived += OnMessageReceived;
_accessKey = accessKey;
_secretKey = secretKey;
_subscriptionManager = new SubscriptionManager(loggerFactory);
}
public void Dispose()
{
_restClient.Dispose();
_socketClient.Connected -= OnConnected;
_socketClient.Disconnected -= OnDisconnected;
_socketClient.MessageReceived -= OnMessageReceived;
_socketClient.Dispose();
}
public Task StartAsync()
{
return _socketClient.StartAsync();
}
public Task StopAsync()
{
_subscriptionManager.Reset();
return _socketClient.StopAsync();
}
public async Task SendAsync(SubscriptionRequest request, Action<Notification> callback)
{
var json = JsonSerializer.Serialize(request);
var message = new Message(Encoding.UTF8.GetBytes(json));
await _socketClient.SendAsync(message).ConfigureAwait(false);
_subscriptionManager.AddSubscription(new Subscription(request, callback));
}
public Task SubscribeToTickerAsync(string[] pairs, Action<Ticker> callback)
{
var request = new SubscriptionRequest(SubscriptionType.Subscribe, pairs, new[] { "ticker" }, IntervalType.Raw, null);
return SendAsync(request, n => callback(n.Data.Deserialize<Ticker>()!));
}
public Task SubscribeToDepthAsync(string[] pairs, Action<Depth> callback)
{
var request = new SubscriptionRequest(SubscriptionType.Subscribe, pairs, new[] { "depth" }, IntervalType.Raw, null);
return SendAsync(request, n => callback(n.Data.Deserialize<Depth>()!));
}
public async Task SubscribeToUserTradesAsync(Action<IEnumerable<UserTrade>> callback)
{
var token = await GetAuthenticationTokenAsync().ConfigureAwait(false);
var request = new SubscriptionRequest(SubscriptionType.Subscribe, null, new[] { "user_trade" }, IntervalType.Raw, token);
await SendAsync(request, n => callback(n.Data.Deserialize<IEnumerable<UserTrade>>()!)).ConfigureAwait(false);
}
private void OnConnected(object? sender, EventArgs e)
{
foreach (var subscription in _subscriptionManager.GetSubscriptions())
{
Task.Run(async () =>
{
await SendAsync(subscription.Request, subscription.Action).ConfigureAwait(false);
});
}
}
private void OnDisconnected(object? sender, EventArgs e)
{
}
private void OnNotification(Notification notification)
{
var callbacks = _subscriptionManager.GetCallbacks(notification.Channel);
foreach (var callback in callbacks)
{
try
{
Task.Run(() => callback(notification));
}
catch (Exception ex)
{
_logger.LogError(ex, "OnNotification: Error during event callback call");
}
}
}
private void OnMessageReceived(object? sender, MessageReceivedEventArgs e)
{
using var document = JsonDocument.Parse(e.Message.Buffer);
if (document.RootElement.TryGetProperty("channel", out var channelElement))
{
var channel = channelElement.GetString();
Debug.Assert(channel != null);
if (channel == "subscription")
{
var response = document.RootElement.Deserialize<BaseResponse<SubscriptionResponse>>();
Debug.Assert(response != null);
//Console.WriteLine($"SUBSCRIPTION | Code: {response.Data.Code} | Message: {response.Data.Message} | Subscription: {response.Data.Subscription}");
}
else
{
var notification = document.RootElement.Deserialize<Notification>();
Debug.Assert(notification != null);
OnNotification(notification);
}
}
}
}
```
Solution
If you would replace the List<Subscription>
+ lock
combo inside your SubscriptionManager
to ConcurrentDictionary<Guid, Subscription>
then the class implementation can be greatly simplified and would be more concise.
public class SubscriptionManager
{
private readonly ILogger<SubscriptionManager> _logger;
private readonly ConcurrentDictionary<Guid,Subscription> _subscriptions = new ();
public SubscriptionManager(ILoggerFactory? loggerFactory = default)
=> _logger = (loggerFactory ?? NullLoggerFactory.Instance)
.CreateLogger<SubscriptionManager>();
public void AddSubscription(Subscription subscription)
=> _subscriptions.TryAdd(subscription.Id, subscription);
public void RemoveSubscription(Subscription subscription)
=> _subscriptions.TryRemove(subscription.Id, out _);
public Subscription? GetSubscription(Guid id)
=> _subscriptions.TryGetValue(id, out var sub) ? sub : null;
public IList<Subscription> GetSubscriptions()
=> _subscriptions.Values.ToList();
public IEnumerable<Action<Notification>> GetCallbacks(string channel)
=> _subscriptions.Values
.Where(sub => sub.Request.Channels.Contains(channel))
.Select(sub => sub.Action)
.ToList();
public void Reset()
=> _subscriptions.Clear();
}