blob: 654c26ec3ba5d644066efb070e6eedf1f0c28b2a [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2022 The Fuchsia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest import TestCase
import sys
from antlion.event import subscription_bundle
from antlion.event.decorators import subscribe
from antlion.event.decorators import subscribe_static
from antlion.event.event import Event
from antlion.event.subscription_bundle import SubscriptionBundle
from mock import Mock
from mock import patch
class SubscriptionBundleTest(TestCase):
"""Tests the SubscriptionBundle class."""
def test_add_calls_add_subscription_properly(self):
"""Tests that the convenience function add() calls add_subscription."""
event = object()
func = object()
event_filter = object()
order = object()
package = SubscriptionBundle()
package.add_subscription = Mock()
package.add(event, func, event_filter=event_filter, order=order)
self.assertEqual(package.add_subscription.call_count, 1)
subscription = package.add_subscription.call_args[0][0]
self.assertEqual(subscription._event_type, event)
self.assertEqual(subscription._func, func)
self.assertEqual(subscription._event_filter, event_filter)
self.assertEqual(subscription.order, order)
@patch('antlion.event.event_bus.register_subscription')
def test_add_subscription_registers_sub_if_package_is_registered(
self, register_subscription):
"""Tests that add_subscription registers the subscription if the
SubscriptionBundle is already registered."""
package = SubscriptionBundle()
package._registered = True
mock_subscription = Mock()
package.add_subscription(mock_subscription)
self.assertEqual(register_subscription.call_count, 1)
register_subscription.assert_called_with(mock_subscription)
def test_add_subscription_adds_to_subscriptions(self):
"""Tests add_subscription adds the subscription to subscriptions."""
mock_subscription = Mock()
package = SubscriptionBundle()
package.add_subscription(mock_subscription)
self.assertTrue(mock_subscription in package.subscriptions.keys())
def test_remove_subscription_removes_subscription_from_subscriptions(self):
"""Tests remove_subscription removes the given subscription from the
subscriptions dictionary."""
mock_subscription = Mock()
package = SubscriptionBundle()
package.subscriptions[mock_subscription] = id(mock_subscription)
package.remove_subscription(mock_subscription)
self.assertTrue(mock_subscription not in package.subscriptions.keys())
@patch('antlion.event.event_bus.unregister')
def test_remove_subscription_unregisters_subscription(self, unregister):
"""Tests that removing a subscription will also unregister it if the
SubscriptionBundle is registered."""
mock_subscription = Mock()
package = SubscriptionBundle()
package._registered = True
package.subscriptions[mock_subscription] = id(mock_subscription)
package.remove_subscription(mock_subscription)
self.assertEqual(unregister.call_count, 1)
unregistered_obj = unregister.call_args[0][0]
self.assertTrue(unregistered_obj == id(mock_subscription) or
unregistered_obj == mock_subscription)
@patch('antlion.event.event_bus.register_subscription')
def test_register_registers_all_subscriptions(self, register_subscription):
"""Tests register() registers all subscriptions within the bundle."""
mock_subscription_list = [Mock(), Mock(), Mock()]
package = SubscriptionBundle()
package._registered = False
for subscription in mock_subscription_list:
package.subscriptions[subscription] = None
package.register()
self.assertEqual(register_subscription.call_count,
len(mock_subscription_list))
args = {args[0] for args, _ in register_subscription.call_args_list}
for subscription in mock_subscription_list:
self.assertTrue(subscription in args or id(subscription) in args)
@patch('antlion.event.event_bus.unregister')
def test_register_registers_all_subscriptions(self, unregister):
"""Tests register() registers all subscriptions within the bundle."""
mock_subscription_list = [Mock(), Mock(), Mock()]
package = SubscriptionBundle()
package._registered = True
for subscription in mock_subscription_list:
package.subscriptions[subscription] = id(subscription)
package.unregister()
self.assertEqual(unregister.call_count, len(mock_subscription_list))
args = {args[0] for args, _ in unregister.call_args_list}
for subscription in mock_subscription_list:
self.assertTrue(subscription in args or id(subscription) in args)
class SubscriptionBundleStaticFunctions(TestCase):
"""Tests the static functions found in subscription_bundle.py"""
@staticmethod
@subscribe_static(Event)
def static_listener_1():
pass
@staticmethod
@subscribe_static(Event)
def static_listener_2():
pass
@subscribe(Event)
def instance_listener_1(self):
pass
@subscribe(Event)
def instance_listener_2(self):
pass
def test_create_from_static(self):
"""Tests create_from_static gets all StaticSubscriptionHandles."""
cls = self.__class__
bundle = subscription_bundle.create_from_static(cls)
self.assertEqual(len(bundle.subscriptions), 2)
keys = bundle.subscriptions.keys()
self.assertIn(cls.static_listener_1.subscription, keys)
self.assertIn(cls.static_listener_2.subscription, keys)
def test_create_from_instance(self):
"""Tests create_from_instance gets all InstanceSubscriptionHandles."""
bundle = subscription_bundle.create_from_instance(self)
self.assertEqual(len(bundle.subscriptions), 2)
keys = bundle.subscriptions.keys()
self.assertIn(self.instance_listener_1.subscription, keys)
self.assertIn(self.instance_listener_2.subscription, keys)
@subscribe_static(Event)
def static_listener_1():
pass
class SubscribeStaticModuleLevelTest(TestCase):
def test_create_from_static(self):
"""Tests create_from_static gets all StaticSubscriptionHandles."""
bundle = subscription_bundle.create_from_static(
sys.modules[self.__module__])
self.assertEqual(len(bundle.subscriptions), 1)
keys = bundle.subscriptions.keys()
self.assertIn(static_listener_1.subscription, keys)
if __name__ == '__main__':
unittest.main()