blob: a9b570fdab506655ffc687d76c30ecb80ab6b875 [file] [log] [blame]
# Copyright 2023 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
# TODO(https://fxbug.dev/346628306): Remove this comment to ignore mypy errors.
# mypy: ignore-errors
import asyncio
import logging
from abc import abstractmethod
from inspect import getframeinfo, stack
from typing import Any
import fuchsia_controller_py as fc
from fidl_codec import decode_fidl_request, encode_fidl_message
from ._fidl_common import (
DomainError,
FidlMessage,
FidlMeta,
FrameworkError,
GenericResult,
StopServer,
parse_ordinal,
parse_txid,
)
from ._ipc import GlobalHandleWaker
# Rather than make a long server UUID, this will be a monotonically increasing
# ID to differentiate servers for debugging purposes.
_SERVER_ID = 0
_LOGGER = logging.getLogger("fidl.server")
class ServerError(Exception):
pass
class ServerBase(
metaclass=FidlMeta,
required_class_variables=[
("library", str),
("method_map", dict),
],
):
"""Base object for doing basic FIDL server tasks."""
@staticmethod
@abstractmethod
def construct_response_object(
response_ident: str, response_obj: Any
) -> Any:
...
def __str__(self):
return f"server:{type(self).__name__}:{id(self)}"
def __init__(self, channel: fc.Channel, channel_waker=None):
global _SERVER_ID
self._channel = channel
self.id = _SERVER_ID
_SERVER_ID += 1
if channel_waker is None:
self._channel_waker = GlobalHandleWaker()
else:
self._channel_waker = channel_waker
caller = getframeinfo(stack()[1][0])
_LOGGER.debug(
f"{self} instantiated from {caller.filename}:{caller.lineno}"
)
def __del__(self):
_LOGGER.debug(f"{self} closing")
if self._channel is not None:
self._channel_waker.unregister(self._channel)
self._channel = None
def close(self):
self.__del__()
def serve(self):
self._channel_waker.register(self._channel)
async def _serve():
self._channel_waker.register(self._channel)
while await self.handle_next_request():
pass
return _serve()
async def handle_next_request(self) -> bool:
try:
# TODO(b/299946378): Handle case where ordinal is unknown.
return await self._handle_request_helper()
except StopServer:
self._channel.close()
return False
except Exception as e:
# It's very important to close the channel, because if this is run inside a task,
# then it isn't possible for the exception to get raised in time. So if another
# coroutine depends on this server functioning (like a client), then it'll hang
# forever. So, we must close the channel in order to make progress.
self._channel.close()
self._channel = None
_LOGGER.debug(f"{self} request handling error: {e}")
raise e
async def _handle_request_helper(self) -> bool:
# TODO(b/303532690): When attempting to decode a method that is
# unrecognized, there should be a message sent declaring this is
# an unknown method.
try:
msg, txid, ordinal = await self._channel_read_and_parse()
except fc.ZxStatus as e:
if e.args[0] == fc.ZxStatus.ZX_ERR_PEER_CLOSED:
_LOGGER.debug(f"{self} shutting down. PEER_CLOSED received")
return False
else:
_LOGGER.warn(f"{self} channel received error: {e}")
raise e
info = self.method_map[ordinal]
info.request_ident
method_name = info.name
method = getattr(self, method_name)
if msg is not None:
res = method(msg)
else:
res = method()
if asyncio.iscoroutine(res) or asyncio.isfuture(res):
res = await res
if res is not None and not info.requires_response:
raise ServerError(
f"{self} method {info.name} received a "
+ "response but is one-way method"
)
if res is None and info.requires_response and not info.empty_response:
raise ServerError(
f"{self} method {info.name} returned "
+ "None when a response was expected"
)
if info.has_result:
_LOGGER.debug(f"{self} received method response {res}")
if type(res) is DomainError:
res = GenericResult(
fidl_type=info.response_identifier, err=res.error
)
elif type(res) is FrameworkError:
res = GenericResult(
fidl_type=info.response_identifier, framework_err=res
)
else:
if res is None:
res = GenericResult(
fidl_type=info.response_identifier, response=object()
)
else:
res = GenericResult(
fidl_type=info.response_identifier, response=res
)
if res is not None:
encoded_fidl_message = encode_fidl_message(
ordinal=ordinal,
object=res,
library=self.library,
txid=txid,
type_name=res.__fidl_raw_type__,
)
self._channel.write(encoded_fidl_message)
elif info.empty_response:
encoded_fidl_message = encode_fidl_message(
ordinal=ordinal,
object=None,
library=self.library,
txid=txid,
type_name=None,
)
self._channel.write(encoded_fidl_message)
return True
async def _channel_read(self) -> FidlMessage:
while True:
try:
return self._channel.read()
except fc.ZxStatus as e:
# Any number of spurious wakeups are possible. Stay in the loop if the error
# is ZX_ERR_SHOULD_WAIT.
if e.args[0] == fc.ZxStatus.ZX_ERR_SHOULD_WAIT:
_LOGGER.debug(f"{self} channel spurious wakeup")
await self._channel_waker.wait_ready(self._channel)
continue
self._channel_waker.unregister(self._channel)
_LOGGER.warning(f"{self} channel received error: {e}")
raise e
async def _channel_read_and_parse(self):
raw_msg = await self._channel_read()
ordinal = parse_ordinal(raw_msg)
txid = parse_txid(raw_msg)
handles = [x.take() for x in raw_msg[1]]
msg = decode_fidl_request(bytes=raw_msg[0], handles=handles)
result_obj = self.construct_response_object(
self.method_map[ordinal].request_ident, msg
)
return result_obj, txid, ordinal
def _send_event(self, ordinal: int, library: str, msg_obj):
type_name = None
if msg_obj is not None:
type_name = msg_obj.__fidl_raw_type__
encoded_fidl_message = encode_fidl_message(
ordinal=ordinal,
object=msg_obj,
library=library,
txid=0,
type_name=type_name,
)
self._channel.write(encoded_fidl_message)