blob: 6ecfd9558346425d8be47487109be8f9427ff020 [file] [log] [blame]
# Copyright 2023 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import asyncio
import unittest
from fuchsia_controller_py import Channel, Socket, ZxStatus
from fidl import AlreadyReadingAll, AsyncSocket
class SocketTests(unittest.IsolatedAsyncioTestCase):
"""Socket tests."""
def test_socket_write_then_read(self):
"""Tests a simple write followed by an immediate read."""
(sock_out, sock_in) = Socket.create()
bytes_in = bytearray([1, 2, 3])
sock_out.write(bytes_in)
bytes_out = sock_in.read()
self.assertEqual(bytes_out, bytes_in)
def test_socket_write_fails_when_closed(self):
"""Verifies PEER_CLOSED is surfaced when opposing socket is closed."""
(sock_out, sock_in) = Socket.create()
del sock_in
with self.assertRaises(ZxStatus):
try:
sock_out.write(bytearray([1, 2, 3]))
except ZxStatus as e:
self.assertEqual(e.args[0], ZxStatus.ZX_ERR_PEER_CLOSED)
raise e
def test_socket_passing(self):
"""Verifies a socket remains connected when passed through a channel."""
(chan_out, chan_in) = Channel.create()
(sock_out, sock_in) = Socket.create()
# This is using 'take' rather than 'as_int' as using 'as_int' would cause a double-close
# error on a channel that has already been closed.
chan_out.write((bytearray(), [(0, sock_out.take(), 0, 0, 0)]))
_, hdls = chan_in.read()
self.assertEqual(len(hdls), 1)
bytes_in = bytearray([1, 2, 3])
new_sock_out = Socket(hdls[0])
new_sock_out.write(bytes_in)
bytes_out = sock_in.read()
self.assertEqual(bytes_out, bytes_in)
async def test_async_socket(self):
"""Verifies an async socket is able to wait for the other end to complete writing."""
(sock_out, sock_in) = Socket.create()
sock_in = AsyncSocket(sock_in)
bytes_in = bytearray([1, 2, 3])
async def slow_write(sock: Socket, b: bytes) -> None:
await asyncio.sleep(1)
sock.write(b)
loop = asyncio.get_running_loop()
write_task = loop.create_task(slow_write(sock_out, bytes_in))
read_task = loop.create_task(sock_in.read())
done, _ = await asyncio.wait([write_task, read_task])
self.assertEqual(len(done), 2)
read_bytes = None
for item in done:
if item.result() is not None:
read_bytes = item.result()
self.assertEqual(read_bytes, bytes_in)
async def test_async_socket_write_fails_when_closed(self):
"""Verifies an async socket write fails in the expected way when the other end is closed."""
(sock_out, sock_in) = Socket.create()
del sock_in
with self.assertRaises(ZxStatus):
sock_out = AsyncSocket(sock_out)
try:
sock_out.write(bytearray([1, 2, 3]))
except ZxStatus as e:
self.assertEqual(e.args[0], ZxStatus.ZX_ERR_PEER_CLOSED)
raise e
async def test_async_socket_read_fails_when_closed(self):
"""Verifies an async socket read fails in the expected way when the other end is closed."""
(sock_out, sock_in) = Socket.create()
del sock_out
with self.assertRaises(ZxStatus):
sock_in = AsyncSocket(sock_in)
try:
await sock_in.read()
except ZxStatus as e:
self.assertEqual(e.args[0], ZxStatus.ZX_ERR_PEER_CLOSED)
raise e
async def test_async_socket_read_fails_when_already_reading_all(self):
"""Verifies running `read_all()` twice fails."""
(sock_out, sock_in) = Socket.create()
with self.assertRaises(AlreadyReadingAll):
sock_out = AsyncSocket(sock_out)
task = asyncio.get_running_loop().create_task(sock_out.read_all())
await asyncio.gather(sock_out.read_all(), task)
async def test_async_socket_read_fails_when_already_reading_all(self):
"""Verifies running `read_all()` then `read()` fails."""
(sock_out, sock_in) = Socket.create()
with self.assertRaises(AlreadyReadingAll):
sock_out = AsyncSocket(sock_out)
task = asyncio.get_running_loop().create_task(sock_out.read_all())
await asyncio.gather(sock_out.read_all(), task)