diff --git a/acl-groovy-dsl/src/main/resources/acl.gdsl b/acl-groovy-dsl/src/main/resources/acl.gdsl index c564ff32..3c4c46af 100644 --- a/acl-groovy-dsl/src/main/resources/acl.gdsl +++ b/acl-groovy-dsl/src/main/resources/acl.gdsl @@ -33,7 +33,7 @@ contributor(context(scope: scriptScope())) { && !enclosingCall("allOf") && !enclosingCall("anyOf")) { method name: 'topicFilter', type: 'javasabr.mqtt.service.acl.builder.SubscribeRuleBuilder', - params: [string: 'javasabr.mqtt.model.acl.matcher.ValueMatcher...'], + params: [string: 'javasabr.mqtt.model.acl.matcher.TopicFilterMatcher...'], doc: 'Set of topic filters matching by rule' method name: 'match', type: 'javasabr.mqtt.model.acl.matcher.TopicFilterMatcher', diff --git a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java index a4f79d94..d45c5d55 100644 --- a/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java +++ b/application/src/main/java/javasabr/mqtt/broker/application/config/MqttBrokerSpringConfig.java @@ -20,6 +20,7 @@ import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; import javasabr.mqtt.service.PublishReceivingService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.TopicService; import javasabr.mqtt.service.handler.client.ExternalNetworkMqttUserReleaseHandler; @@ -28,6 +29,7 @@ import javasabr.mqtt.service.impl.DefaultMqttConnectionFactory; import javasabr.mqtt.service.impl.DefaultPublishDeliveringService; import javasabr.mqtt.service.impl.DefaultPublishReceivingService; +import javasabr.mqtt.service.impl.DefaultRetainMessageService; import javasabr.mqtt.service.impl.DefaultTopicService; import javasabr.mqtt.service.impl.DisabledAuthorizationService; import javasabr.mqtt.service.impl.ExternalNetworkMqttUserFactory; @@ -114,8 +116,13 @@ AuthorizationService authorizationService() { } @Bean - SubscriptionService subscriptionService() { - return new InMemorySubscriptionService(); + SubscriptionService subscriptionService(RetainMessageService retainMessageService) { + return new InMemorySubscriptionService(retainMessageService); + } + + @Bean + RetainMessageService retainMessageService(PublishDeliveringService publishDeliveringService) { + return new DefaultRetainMessageService(publishDeliveringService); } @Bean @@ -210,7 +217,8 @@ MqttInMessageHandler publishMqttInMessageHandler( return new PublishMqttInMessageHandler( publishReceivingService, messageOutFactoryService, - topicService, authorizationService, + topicService, + authorizationService, fieldValidators); } @@ -254,24 +262,18 @@ ConnectionService externalMqttConnectionService(Collection findSubscribers(TopicName topicName) { return findSubscribersTo(MutableArray.ofType(SingleSubscriber.class), topicName); } diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultRetainMessageService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultRetainMessageService.java new file mode 100644 index 00000000..9056d870 --- /dev/null +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/DefaultRetainMessageService.java @@ -0,0 +1,47 @@ +package javasabr.mqtt.service.impl; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.subscriber.Subscriber; +import javasabr.mqtt.model.subscription.Subscription; +import javasabr.mqtt.model.topic.tree.ConcurrentRetainedMessageTree; +import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; +import javasabr.mqtt.service.publish.handler.PublishHandlingResult; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.MutableArray; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; + +@FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) +public class DefaultRetainMessageService implements RetainMessageService { + + PublishDeliveringService publishDeliveringService; + ConcurrentRetainedMessageTree retainedMessageTree; + + public DefaultRetainMessageService(PublishDeliveringService publishDeliveringService) { + this.publishDeliveringService = publishDeliveringService; + this.retainedMessageTree = new ConcurrentRetainedMessageTree(); + } + + @Override + public void retainMessage(Publish publish) { + retainedMessageTree.retainMessage(publish); + + } + + @Override + public void deliverRetainedMessages(Subscriber subscriber) { + SingleSubscriber singleSubscriber = subscriber.resolveSingle(); + Subscription subscription = singleSubscriber.subscription(); + boolean retainAsPublished = subscription.retainAsPublished(); + Array retainedMessages = retainedMessageTree.getRetainedMessage(subscription.topicFilter()); + MutableArray result = MutableArray.ofType(PublishHandlingResult.class); + for (Publish message : retainedMessages) { + if (!retainAsPublished) { + message = message.withoutRetain(); + } + publishDeliveringService.startDelivering(message, singleSubscriber); + } + } +} diff --git a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java index b9bd0e86..90e7c840 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java +++ b/core-service/src/main/java/javasabr/mqtt/service/impl/InMemorySubscriptionService.java @@ -1,20 +1,25 @@ package javasabr.mqtt.service.impl; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND; +import static javasabr.mqtt.model.SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST; import static javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode.NO_SUBSCRIPTION_EXISTED; import static javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode.SUCCESS; import javasabr.mqtt.model.MqttClientConnectionConfig; import javasabr.mqtt.model.MqttUser; +import javasabr.mqtt.model.SubscribeRetainHandling; import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode; import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode; import javasabr.mqtt.model.session.ActiveSubscriptions; import javasabr.mqtt.model.session.MqttSession; import javasabr.mqtt.model.subscriber.SingleSubscriber; +import javasabr.mqtt.model.subscriber.Subscriber; import javasabr.mqtt.model.subscriber.tree.ConcurrentSubscriberTree; import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.rlib.collections.array.Array; import javasabr.rlib.collections.array.ArrayFactory; @@ -22,6 +27,7 @@ import lombok.AccessLevel; import lombok.CustomLog; import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; /** * In memory subscription service based on {@link ConcurrentSubscriberTree} @@ -30,10 +36,12 @@ @FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) public class InMemorySubscriptionService implements SubscriptionService { + RetainMessageService retainMessageService; ConcurrentSubscriberTree subscriberTree; - public InMemorySubscriptionService() { + public InMemorySubscriptionService(RetainMessageService retainMessageService) { this.subscriberTree = new ConcurrentSubscriberTree(); + this.retainMessageService = retainMessageService; } @Override @@ -71,9 +79,13 @@ private SubscribeAckReasonCode addSubscription(MqttUser user, MqttSession sessio return SubscribeAckReasonCode.WILDCARD_SUBSCRIPTIONS_NOT_SUPPORTED; } ActiveSubscriptions activeSubscriptions = session.activeSubscriptions(); - SingleSubscriber previous = subscriberTree.subscribe(user, subscription); - if (previous != null) { - activeSubscriptions.remove(previous.subscription()); + SingleSubscriber newSubscriber = new SingleSubscriber(user, subscription); + SingleSubscriber previousSubscriber = subscriberTree.subscribe(newSubscriber); + if (previousSubscriber != null) { + activeSubscriptions.remove(previousSubscriber.subscription()); + } + if (isRetainHandlingRequired(subscription, previousSubscriber)) { + retainMessageService.deliverRetainedMessages(newSubscriber); } activeSubscriptions.add(subscription); return subscription.qos().subscribeAckReasonCode(); @@ -125,7 +137,18 @@ public void restoreSubscriptions(MqttUser user, MqttSession session) { .activeSubscriptions() .subscriptions(); for (Subscription subscription : subscriptions) { - subscriberTree.subscribe(user, subscription); + subscriberTree.subscribe(new SingleSubscriber(user, subscription)); + } + } + + private static boolean isRetainHandlingRequired( + Subscription newSubscription, + @Nullable Subscriber previousSubscriber) { + if (newSubscription.topicFilter().isShared() || !newSubscription.qos().isValid()) { + return false; } + SubscribeRetainHandling retainHandling = newSubscription.retainHandling(); + return retainHandling == SEND || (retainHandling == SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST + && previousSubscriber == null); } } diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java index ab91e42d..a56f51aa 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishInMessageHandler.java @@ -10,6 +10,7 @@ import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.MqttPublishInMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; @@ -29,6 +30,7 @@ public abstract class AbstractMqttPublishInMessageHandler subscribers = subscriptionService.findSubscribers(topicName); if (subscribers.isEmpty()) { diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java index f0c0538f..96c37a69 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/AbstractMqttPublishOutMessageHandler.java @@ -7,7 +7,6 @@ import javasabr.mqtt.network.message.out.MqttOutMessage; import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.MqttPublishOutMessageHandler; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -23,7 +22,6 @@ public abstract class AbstractMqttPublishOutMessageHandler expectedUserType; - SubscriptionService subscriptionService; MessageOutFactoryService messageOutFactoryService; @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java index 5a297f59..1a697ca9 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandler.java @@ -10,6 +10,7 @@ import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; public class Qos0MqttPublishInMessageHandler extends AbstractMqttPublishInMessageHandler { @@ -17,8 +18,14 @@ public class Qos0MqttPublishInMessageHandler extends AbstractMqttPublishInMessag public Qos0MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + ExternalNetworkMqttUser.class, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java index 8803aa3a..0e61d3f3 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandler.java @@ -6,16 +6,12 @@ import javasabr.mqtt.model.session.MqttSession; import javasabr.mqtt.network.impl.ExternalNetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import org.jspecify.annotations.Nullable; -public class Qos0MqttPublishOutMessageHandler - extends AbstractMqttPublishOutMessageHandler { +public class Qos0MqttPublishOutMessageHandler extends AbstractMqttPublishOutMessageHandler { - public Qos0MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, messageOutFactoryService); + public Qos0MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(ExternalNetworkMqttUser.class, messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java index 76c1446d..b5b76333 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandler.java @@ -11,6 +11,7 @@ import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -24,8 +25,14 @@ public class Qos1MqttPublishInMessageHandler extends TrackableMqttPublishInMessa public Qos1MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + ExternalNetworkMqttUser.class, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); } @Override @@ -53,10 +60,7 @@ protected boolean validateImpl(ExternalNetworkMqttUser user, NetworkMqttSession } @Override - protected void handleNoMatchedSubscribers( - ExternalNetworkMqttUser user, - NetworkMqttSession session, - Publish publish) { + protected void handleNoMatchedSubscribers(ExternalNetworkMqttUser user, NetworkMqttSession session, Publish publish) { super.handleNoMatchedSubscribers(user, session, publish); int messageId = publish.messageId(); MqttOutMessage response = messageOutFactoryService diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java index 072f7dca..38a8ff52 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandler.java @@ -10,17 +10,14 @@ import javasabr.mqtt.network.impl.ExternalNetworkMqttUser; import javasabr.mqtt.network.message.in.PublishAckMqttInMessage; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import lombok.CustomLog; import org.jspecify.annotations.Nullable; @CustomLog public class Qos1MqttPublishOutMessageHandler extends TrackableMqttPublishOutMessageHandler { - public Qos1MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(subscriptionService, messageOutFactoryService); + public Qos1MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java index 532b0b8f..df446c5d 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandler.java @@ -19,6 +19,7 @@ import javasabr.mqtt.network.session.NetworkMqttSession; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -34,8 +35,14 @@ public class Qos2MqttPublishInMessageHandler extends TrackableMqttPublishInMessa public Qos2MqttPublishInMessageHandler( SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + ExternalNetworkMqttUser.class, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); this.trackableMessageCallback = this::handleReceivedTrackableMessage; } @@ -67,17 +74,15 @@ protected boolean validateImpl(ExternalNetworkMqttUser user, NetworkMqttSession } @Override - protected void handleNoMatchedSubscribers( - ExternalNetworkMqttUser user, - NetworkMqttSession session, - Publish publish) { + protected void handleNoMatchedSubscribers(ExternalNetworkMqttUser user, NetworkMqttSession session, Publish publish) { super.handleNoMatchedSubscribers(user, session, publish); var reasonCode = PublishReceivedReasonCode.NO_MATCHING_SUBSCRIBERS; updateSessionState(session, publish, reasonCode); sendFeedback( - user, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(publish.messageId(), reasonCode)); + user, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(publish.messageId(), reasonCode)); } @Override @@ -90,9 +95,10 @@ protected void handleSuccess( var reasonCode = PublishReceivedReasonCode.SUCCESS; updateSessionState(session, publish, reasonCode); sendFeedback( - user, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(publish.messageId(), PublishReceivedReasonCode.SUCCESS)); + user, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(publish.messageId(), PublishReceivedReasonCode.SUCCESS)); } private void updateSessionState(NetworkMqttSession session, Publish publish, PublishReceivedReasonCode reasonCode) { @@ -118,22 +124,25 @@ protected void handleError( MessageTacker messageTacker = session.inMessageTracker(); messageTacker.update(messageId, MqttMessageType.PUBLISH, reasonCode); - sendFeedback(user, session, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(messageId, reasonCode), messageId); + sendFeedback( + user, + session, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(messageId, reasonCode), + messageId); } - private void handleDuplicated( - ExternalNetworkMqttUser user, - int messageId, - TrackedMessageMeta alreadyInProcess) { + private void handleDuplicated(ExternalNetworkMqttUser user, int messageId, TrackedMessageMeta alreadyInProcess) { PublishReceivedReasonCode reasonCode = PublishReceivedReasonCode.SUCCESS; if (alreadyInProcess.reasonCode() instanceof PublishReceivedReasonCode receivedReasonCode) { reasonCode = receivedReasonCode; } - sendFeedback(user, messageOutFactoryService - .resolveFactory(user) - .newPublishReceived(messageId, reasonCode)); + sendFeedback( + user, + messageOutFactoryService + .resolveFactory(user) + .newPublishReceived(messageId, reasonCode)); } private void handleMessageIdIsInUse(ExternalNetworkMqttUser user, int messageId) { @@ -154,7 +163,10 @@ private boolean handleReceivedTrackableMessage(MqttUser user, MqttSession sessio } if (messageMeta.messageType() != MqttMessageType.PUBLISH) { - log.warning(networkUser.clientId(), messageMeta, messageId, + log.warning( + networkUser.clientId(), + messageMeta, + messageId, "[%s] Not expected tracked message meta:[%s] for messageId:[%d]"::formatted); return true; } else if (!(message instanceof PublishReleaseMqttInMessage release)) { @@ -162,10 +174,7 @@ private boolean handleReceivedTrackableMessage(MqttUser user, MqttSession sessio return true; } - messageTacker.update( - messageId, - MqttMessageType.PUBLISH_COMPLETE, - PublishCompletedReasonCode.SUCCESS); + messageTacker.update(messageId, MqttMessageType.PUBLISH_COMPLETE, PublishCompletedReasonCode.SUCCESS); MqttOutMessage response = messageOutFactoryService .resolveFactory(networkUser) diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java index 040a8721..38c4949c 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandler.java @@ -13,17 +13,14 @@ import javasabr.mqtt.network.message.in.PublishCompleteMqttInMessage; import javasabr.mqtt.network.message.in.PublishReceivedMqttInMessage; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import lombok.CustomLog; import org.jspecify.annotations.Nullable; @CustomLog public class Qos2MqttPublishOutMessageHandler extends TrackableMqttPublishOutMessageHandler { - public Qos2MqttPublishOutMessageHandler( - SubscriptionService subscriptionService, - MessageOutFactoryService messageOutFactoryService) { - super(subscriptionService, messageOutFactoryService); + public Qos2MqttPublishOutMessageHandler(MessageOutFactoryService messageOutFactoryService) { + super(messageOutFactoryService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java index 9b768a65..57e0b1a1 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishInMessageHandler.java @@ -11,17 +11,24 @@ import javasabr.mqtt.network.user.NetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; import javasabr.mqtt.service.PublishDeliveringService; +import javasabr.mqtt.service.RetainMessageService; import javasabr.mqtt.service.SubscriptionService; -public abstract class TrackableMqttPublishInMessageHandler - extends AbstractMqttPublishInMessageHandler { +public abstract class TrackableMqttPublishInMessageHandler extends + AbstractMqttPublishInMessageHandler { public TrackableMqttPublishInMessageHandler( Class expectedClientType, SubscriptionService subscriptionService, PublishDeliveringService publishDeliveringService, - MessageOutFactoryService messageOutFactoryService) { - super(expectedClientType, subscriptionService, publishDeliveringService, messageOutFactoryService); + MessageOutFactoryService messageOutFactoryService, + RetainMessageService retainMessageService) { + super( + expectedClientType, + subscriptionService, + publishDeliveringService, + messageOutFactoryService, + retainMessageService); } @Override diff --git a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishOutMessageHandler.java b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishOutMessageHandler.java index 7977f9a7..d2444e9f 100644 --- a/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishOutMessageHandler.java +++ b/core-service/src/main/java/javasabr/mqtt/service/publish/handler/impl/TrackableMqttPublishOutMessageHandler.java @@ -15,7 +15,6 @@ import javasabr.mqtt.model.session.TrackedMessageMeta; import javasabr.mqtt.network.impl.ExternalNetworkMqttUser; import javasabr.mqtt.service.MessageOutFactoryService; -import javasabr.mqtt.service.SubscriptionService; import javasabr.mqtt.service.message.out.factory.MqttMessageOutFactory; import javasabr.mqtt.service.publish.handler.PublishHandlingResult; import lombok.AccessLevel; @@ -32,9 +31,8 @@ public abstract class TrackableMqttPublishOutMessageHandler extends PublishRetryer publishRetryer; protected TrackableMqttPublishOutMessageHandler( - SubscriptionService subscriptionService, MessageOutFactoryService messageOutFactoryService) { - super(ExternalNetworkMqttUser.class, subscriptionService, messageOutFactoryService); + super(ExternalNetworkMqttUser.class, messageOutFactoryService); this.trackableMessageCallback = this::handleReceivedTrackableMessage; this.publishRetryer = this::retryDelivering; } diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy index 2da0a9f0..87b1bc8c 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/IntegrationServiceSpecification.groovy @@ -12,6 +12,7 @@ import javasabr.mqtt.network.user.NetworkMqttUser import javasabr.mqtt.service.impl.DefaultMessageOutFactoryService import javasabr.mqtt.service.impl.DefaultPublishDeliveringService import javasabr.mqtt.service.impl.DefaultPublishReceivingService +import javasabr.mqtt.service.impl.DefaultRetainMessageService import javasabr.mqtt.service.impl.DefaultTopicService import javasabr.mqtt.service.impl.DisabledAuthorizationService import javasabr.mqtt.service.impl.InMemorySubscriptionService @@ -48,14 +49,11 @@ abstract class IntegrationServiceSpecification extends Specification { def testPayload = "testpayload".getBytes(StandardCharsets.UTF_8) @Shared - def clientIdGenerator = new AtomicInteger(); + def clientIdGenerator = new AtomicInteger() @Shared def defaultTopicService = new DefaultTopicService() - @Shared - def defaultSubscriptionService = new InMemorySubscriptionService() - @Shared def defaultMessageOutFactoryService = new DefaultMessageOutFactoryService([ new Mqtt311MessageOutFactory(), @@ -64,16 +62,23 @@ abstract class IntegrationServiceSpecification extends Specification { @Shared def defaultPublishDeliveringService = new DefaultPublishDeliveringService([ - new Qos0MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService), - new Qos1MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService), - new Qos2MqttPublishOutMessageHandler(defaultSubscriptionService, defaultMessageOutFactoryService) + new Qos0MqttPublishOutMessageHandler(defaultMessageOutFactoryService), + new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService), + new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) ]) + @Shared + def defaultRetainMessageService = new DefaultRetainMessageService(defaultPublishDeliveringService) + + @Shared + def defaultSubscriptionService = new InMemorySubscriptionService(defaultRetainMessageService) + @Shared def qos0MqttPublishInMessageHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService); + defaultMessageOutFactoryService, + defaultRetainMessageService) @Shared def publishReceivingService = new DefaultPublishReceivingService([ @@ -81,25 +86,27 @@ abstract class IntegrationServiceSpecification extends Specification { new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService), + defaultMessageOutFactoryService, + defaultRetainMessageService), new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) ]) @Shared - def defaultPublishReleaseMqttInMessageHandler = new PublishReleaseMqttInMessageHandler(defaultMessageOutFactoryService); + def defaultPublishReleaseMqttInMessageHandler = new PublishReleaseMqttInMessageHandler(defaultMessageOutFactoryService) @Shared def defaultBufferAllocator = new DefaultBufferAllocator(SimpleServerNetworkConfig.builder().build()) @Shared - def defaultMqttSessionService = new InMemoryMqttSessionService(60_000); - + def defaultMqttSessionService = new InMemoryMqttSessionService(60_000) + @Shared def disabledAclService = new DisabledAuthorizationService() - + @Shared List> publishInFieldValidators = [ new PublishRetainMqttInMessageFieldValidator(defaultMessageOutFactoryService), diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy index 6e97e1de..344edd52 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/impl/InMemorySubscriptionServiceTest.groovy @@ -6,13 +6,20 @@ import javasabr.mqtt.model.SubscribeRetainHandling import javasabr.mqtt.model.reason.code.SubscribeAckReasonCode import javasabr.mqtt.model.reason.code.UnsubscribeAckReasonCode import javasabr.mqtt.model.subscription.Subscription +import javasabr.mqtt.model.subscription.TestPublishFactory +import javasabr.mqtt.model.topic.TopicFilter +import javasabr.mqtt.model.topic.TopicName +import javasabr.mqtt.network.handler.NetworkMqttUserReleaseHandler +import javasabr.mqtt.network.impl.InternalNetworkMqttUser +import javasabr.mqtt.network.message.out.PublishMqtt5OutMessage import javasabr.mqtt.service.IntegrationServiceSpecification -import javasabr.mqtt.service.SubscriptionService +import javasabr.mqtt.service.TestExternalNetworkMqttUser import javasabr.rlib.collections.array.Array class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { - SubscriptionService subscriptionService = new InMemorySubscriptionService() + def retainMessageService = new DefaultRetainMessageService(defaultPublishDeliveringService) + def subscriptionService = new InMemorySubscriptionService(retainMessageService) def "should subscribe with expected results in default settings"() { given: @@ -325,4 +332,249 @@ class InMemorySubscriptionServiceTest extends IntegrationServiceSpecification { storedSubscriptions.size() == 3 storedSubscriptions ==~ resultSubscriptions } + + def "should only deliver 'send-if-subscription-does-not-exist' Subscribe Retain Handling once"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST, + true, + true) + def subscriptions = Array.of( + subscription, + new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/2"), + 30, + QoS.AT_LEAST_ONCE, + SubscribeRetainHandling.SEND, + true, + true), + new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/3"), + 30, + QoS.EXACTLY_ONCE, + SubscribeRetainHandling.DO_NOT_SEND, + true, + true)) + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + when: + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + then: + def publishMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + publishMessage.payload() == publishWithRetain.payload() + and: + mqttUser.isEmpty() + } + + def "should always deliver 'send' Subscribe Retain Handling"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true) + def subscriptions = Array.of(subscription) + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + when: + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + then: + def firstSentMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + firstSentMessage.payload() == publishWithRetain.payload() + and: + def thirdSentMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + thirdSentMessage.payload() == publishWithRetain.payload() + and: + mqttUser.isEmpty() + } + + def "should not deliver 'do-not-send' Subscribe Retain Handling"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.DO_NOT_SEND, + true, + true) + def subscriptions = Array.of( + subscription, + new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/2"), + 30, + QoS.AT_LEAST_ONCE, + SubscribeRetainHandling.SEND_IF_SUBSCRIPTION_DOES_NOT_EXIST, + true, + true), + new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/3"), + 30, + QoS.EXACTLY_ONCE, + SubscribeRetainHandling.SEND, + true, + true)) + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + and: + def publishWithoutRetain = TestPublishFactory.makePublishWithoutRetain("topic/filter/1", "payload2") + retainMessageService.retainMessage(publishWithoutRetain) + when: + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + then: + mqttUser.isEmpty() + } + + def "should reset retain flag if 'retain as published' is false"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + false) + def subscriptions = Array.of(subscription) + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + when: + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + then: + def sentMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + sentMessage.payload() == publishWithRetain.payload() + !sentMessage.retain() + and: + mqttUser.isEmpty() + } + + def "should keep retain flag if 'retain as published' is true"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true) + def subscriptions = Array.of(subscription) + when: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + then: + def secondSentMessage = mqttUser.nextSentMessage(PublishMqtt5OutMessage) + secondSentMessage.payload() == publishWithRetain.payload() + secondSentMessage.retain() + and: + mqttUser.isEmpty() + } + + def "should not send retained messages in case of invalid QoS"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.INVALID, + SubscribeRetainHandling.SEND, + true, + true) + def subscriptions = Array.of(subscription) + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + when: + subscriptionService.subscribe(mqttUser, mqttUser.session(), subscriptions) + then: + mqttUser.isEmpty() + } + + def "should clean and restore subscriptions"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def expectedUser = mqttConnection.user() as TestExternalNetworkMqttUser + def expectedSubscription = new Subscription( + TopicFilter.valueOf("topic"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true) + when: + subscriptionService.subscribe(expectedUser, expectedUser.session(), Array.of(expectedSubscription)) + def subscribers = subscriptionService.findSubscribers(TopicName.valueOf("topic")) + then: + !subscribers.isEmpty() + with(subscribers[0]) { + user() == expectedUser + subscription() == expectedSubscription + } + when: + subscriptionService.cleanSubscriptions(expectedUser, expectedUser.session()) + subscribers = subscriptionService.findSubscribers(TopicName.valueOf("topic")) + then: + subscribers.isEmpty() + + when: + subscriptionService.restoreSubscriptions(expectedUser, expectedUser.session()) + subscribers = subscriptionService.findSubscribers(TopicName.valueOf("topic")) + then: + !subscribers.isEmpty() + with(subscribers[0]) { + user() == expectedUser + subscription() == expectedSubscription + } + } + + def "should suppress retained message delivering failure"() { + given: + def serverConfig = defaultExternalServerConnectionConfig + def mqttConnection = mockedExternalConnection(serverConfig, MqttVersion.MQTT_5) + def mqttUser = mqttConnection.user() as TestExternalNetworkMqttUser + def anotherUser = new InternalNetworkMqttUser(mqttConnection, Mock(NetworkMqttUserReleaseHandler)) + def subscription = new Subscription( + defaultTopicService.createTopicFilter(mqttUser, "topic/filter/1"), + 30, + QoS.AT_MOST_ONCE, + SubscribeRetainHandling.SEND, + true, + true) + def subscriptions = Array.of(subscription) + and: + def publishWithRetain = TestPublishFactory.makePublishWithRetain("topic/filter/1", "payload1") + retainMessageService.retainMessage(publishWithRetain) + when: + subscriptionService.subscribe(anotherUser, mqttUser.session(), subscriptions) + then: + mqttUser.isEmpty() + } } diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy index 2ebca07c..febfe532 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/message/handler/impl/UnsubscribeMqttInMessageHandlerTest.groovy @@ -69,7 +69,7 @@ class UnsubscribeMqttInMessageHandlerTest extends IntegrationServiceSpecificatio def "should response with expected results"() { given: def mqttConnection = mockedExternalConnection(MqttVersion.MQTT_5) - def subscriptionService = new InMemorySubscriptionService() + def subscriptionService = new InMemorySubscriptionService(defaultRetainMessageService) def messageHandler = new UnsubscribeMqttInMessageHandler( subscriptionService, defaultMessageOutFactoryService, diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy index fd2d8600..349f9fe4 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishInMessageHandlerTest.groovy @@ -15,7 +15,8 @@ class Qos0MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def subscriber1 = mockedExternalConnection(MqttVersion.MQTT_5) def subscriber2 = mockedExternalConnection(MqttVersion.MQTT_5) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) @@ -50,7 +51,8 @@ class Qos0MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos0MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos0MqttPublishInMessageHandlerTest/2") diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandlerTest.groovy index 2781da71..16496512 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos0MqttPublishOutMessageHandlerTest.groovy @@ -14,9 +14,7 @@ class Qos0MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should deliver publish to subscriber"() { given: - def publishOutHandler = new Qos0MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos0MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def testTopicName = defaultTopicService.createTopicName(user, "Qos0MqttPublishOutMessageHandlerTest/1") diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy index d86426e6..75df3392 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishInMessageHandlerTest.groovy @@ -23,7 +23,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def subscriber1 = mockedExternalConnection(MqttVersion.MQTT_5) def subscriber2 = mockedExternalConnection(MqttVersion.MQTT_5) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) @@ -68,7 +69,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/2") @@ -92,7 +94,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/3") @@ -115,7 +118,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/4") @@ -141,7 +145,8 @@ class Qos1MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos1MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishInMessageHandlerTest/5") diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandlerTest.groovy index d1adf56a..406e7936 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos1MqttPublishOutMessageHandlerTest.groovy @@ -22,9 +22,7 @@ class Qos1MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should deliver publish to subscriber"() { given: - def publishOutHandler = new Qos1MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def testTopicName = defaultTopicService.createTopicName(user, "Qos1MqttPublishOutMessageHandlerTest/1") @@ -50,9 +48,7 @@ class Qos1MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should wait for ack response for publish"() { given: - def publishOutHandler = new Qos1MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -98,9 +94,7 @@ class Qos1MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should correctly handle publish ack when no stored trackable meta about the publish"() { given: - def publishOutHandler = new Qos1MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -149,9 +143,7 @@ class Qos1MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should handle as protocol error receiving unexpected response message"() { given: - def publishOutHandler = new Qos1MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -191,9 +183,7 @@ class Qos1MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should handle as protocol error for unexpected flow state"() { given: - def publishOutHandler = new Qos1MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos1MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy index fdae20e6..2e6ea731 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishInMessageHandlerTest.groovy @@ -27,7 +27,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def subscriber1 = mockedExternalConnection(MqttVersion.MQTT_5) def subscriber2 = mockedExternalConnection(MqttVersion.MQTT_5) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) @@ -87,7 +88,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/2") @@ -126,7 +128,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/3") @@ -149,7 +152,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/4") @@ -175,7 +179,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/5") @@ -203,7 +208,8 @@ class Qos2MqttPublishInMessageHandlerTest extends QosMqttPublishInMessageHandler def publishInHandler = new Qos2MqttPublishInMessageHandler( defaultSubscriptionService, defaultPublishDeliveringService, - defaultMessageOutFactoryService) + defaultMessageOutFactoryService, + defaultRetainMessageService) def publisher = mockedExternalConnection(MqttVersion.MQTT_5) def user = publisher.user() as TestExternalNetworkMqttUser def topicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishInMessageHandlerTest/5") diff --git a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandlerTest.groovy b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandlerTest.groovy index d6d0ef56..1de5d99b 100644 --- a/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandlerTest.groovy +++ b/core-service/src/test/groovy/javasabr/mqtt/service/publish/handler/impl/Qos2MqttPublishOutMessageHandlerTest.groovy @@ -25,9 +25,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should deliver publish to subscriber"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def testTopicName = defaultTopicService.createTopicName(user, "Qos2MqttPublishOutMessageHandlerTest/1") @@ -53,9 +51,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should wait for receive-complete responses for publish"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -124,9 +120,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should correctly handle publish receive when no stored trackable meta about the publish"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -176,9 +170,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should handle as protocol error receiving unexpected response message for first stage"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -218,9 +210,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should handle as protocol error receiving unexpected response message for second stage"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -271,9 +261,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should handle as protocol error for unexpected flow state for publish received"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() @@ -316,9 +304,7 @@ class Qos2MqttPublishOutMessageHandlerTest extends QosMqttPublishOutMessageHandl def "should handle as protocol error for unexpected flow state for publish complete"() { given: - def publishOutHandler = new Qos2MqttPublishOutMessageHandler( - defaultSubscriptionService, - defaultMessageOutFactoryService) + def publishOutHandler = new Qos2MqttPublishOutMessageHandler(defaultMessageOutFactoryService) def connection = mockedExternalConnection(MqttVersion.MQTT_5) def user = connection.user() as TestExternalNetworkMqttUser def session = user.session() diff --git a/model/src/main/java/javasabr/mqtt/model/AbstractTrieNode.java b/model/src/main/java/javasabr/mqtt/model/AbstractTrieNode.java new file mode 100644 index 00000000..75f835a2 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/AbstractTrieNode.java @@ -0,0 +1,93 @@ +package javasabr.mqtt.model; + +import java.util.Collection; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.rlib.collections.dictionary.DictionaryFactory; +import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; +import org.jspecify.annotations.Nullable; + +public abstract class AbstractTrieNode { + + @Nullable + volatile LockableRefToRefDictionary childNodes; + + protected abstract Supplier getNodeFactory(); + + private LockableRefToRefDictionary getOrCreateChildNodes() { + var current = childNodes; + if (current != null) { + return current; + } + synchronized (this) { + current = childNodes; + if (current == null) { + current = DictionaryFactory.stampedLockBasedRefToRefDictionary(); + childNodes = current; + } + return current; + } + } + + protected T getOrCreateChildNode(String segment) { + var childNodes = getOrCreateChildNodes(); + long stamp = childNodes.readLock(); + try { + T topicFilterNode = childNodes.get(segment); + if (topicFilterNode != null) { + return topicFilterNode; + } + } finally { + childNodes.readUnlock(stamp); + } + stamp = childNodes.writeLock(); + try { + return childNodes.getOrCompute(segment, getNodeFactory()); + } finally { + childNodes.writeUnlock(stamp); + } + } + + protected void collectChildNodes(Collection resultCollection) { + var localChildNodes = childNodes; + if (localChildNodes == null) { + return; + } + long stamp = localChildNodes.readLock(); + try { + localChildNodes.values(resultCollection); + } finally { + localChildNodes.readUnlock(stamp); + } + } + + @Nullable + protected Collection getChildNodes(Supplier> resultCollectionFactory) { + var localChildNodes = childNodes; + if (localChildNodes == null) { + return null; + } + Collection resultCollection = resultCollectionFactory.get(); + collectChildNodes(resultCollection); + return resultCollection; + } + + @Nullable + protected T getChildNode(String segment) { + var localChildNodes = childNodes; + if (localChildNodes == null) { + return null; + } + long stamp = localChildNodes.readLock(); + try { + return localChildNodes.get(segment); + } finally { + localChildNodes.readUnlock(stamp); + } + } + + @Override + public String toString() { + return DebugUtils.toJsonString(this); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/QoS.java b/model/src/main/java/javasabr/mqtt/model/QoS.java index 8da7dd94..d8159f69 100644 --- a/model/src/main/java/javasabr/mqtt/model/QoS.java +++ b/model/src/main/java/javasabr/mqtt/model/QoS.java @@ -19,8 +19,7 @@ public enum QoS implements NumberedEnum { EXACTLY_ONCE(2, SubscribeAckReasonCode.GRANTED_QOS_2), INVALID(3, SubscribeAckReasonCode.IMPLEMENTATION_SPECIFIC_ERROR); - private static final NumberedEnumMap NUMBERED_MAP = - new NumberedEnumMap<>(QoS.class); + private static final NumberedEnumMap NUMBERED_MAP = new NumberedEnumMap<>(QoS.class); public static QoS ofCode(int level) { return NUMBERED_MAP.resolve(level, QoS.INVALID); @@ -45,4 +44,8 @@ public boolean isLowerThan(QoS another) { public boolean isHigherThan(QoS another) { return level > another.level; } + + public boolean isValid() { + return this != INVALID; + } } diff --git a/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java b/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java index 2e10c730..db8d646d 100644 --- a/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java +++ b/model/src/main/java/javasabr/mqtt/model/publishing/Publish.java @@ -92,6 +92,24 @@ public Publish with(int messageId, QoS qos, boolean duplicated, int topicAlias) userProperties); } + public Publish withoutRetain() { + return new Publish( + messageId, + qos, + topicName, + responseTopicName, + payload, + duplicated, + false, + contentType, + subscriptionIds, + correlationData, + messageExpiryInterval, + topicAlias, + payloadFormat, + userProperties); + } + @Override public String toString() { return DebugUtils.toJsonString(this); diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java index 307db58c..1f6d4c74 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/ConcurrentSubscriberTree.java @@ -2,7 +2,6 @@ import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.subscriber.SingleSubscriber; -import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; import javasabr.rlib.collections.array.Array; @@ -22,8 +21,8 @@ public ConcurrentSubscriberTree() { } @Nullable - public SingleSubscriber subscribe(MqttUser user, Subscription subscription) { - return rootNode.subscribe(0, user, subscription, subscription.topicFilter()); + public SingleSubscriber subscribe(SingleSubscriber subscriber) { + return rootNode.subscribe(0, subscriber, subscriber.subscription().topicFilter()); } public boolean unsubscribe(MqttUser user, TopicFilter topicFilter) { diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java index 4a6579c2..5d3c74bb 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberNode.java @@ -5,14 +5,11 @@ import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.Subscriber; -import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.mqtt.model.topic.TopicName; import javasabr.rlib.collections.array.ArrayFactory; import javasabr.rlib.collections.array.LockableArray; import javasabr.rlib.collections.array.MutableArray; -import javasabr.rlib.collections.dictionary.DictionaryFactory; -import javasabr.rlib.collections.dictionary.LockableRefToRefDictionary; import lombok.AccessLevel; import lombok.Getter; import lombok.experimental.Accessors; @@ -24,30 +21,33 @@ @FieldDefaults(level = AccessLevel.PRIVATE) class SubscriberNode extends SubscriberTreeBase { - private final static Supplier SUBSCRIBER_NODE_FACTORY = SubscriberNode::new; + private final static Supplier NODE_FACTORY = SubscriberNode::new; static { DebugUtils.registerIncludedFields("childNodes", "subscribers"); } - @Nullable - volatile LockableRefToRefDictionary childNodes; @Nullable volatile LockableArray subscribers; + @Override + protected Supplier getNodeFactory() { + return NODE_FACTORY; + } + /** * @return the previous subscription from the same owner */ @Nullable - public SingleSubscriber subscribe(int level, MqttUser owner, Subscription subscription, TopicFilter topicFilter) { + protected SingleSubscriber subscribe(int level, SingleSubscriber subscriber, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { - return addSubscriber(getOrCreateSubscribers(), owner, subscription, topicFilter); + return addSubscriber(getOrCreateSubscribers(), subscriber, topicFilter); } SubscriberNode childNode = getOrCreateChildNode(topicFilter.segment(level)); - return childNode.subscribe(level + 1, owner, subscription, topicFilter); + return childNode.subscribe(level + 1, subscriber, topicFilter); } - public boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { + protected boolean unsubscribe(int level, MqttUser owner, TopicFilter topicFilter) { if (level == topicFilter.levelsCount()) { return removeSubscriber(subscribers(), owner, topicFilter); } @@ -67,7 +67,7 @@ private void exactlyTopicMatch( int lastLevel, MutableArray result) { String segment = topicName.segment(level); - SubscriberNode subscriberNode = childNode(segment); + SubscriberNode subscriberNode = getChildNode(segment); if (subscriberNode == null) { return; } @@ -83,7 +83,7 @@ private void singleWildcardTopicMatch( TopicName topicName, int lastLevel, MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.SINGLE_LEVEL_WILDCARD); + SubscriberNode subscriberNode = getChildNode(TopicFilter.SINGLE_LEVEL_WILDCARD); if (subscriberNode == null) { return; } @@ -95,71 +95,24 @@ private void singleWildcardTopicMatch( } private void multiWildcardTopicMatch(MutableArray result) { - SubscriberNode subscriberNode = childNode(TopicFilter.MULTI_LEVEL_WILDCARD); + SubscriberNode subscriberNode = getChildNode(TopicFilter.MULTI_LEVEL_WILDCARD); if (subscriberNode != null) { appendSubscribersTo(result, subscriberNode); } } - private SubscriberNode getOrCreateChildNode(String segment) { - LockableRefToRefDictionary childNodes = getOrCreateChildNodes(); - long stamp = childNodes.readLock(); - try { - SubscriberNode subscriberNode = childNodes.get(segment); - if (subscriberNode != null) { - return subscriberNode; - } - } finally { - childNodes.readUnlock(stamp); - } - stamp = childNodes.writeLock(); - try { - return childNodes.getOrCompute(segment, SUBSCRIBER_NODE_FACTORY); - } finally { - childNodes.writeUnlock(stamp); - } - } - - @Nullable - private SubscriberNode childNode(String segment) { - LockableRefToRefDictionary childNodes = childNodes(); - if (childNodes == null) { - return null; - } - long stamp = childNodes.readLock(); - try { - return childNodes.get(segment); - } finally { - childNodes.readUnlock(stamp); - } - } - - private LockableRefToRefDictionary getOrCreateChildNodes() { - if (childNodes == null) { - synchronized (this) { - if (childNodes == null) { - childNodes = DictionaryFactory.stampedLockBasedRefToRefDictionary(); - } - } - } - //noinspection ConstantConditions - return childNodes; - } - private LockableArray getOrCreateSubscribers() { - if (subscribers == null) { - synchronized (this) { - if (subscribers == null) { - subscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); - } + LockableArray localSubscribers = subscribers; + if (localSubscribers != null) { + return localSubscribers; + } + synchronized (this) { + localSubscribers = subscribers; + if (localSubscribers == null) { + localSubscribers = ArrayFactory.stampedLockBasedArray(Subscriber.class); + subscribers = localSubscribers; } + return localSubscribers; } - //noinspection ConstantConditions - return subscribers; - } - - @Override - public String toString() { - return DebugUtils.toJsonString(this); } } diff --git a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java index 972b696b..15b4aa8d 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java +++ b/model/src/main/java/javasabr/mqtt/model/subscriber/tree/SubscriberTreeBase.java @@ -1,12 +1,12 @@ package javasabr.mqtt.model.subscriber.tree; import java.util.Objects; +import javasabr.mqtt.model.AbstractTrieNode; import javasabr.mqtt.model.MqttUser; import javasabr.mqtt.model.QoS; import javasabr.mqtt.model.subscriber.SharedSubscriber; import javasabr.mqtt.model.subscriber.SingleSubscriber; import javasabr.mqtt.model.subscriber.Subscriber; -import javasabr.mqtt.model.subscription.Subscription; import javasabr.mqtt.model.topic.SharedTopicFilter; import javasabr.mqtt.model.topic.TopicFilter; import javasabr.rlib.collections.array.LockableArray; @@ -18,7 +18,7 @@ @RequiredArgsConstructor @FieldDefaults(level = AccessLevel.PROTECTED, makeFinal = true) -abstract class SubscriberTreeBase { +abstract class SubscriberTreeBase extends AbstractTrieNode { /** * @return previous subscriber with the same user @@ -26,17 +26,16 @@ abstract class SubscriberTreeBase { @Nullable protected static SingleSubscriber addSubscriber( LockableArray subscribers, - MqttUser user, - Subscription subscription, + SingleSubscriber subscriber, TopicFilter topicFilter) { long stamp = subscribers.writeLock(); try { if (topicFilter instanceof SharedTopicFilter stf) { - addSharedSubscriber(subscribers, user, subscription, stf); + addSharedSubscriber(subscribers, subscriber, stf); return null; } else { - SingleSubscriber previous = removePreviousIfExist(subscribers, user); - subscribers.add(new SingleSubscriber(user, subscription)); + SingleSubscriber previous = removePreviousIfExist(subscribers, subscriber.user()); + subscribers.add(subscriber); return previous; } } finally { @@ -45,9 +44,7 @@ protected static SingleSubscriber addSubscriber( } @Nullable - private static SingleSubscriber removePreviousIfExist( - LockableArray subscribers, - MqttUser user) { + private static SingleSubscriber removePreviousIfExist(LockableArray subscribers, MqttUser user) { int index = subscribers.indexOf(Subscriber::resolveUser, user); if (index < 0) { return null; @@ -59,8 +56,7 @@ private static SingleSubscriber removePreviousIfExist( private static void addSharedSubscriber( LockableArray subscribers, - MqttUser user, - Subscription subscription, + SingleSubscriber subscriber, SharedTopicFilter sharedTopicFilter) { String group = sharedTopicFilter.shareName(); @@ -73,7 +69,7 @@ private static void addSharedSubscriber( subscribers.add(sharedSubscriber); } - sharedSubscriber.addSubscriber(new SingleSubscriber(user, subscription)); + sharedSubscriber.addSubscriber(subscriber); } protected static void appendSubscribersTo(MutableArray result, SubscriberNode subscriberNode) { @@ -84,10 +80,7 @@ protected static void appendSubscribersTo(MutableArray result, long stamp = subscribers.readLock(); try { for (Subscriber subscriber : subscribers) { - SingleSubscriber singleSubscriber = subscriber.resolveSingle(); - if (removeDuplicateWithLowerQoS(result, singleSubscriber)) { - result.add(singleSubscriber); - } + addOrReplaceIfLowerQos(result, subscriber); } } finally { subscribers.readUnlock(stamp); @@ -141,23 +134,18 @@ private static boolean isSharedSubscriberWithGroup(Subscriber subscriber, String return subscriber instanceof SharedSubscriber shared && Objects.equals(group, shared.group()); } - private static boolean removeDuplicateWithLowerQoS( - MutableArray result, SingleSubscriber candidate) { - + private static void addOrReplaceIfLowerQos(MutableArray result, Subscriber subscriber) { + SingleSubscriber candidate = subscriber.resolveSingle(); int found = result.indexOf(SingleSubscriber::user, candidate.user()); if (found == -1) { - return true; + result.add(candidate); + return; } - QoS candidateQos = candidate.qos(); - SingleSubscriber exist = result.get(found); - QoS existeQos = exist.qos(); - - if (existeQos.ordinal() < candidateQos.ordinal()) { + QoS existedQos = result.get(found).qos(); + if (existedQos.isLowerThan(candidateQos)) { result.remove(found); - return true; + result.add(candidate); } - - return false; } } diff --git a/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java b/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java index 69439978..c66a6548 100644 --- a/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java +++ b/model/src/main/java/javasabr/mqtt/model/subscription/Subscription.java @@ -33,8 +33,12 @@ public record Subscription( boolean noLocal, /* If true, Application Messages forwarded using this subscription keep the RETAIN flag they were published with. If - false, Application Messages forwarded using this subscription have the RETAIN flag set to 0. Retained messages sent - when the subscription is established have the RETAIN flag set to 1. + false, Application Messages forwarded using this subscription have the RETAIN flag set to 0. + + Bit 3 of the Subscription Options represents the Retain As Published option. + If 1, Application Messages forwarded using this subscription keep the RETAIN flag they were published with. + If 0, Application Messages forwarded using this subscription have the RETAIN flag set to 0. + Retained messages sent when the subscription is established have the RETAIN flag set to 1. */ boolean retainAsPublished) { diff --git a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java index 2b1f313e..ece98cca 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/AbstractTopic.java @@ -32,6 +32,10 @@ protected AbstractTopic(String rawTopicName) { rawTopic = rawTopicName; } + public boolean isShared(){ + return false; + } + public String segment(int level) { return segments[level]; } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java b/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java index e9e3ba7d..b68c4689 100644 --- a/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java +++ b/model/src/main/java/javasabr/mqtt/model/topic/SharedTopicFilter.java @@ -28,6 +28,11 @@ public static SharedTopicFilter valueOf(String rawSharedTopicFilter) { return new SharedTopicFilter(rawTopicFilter, shareName); } + @Override + public boolean isShared(){ + return true; + } + public static boolean isShared(String rawTopicFilter) { return rawTopicFilter.startsWith(SharedTopicFilter.SHARE_KEYWORD); } diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java new file mode 100644 index 00000000..ea45b880 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/ConcurrentRetainedMessageTree.java @@ -0,0 +1,29 @@ +package javasabr.mqtt.model.topic.tree; + +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.rlib.collections.array.Array; +import javasabr.rlib.collections.array.MutableArray; +import javasabr.rlib.common.ThreadSafe; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; + +@FieldDefaults(level = AccessLevel.PRIVATE, makeFinal = true) +public class ConcurrentRetainedMessageTree implements ThreadSafe { + + RetainedMessageNode rootNode; + + public ConcurrentRetainedMessageTree() { + this.rootNode = new RetainedMessageNode(); + } + + public void retainMessage(Publish message) { + rootNode.retainMessage(0, message, message.topicName()); + } + + public Array getRetainedMessage(TopicFilter topicFilter) { + var resultArray = MutableArray.ofType(Publish.class); + rootNode.collectRetainedMessages(0, topicFilter, resultArray); + return Array.copyOf(resultArray); + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java new file mode 100644 index 00000000..5b2f81c0 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/RetainedMessageNode.java @@ -0,0 +1,118 @@ +package javasabr.mqtt.model.topic.tree; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import javasabr.mqtt.base.util.DebugUtils; +import javasabr.mqtt.model.AbstractTrieNode; +import javasabr.mqtt.model.publishing.Publish; +import javasabr.mqtt.model.topic.TopicFilter; +import javasabr.mqtt.model.topic.TopicName; +import javasabr.rlib.collections.array.ArrayFactory; +import javasabr.rlib.collections.array.MutableArray; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.experimental.Accessors; +import lombok.experimental.FieldDefaults; +import org.jspecify.annotations.Nullable; + +@Getter(AccessLevel.PACKAGE) +@Accessors(fluent = true, chain = false) +@FieldDefaults(level = AccessLevel.PRIVATE) +class RetainedMessageNode extends AbstractTrieNode { + + private final static Supplier NODE_FACTORY = RetainedMessageNode::new; + + static { + DebugUtils.registerIncludedFields("childNodes", "retainedMessage"); + } + + private static MutableArray childNodesFactory() { + return ArrayFactory.mutableArray(RetainedMessageNode.class); + } + + final AtomicReference<@Nullable Publish> retainedMessage = new AtomicReference<>(); + + @Override + protected Supplier getNodeFactory() { + return NODE_FACTORY; + } + + public void retainMessage(int level, Publish message, TopicName topicName) { + var child = getOrCreateChildNode(topicName.segment(level)); + boolean isLastLevel = (level + 1 == topicName.levelsCount()); + if (isLastLevel) { + if (message.payload().length == 0) { + child.clearRetainedMessage(); + } else { + child.setRetainedMessage(message); + } + } else { + child.retainMessage(level + 1, message, topicName); + } + } + + private void setRetainedMessage(Publish value) { + retainedMessage.set(value); + } + + private void clearRetainedMessage() { + retainedMessage.set(null); + } + + public void collectRetainedMessages(int level, TopicFilter topicFilter, MutableArray result) { + if (level == topicFilter.levelsCount()) { + Publish publish = retainedMessage.get(); + if (publish != null) { + result.add(publish); + } + return; + } + String segment = topicFilter.segment(level); + boolean isOneChar = segment.length() == 1; + if (isOneChar && segment.charAt(0) == TopicFilter.SINGLE_LEVEL_WILDCARD_CHAR) { + collectAllChildren(level, topicFilter, result); + } else if (isOneChar && segment.charAt(0) == TopicFilter.MULTI_LEVEL_WILDCARD_CHAR) { + collectEverything(this, result); + } else { + collectExactSegment(level, segment, topicFilter, result); + } + } + + private void collectExactSegment( + int level, + String segment, + TopicFilter topicFilter, + MutableArray result) { + RetainedMessageNode retainedMessageNode = getChildNode(segment); + if (retainedMessageNode != null) { + retainedMessageNode.collectRetainedMessages(level + 1, topicFilter, result); + } + } + + private void collectAllChildren(int level, TopicFilter topicFilter, MutableArray result) { + var localChildNodes = getChildNodes(RetainedMessageNode::childNodesFactory); + if (localChildNodes != null) { + for (RetainedMessageNode childNode : localChildNodes) { + childNode.collectRetainedMessages(level + 1, topicFilter, result); + } + } + } + + private void collectEverything(RetainedMessageNode node, MutableArray result) { + collectEverythingDfs(node, result); + } + + private void collectEverythingDfs(RetainedMessageNode node, MutableArray result) { + Publish message = node.retainedMessage.get(); + if (message != null) { + result.add(message); + } + + var childNodes = node.getChildNodes(RetainedMessageNode::childNodesFactory); + if (childNodes != null) { + for (RetainedMessageNode childNode : childNodes) { + collectEverythingDfs(childNode, result); + } + } + } +} diff --git a/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java new file mode 100644 index 00000000..1df48806 --- /dev/null +++ b/model/src/main/java/javasabr/mqtt/model/topic/tree/package-info.java @@ -0,0 +1,4 @@ +@NullMarked +package javasabr.mqtt.model.topic.tree; + +import org.jspecify.annotations.NullMarked; diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy new file mode 100644 index 00000000..c82034c6 --- /dev/null +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/RetainedMessageTreeTest.groovy @@ -0,0 +1,96 @@ +package javasabr.mqtt.model.topic.tree + +import javasabr.mqtt.model.subscription.TestPublishFactory +import javasabr.mqtt.model.topic.TopicFilter +import javasabr.mqtt.test.support.UnitSpecification + +class RetainedMessageTreeTest extends UnitSpecification { + + def "should fetch retained messages by topic filter"( + List messages, + String rawTopicFilter, + List expectedMessages) { + given: + ConcurrentRetainedMessageTree retainedMessageTree = new ConcurrentRetainedMessageTree() + messages.collect(TestPublishFactory::makePublish).each(retainedMessageTree::retainMessage) + def topicFilter = TopicFilter.valueOf(rawTopicFilter) + when: + def retainedMessages = retainedMessageTree.getRetainedMessage(topicFilter) + then: + retainedMessages.size() == expectedMessages.size() + verifyEach(retainedMessages) { publish, index -> + publish.topicName().rawTopic() == expectedMessages[index] + } + where: + rawTopicFilter << [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment3", + "/topic/+/segment2", + "/topic/#" + ] + //noinspection GroovyAssignabilityCheck + messages << [ + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment1/segment2", + "/topic/", + "/topic" + ], + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment1/segment2", + "/topic/", + "/topic/segment2", + "/", + "/topic/segment2/segment1" + ], + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment3", + "/topic/segment3", + "/topic/segment3", + "/topic/segment3" + ], + [ + "/topic/segment1", + "/topic/segment2", + "/topic/segment1/segment2", + "/topic/segment500/segment2", + "/topic/", + "/topic" + ], + [ + "/topic1/segment1", + "/topic/segment2", + "/topic2/segment1/segment2", + "/topic/segment3", + "/topic/segment1/segment2" + ] + ] + //noinspection GroovyAssignabilityCheck + expectedMessages << [ + [ + "/topic/segment1" + ], + [ + "/topic/segment2" + ], + [ + "/topic/segment3" + ], + [ + "/topic/segment1/segment2", + "/topic/segment500/segment2" + ], + [ + "/topic/segment1/segment2", + "/topic/segment2", + "/topic/segment3" + ] + ] + } +} diff --git a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy index cfb17623..fe048055 100644 --- a/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy +++ b/model/src/test/groovy/javasabr/mqtt/model/topic/tree/SubscriberTreeTest.groovy @@ -15,6 +15,18 @@ import javasabr.mqtt.test.support.UnitSpecification class SubscriberTreeTest extends UnitSpecification { + static SingleSubscriber createSubscriber(String clientId, String rawTopicFilter) { + return createSubscriber(clientId, rawTopicFilter, QoS.AT_LEAST_ONCE.number()) + } + + static SingleSubscriber createSubscriber(String clientId, String rawTopicFilter, int qos) { + return new SingleSubscriber(makeUser(clientId), makeSubscription(rawTopicFilter, qos)) + } + + static SingleSubscriber createShareSubscriber(String clientId, String rawTopicFilter) { + return new SingleSubscriber(makeUser(clientId), makeSharedSubscription(rawTopicFilter)) + } + def "should match simple topic correctly"( List subscriptions, List users, @@ -23,7 +35,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -36,6 +48,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment2", "/topic/segment3" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1"), @@ -62,6 +75,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic/segment3") ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -88,6 +102,7 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id4") ] ] + //noinspection GroovyAssignabilityCheck expectedUsers << [ [ makeUser("id1") @@ -111,7 +126,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -124,6 +139,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment2", "/topic/segment3" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1"), @@ -156,6 +172,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic2/+") ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -188,6 +205,7 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id8") ] ] + //noinspection GroovyAssignabilityCheck expectedUsers << [ [ makeUser("id1"), @@ -216,7 +234,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -229,6 +247,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment3/segment4", "/topic/segment2" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1/segment2"), @@ -264,6 +283,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic/segment3/#") ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -299,6 +319,7 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id9") ] ] + //noinspection GroovyAssignabilityCheck expectedUsers << [ [ makeUser("id1"), @@ -330,7 +351,7 @@ class SubscriberTreeTest extends UnitSpecification { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() subscriptions.eachWithIndex { Subscription subscription, int i -> - subscriberTree.subscribe(users.get(i), subscription) + subscriberTree.subscribe(new SingleSubscriber(users.get(i), subscription)) } when: def found = subscriberTree.matches(TopicName.valueOf(topicName)) @@ -342,6 +363,7 @@ class SubscriberTreeTest extends UnitSpecification { "/topic/segment3", "/topic/segment2/" ] + //noinspection GroovyAssignabilityCheck subscriptions << [ [ makeSubscription("/topic/segment1/segment2", 2), @@ -377,6 +399,7 @@ class SubscriberTreeTest extends UnitSpecification { makeSubscription("/topic/#", 0) ] ] + //noinspection GroovyAssignabilityCheck users << [ [ makeUser("id1"), @@ -412,21 +435,22 @@ class SubscriberTreeTest extends UnitSpecification { makeUser("id3") ] ] + //noinspection GroovyAssignabilityCheck expectedSubscribers << [ [ - new SingleSubscriber(makeUser("id1"), makeSubscription("/topic/segment1/segment2", 2)), - new SingleSubscriber(makeUser("id2"), makeSubscription("/topic/segment1/#", 1)), - new SingleSubscriber(makeUser("id3"), makeSubscription("/topic/#", 0)), + createSubscriber("id1", "/topic/segment1/segment2", 2), + createSubscriber("id2", "/topic/segment1/#", 1), + createSubscriber("id3", "/topic/#", 0), ], [ - new SingleSubscriber(makeUser("id1"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id2"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id3"), makeSubscription("/topic/#", 0)), + createSubscriber("id1", "/topic/#", 0), + createSubscriber("id2", "/topic/#", 0), + createSubscriber("id3", "/topic/#", 0), ], [ - new SingleSubscriber(makeUser("id1"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id2"), makeSubscription("/topic/#", 0)), - new SingleSubscriber(makeUser("id3"), makeSubscription("/topic/segment2/#", 1)), + createSubscriber("id1", "/topic/#", 0), + createSubscriber("id2", "/topic/#", 0), + createSubscriber("id3", "/topic/segment2/#", 1), ] ] } @@ -436,16 +460,16 @@ class SubscriberTreeTest extends UnitSpecification { def group1 = ["id1", "id2", "id3", "id4", "id5"] def group2 = ["id6", "id7", "id8", "id9", "id10"] ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() - subscriberTree.subscribe(makeUser("id1"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id2"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id4"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id5"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id6"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id7"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id8"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id9"), makeSharedSubscription('$share/group2/topic/name1')) - subscriberTree.subscribe(makeUser("id10"), makeSharedSubscription('$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id1", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id2", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id3", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id4", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id5", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id6", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id7", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id8", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id9", '$share/group2/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id10", '$share/group2/topic/name1')) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -469,9 +493,9 @@ class SubscriberTreeTest extends UnitSpecification { def "should subscribe and unsubscribe simple topic correctly correctly"() { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() - subscriberTree.subscribe(makeUser("id1"), makeSubscription('topic/name1')) - subscriberTree.subscribe(makeUser("id2"), makeSubscription('topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSubscription('topic/name1')) + subscriberTree.subscribe(createSubscriber("id1", 'topic/name1')) + subscriberTree.subscribe(createSubscriber("id2", 'topic/name1')) + subscriberTree.subscribe(createSubscriber("id3", 'topic/name1')) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -506,9 +530,9 @@ class SubscriberTreeTest extends UnitSpecification { def "should subscribe and unsubscribe shared topic correctly correctly"() { given: ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() - subscriberTree.subscribe(makeUser("id1"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id2"), makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSharedSubscription('$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id1", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id2", '$share/group1/topic/name1')) + subscriberTree.subscribe(createShareSubscriber("id3", '$share/group1/topic/name1')) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -517,8 +541,12 @@ class SubscriberTreeTest extends UnitSpecification { then: matched.size() == 1 when: - def id2WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id2"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) - def id3WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id3"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) + def id2WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id2"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) + def id3WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id3"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .collect { it.user().toString() } @@ -528,8 +556,12 @@ class SubscriberTreeTest extends UnitSpecification { id2WasUnsubscribed id3WasUnsubscribed when: - def id1WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id1"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) - id3WasUnsubscribed = subscriberTree.unsubscribe(makeUser("id3"), SharedTopicFilter.valueOf('$share/group1/topic/name1')) + def id1WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id1"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) + id3WasUnsubscribed = subscriberTree.unsubscribe( + makeUser("id3"), + SharedTopicFilter.valueOf('$share/group1/topic/name1')) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .collect { it.user().toString() } @@ -546,10 +578,10 @@ class SubscriberTreeTest extends UnitSpecification { def owner1 = makeUser("id1") def originalSub = makeSubscription('topic/name1') def replacementSub = makeSubscription('topic/name1') - subscriberTree.subscribe(makeUser("id2"), makeSubscription('topic/name1')) - subscriberTree.subscribe(makeUser("id3"), makeSubscription('topic/name1')) + subscriberTree.subscribe(createSubscriber("id2", 'topic/name1')) + subscriberTree.subscribe(createSubscriber("id3", 'topic/name1')) when: - def previous = subscriberTree.subscribe(owner1, originalSub) + def previous = subscriberTree.subscribe(new SingleSubscriber(owner1, originalSub)) def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .toSet() @@ -557,7 +589,7 @@ class SubscriberTreeTest extends UnitSpecification { matched.size() == 3 previous == null; when: - previous = subscriberTree.subscribe(owner1, replacementSub) + previous = subscriberTree.subscribe(new SingleSubscriber(owner1, replacementSub)) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .toSet() @@ -573,8 +605,8 @@ class SubscriberTreeTest extends UnitSpecification { ConcurrentSubscriberTree subscriberTree = new ConcurrentSubscriberTree() def owner1 = makeUser("id1") def owner2 = makeUser("id2") - subscriberTree.subscribe(owner1, makeSharedSubscription('$share/group1/topic/name1')) - subscriberTree.subscribe(owner2, makeSharedSubscription('$share/group1/topic/name1')) + subscriberTree.subscribe(new SingleSubscriber(owner1, makeSharedSubscription('$share/group1/topic/name1'))) + subscriberTree.subscribe(new SingleSubscriber(owner2, makeSharedSubscription('$share/group1/topic/name1'))) when: def matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) @@ -597,7 +629,7 @@ class SubscriberTreeTest extends UnitSpecification { matched.size() == 1 matched.first().user() == owner2 when: - subscriberTree.subscribe(owner1, makeSharedSubscription('$share/group1/topic/name1')) + subscriberTree.subscribe(new SingleSubscriber(owner1, makeSharedSubscription('$share/group1/topic/name1'))) matched = subscriberTree .matches(TopicName.valueOf("topic/name1")) .toSet() diff --git a/model/src/testFixtures/groovy/javasabr/mqtt/model/subscription/TestPublishFactory.groovy b/model/src/testFixtures/groovy/javasabr/mqtt/model/subscription/TestPublishFactory.groovy new file mode 100644 index 00000000..257ab27d --- /dev/null +++ b/model/src/testFixtures/groovy/javasabr/mqtt/model/subscription/TestPublishFactory.groovy @@ -0,0 +1,67 @@ +package javasabr.mqtt.model.subscription + +import javasabr.mqtt.model.PayloadFormat +import javasabr.mqtt.model.QoS +import javasabr.mqtt.model.publishing.Publish +import javasabr.mqtt.model.topic.TopicName +import javasabr.rlib.collections.array.Array +import javasabr.rlib.collections.array.IntArray + +import static java.nio.charset.StandardCharsets.UTF_8 + +class TestPublishFactory { + + static def makePublish(String topicName) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + "payload".getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } + + static def makePublishWithRetain(String topicName, String payload) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + payload.getBytes(UTF_8), + false, + true, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } + + static def makePublishWithoutRetain(String topicName, String payload) { + return new Publish( + 1, + QoS.AT_MOST_ONCE, + TopicName.valueOf(topicName), + null, + payload.getBytes(UTF_8), + false, + false, + null, + IntArray.of(30), + null, + 60000, + 1, + PayloadFormat.UTF8_STRING, + Array.of()); + } +}