// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

// Fuchsia's BSD socket tests ensure that fdio and Netstack together produce
// POSIX-like behavior. This module contains tests that are generic over
// transport protocol.

#include <arpa/inet.h>
#include <fcntl.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/icmp6.h>
#include <netinet/if_ether.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>

#include <future>
#include <latch>

#include <fbl/unique_fd.h>
#include <gtest/gtest.h>

#include "util.h"

namespace {

std::pair<sockaddr_storage, socklen_t> InitLoopbackAddr(const SocketDomain& domain) {
  sockaddr_storage ss;
  switch (domain.which()) {
    case SocketDomain::Which::IPv4:
      *(reinterpret_cast<sockaddr_in*>(&ss)) = LoopbackSockaddrV4(0);
      return {ss, sizeof(sockaddr_in)};
    case SocketDomain::Which::IPv6:
      *(reinterpret_cast<sockaddr_in6*>(&ss)) = LoopbackSockaddrV6(0);
      return {ss, sizeof(sockaddr_in6)};
  }
}

void ConnectSocketsOverLoopback(const SocketDomain& domain, const SocketType& socket_type,
                                fbl::unique_fd& sendfd, fbl::unique_fd& recvfd) {
  auto [addr, addrlen] = InitLoopbackAddr(domain);

  ASSERT_TRUE(sendfd = fbl::unique_fd(socket(domain.Get(), socket_type.Get(), 0)))
      << strerror(errno);
  switch (socket_type.which()) {
    case SocketType::Which::Stream: {
      fbl::unique_fd acptfd;
      ASSERT_TRUE(acptfd = fbl::unique_fd(socket(domain.Get(), socket_type.Get(), 0)))
          << strerror(errno);
      EXPECT_EQ(bind(acptfd.get(), reinterpret_cast<const sockaddr*>(&addr), addrlen), 0)
          << strerror(errno);
      socklen_t found_len = addrlen;
      EXPECT_EQ(getsockname(acptfd.get(), reinterpret_cast<sockaddr*>(&addr), &found_len), 0)
          << strerror(errno);
      EXPECT_EQ(found_len, addrlen);
      EXPECT_EQ(listen(acptfd.get(), 0), 0) << strerror(errno);
      EXPECT_EQ(connect(sendfd.get(), reinterpret_cast<const sockaddr*>(&addr), addrlen), 0)
          << strerror(errno);
      ASSERT_TRUE(recvfd = fbl::unique_fd(accept(acptfd.get(), nullptr, nullptr)))
          << strerror(errno);
      EXPECT_EQ(close(acptfd.release()), 0) << strerror(errno);
      break;
    }
    case SocketType::Which::Dgram: {
      ASSERT_TRUE(recvfd = fbl::unique_fd(socket(domain.Get(), socket_type.Get(), 0)))
          << strerror(errno);
      EXPECT_EQ(bind(recvfd.get(), reinterpret_cast<const sockaddr*>(&addr), addrlen), 0)
          << strerror(errno);
      socklen_t found_len = addrlen;
      EXPECT_EQ(getsockname(recvfd.get(), reinterpret_cast<sockaddr*>(&addr), &found_len), 0)
          << strerror(errno);
      EXPECT_EQ(found_len, addrlen);
      EXPECT_EQ(connect(sendfd.get(), reinterpret_cast<const sockaddr*>(&addr), addrlen), 0)
          << strerror(errno);

      EXPECT_EQ(getsockname(sendfd.get(), reinterpret_cast<sockaddr*>(&addr), &found_len), 0)
          << strerror(errno);
      EXPECT_EQ(found_len, addrlen);
      EXPECT_EQ(connect(recvfd.get(), reinterpret_cast<const sockaddr*>(&addr), addrlen), 0)
          << strerror(errno);
      break;
    }
  }
}

// Test the error when a client's sandbox does not have access raw/packet sockets.
TEST(LocalhostTest, RawSocketsNotAvailable) {
  // No raw INET sockets.
  ASSERT_EQ(socket(AF_INET, SOCK_RAW, 0), -1);
  ASSERT_EQ(errno, EPROTONOSUPPORT) << strerror(errno);
  ASSERT_EQ(socket(AF_INET, SOCK_RAW, IPPROTO_UDP), -1);
  ASSERT_EQ(errno, EPERM) << strerror(errno);
  ASSERT_EQ(socket(AF_INET, SOCK_RAW, IPPROTO_RAW), -1);
  ASSERT_EQ(errno, EPERM) << strerror(errno);

  // No packet sockets.
  ASSERT_EQ(socket(AF_PACKET, SOCK_RAW, htons(ETH_P_ALL)), -1);
  ASSERT_EQ(errno, EPERM) << strerror(errno);
}

// TODO(https://fxbug.dev/90038): Delete once SockOptsTest is gone.
struct SockOption {
  int level;
  int option;

  bool operator==(const SockOption& other) const {
    return level == other.level && option == other.option;
  }
};

constexpr int INET_ECN_MASK = 3;

using SocketKind = std::tuple<SocketDomain, SocketType>;

std::string SocketKindToString(const testing::TestParamInfo<SocketKind>& info) {
  auto const& [domain, type] = info.param;
  std::ostringstream oss;
  oss << socketDomainToString(domain);
  oss << '_' << socketTypeToString(type);
  return oss.str();
}

// Share common functions for SocketKind based tests.
class SocketKindTest : public testing::TestWithParam<SocketKind> {
 protected:
  static fbl::unique_fd NewSocket() {
    auto const& [domain, type] = GetParam();
    return fbl::unique_fd(socket(domain.Get(), type.Get(), 0));
  }

  static std::pair<sockaddr_storage, socklen_t> LoopbackAddr() {
    auto const& [domain, protocol] = GetParam();
    return InitLoopbackAddr(domain);
  }
};

constexpr int kSockOptOn = 1;
constexpr int kSockOptOff = 0;

struct SocketOption {
  SocketOption(int level, std::string level_str, int name, std::string name_str)
      : level(level), level_str(level_str), name(name), name_str(name_str) {}

  int level;
  std::string level_str;
  int name;
  std::string name_str;
};

#define STRINGIFIED_SOCKOPT(level, name) SocketOption(level, #level, name, #name)

struct IntSocketOption {
  SocketOption option;
  bool is_boolean;
  int default_value;
  std::vector<int> valid_values;
  std::vector<int> invalid_values;
};

class SocketOptionTestBase : public testing::Test {
 public:
  SocketOptionTestBase(const SocketDomain& domain, const SocketType& type)
      : sock_domain_(domain), sock_type_(type) {}

 protected:
  void SetUp() override {
    ASSERT_TRUE(sock_ = fbl::unique_fd(socket(sock_domain_.Get(), sock_type_.Get(), 0)))
        << strerror(errno);
  }

  void TearDown() override { EXPECT_EQ(close(sock_.release()), 0) << strerror(errno); }

  bool IsOptionLevelSupportedByDomain(int level) const {
#if defined(__Fuchsia__)
    // TODO(https://gvisor.dev/issues/6389): Remove once Fuchsia returns an error
    // when setting/getting IPv6 options on an IPv4 socket.
    return true;
#else
    // IPv6 options are only supported on AF_INET6 sockets.
    return sock_domain_.which() == SocketDomain::Which::IPv6 || level != IPPROTO_IPV6;
#endif
  }

  fbl::unique_fd const& sock() const { return sock_; }

 private:
  fbl::unique_fd sock_;
  const SocketDomain sock_domain_;
  const SocketType sock_type_;
};

std::string socketKindAndOptionToString(const SocketDomain& domain, const SocketType& type,
                                        SocketOption opt) {
  std::ostringstream oss;
  oss << socketDomainToString(domain);
  oss << '_' << socketTypeToString(type);
  oss << '_' << opt.level_str;
  oss << '_' << opt.name_str;
  return oss.str();
}

using SocketKindAndIntOption = std::tuple<SocketDomain, SocketType, IntSocketOption>;

std::string SocketKindAndIntOptionToString(
    const testing::TestParamInfo<SocketKindAndIntOption>& info) {
  auto const& [domain, type, int_opt] = info.param;
  return socketKindAndOptionToString(domain, type, int_opt.option);
}

// Test functionality common to every integer and pseudo-boolean socket option.
class IntSocketOptionTest : public SocketOptionTestBase,
                            public testing::WithParamInterface<SocketKindAndIntOption> {
 protected:
  IntSocketOptionTest()
      : SocketOptionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
        opt_(std::get<2>(GetParam())) {}

  void SetUp() override {
    ASSERT_FALSE(opt_.valid_values.empty()) << "must have at least one valid value";
    SocketOptionTestBase::SetUp();
  }

  void TearDown() override { SocketOptionTestBase::TearDown(); }

  bool IsOptionCharCompatible() const {
    const int level = opt_.option.level;
    return level != IPPROTO_IPV6 && level != SOL_SOCKET;
  }

  IntSocketOption const& opt() const { return opt_; }

 private:
  const IntSocketOption opt_;
};

TEST_P(IntSocketOptionTest, Default) {
  int get = -1;
  socklen_t get_len = sizeof(get);
  const int r = getsockopt(sock().get(), opt().option.level, opt().option.name, &get, &get_len);

  if (IsOptionLevelSupportedByDomain(opt().option.level)) {
    ASSERT_EQ(r, 0) << strerror(errno);
    ASSERT_EQ(get_len, sizeof(get));
    EXPECT_EQ(get, opt().default_value);
  } else {
    ASSERT_EQ(r, -1);
    EXPECT_EQ(errno, ENOTSUP) << strerror(errno);
  }
}

TEST_P(IntSocketOptionTest, SetValid) {
  for (int value : opt().valid_values) {
    SCOPED_TRACE("value=" + std::to_string(value));
    // Test each value in a lambda so we continue testing the other values if an ASSERT fails.
    [&]() {
      const int r =
          setsockopt(sock().get(), opt().option.level, opt().option.name, &value, sizeof(value));

      if (IsOptionLevelSupportedByDomain(opt().option.level)) {
        ASSERT_EQ(r, 0) << strerror(errno);
        int get = -1;
        socklen_t get_len = sizeof(get);
        ASSERT_EQ(getsockopt(sock().get(), opt().option.level, opt().option.name, &get, &get_len),
                  0)
            << strerror(errno);
        ASSERT_EQ(get_len, sizeof(get));
        EXPECT_EQ(get, opt().is_boolean ? static_cast<bool>(value) : value);
      } else {
        ASSERT_EQ(r, -1);
        EXPECT_EQ(errno, ENOPROTOOPT) << strerror(errno);
      }
    }();
  }
}

TEST_P(IntSocketOptionTest, SetInvalid) {
  for (int value : opt().invalid_values) {
    SCOPED_TRACE("value=" + std::to_string(value));
    // Test each value in a lambda so we continue testing the other values if an ASSERT fails.
    [&]() {
      const int r =
          setsockopt(sock().get(), opt().option.level, opt().option.name, &value, sizeof(value));

      if (IsOptionLevelSupportedByDomain(opt().option.level)) {
        ASSERT_EQ(r, -1);
        EXPECT_EQ(errno, EINVAL) << strerror(errno);

        // Confirm that no changes were made.
        int get = -1;
        socklen_t get_len = sizeof(get);
        ASSERT_EQ(getsockopt(sock().get(), opt().option.level, opt().option.name, &get, &get_len),
                  0)
            << strerror(errno);
        ASSERT_EQ(get_len, sizeof(get));
        EXPECT_EQ(get, opt().default_value);
      } else {
        ASSERT_EQ(r, -1);
        EXPECT_EQ(errno, ENOPROTOOPT) << strerror(errno);
      }
    }();
  }
}

TEST_P(IntSocketOptionTest, SetChar) {
  for (int value : opt().valid_values) {
    SCOPED_TRACE("value=" + std::to_string(value));
    // Test each value in a lambda so we continue testing the other values if an ASSERT fails.
    [&]() {
      int want;
      {
        const char set_char = static_cast<char>(value);
        if (static_cast<int>(set_char) != value) {
          // Skip values that don't fit in a char.
          return;
        }
        const int r = setsockopt(sock().get(), opt().option.level, opt().option.name, &set_char,
                                 sizeof(set_char));
        if (!IsOptionLevelSupportedByDomain(opt().option.level)) {
          ASSERT_EQ(r, -1);
          EXPECT_EQ(errno, ENOPROTOOPT) << strerror(errno);
          want = opt().default_value;
        } else if (!IsOptionCharCompatible()) {
          ASSERT_EQ(r, -1);
          EXPECT_EQ(errno, EINVAL) << strerror(errno);
          want = opt().default_value;
        } else {
          ASSERT_EQ(r, 0) << strerror(errno);
          want = opt().is_boolean ? static_cast<bool>(set_char) : set_char;
        }
      }

      {
        char get = -1;
        socklen_t get_len = sizeof(get);
        const int r =
            getsockopt(sock().get(), opt().option.level, opt().option.name, &get, &get_len);
        if (!IsOptionLevelSupportedByDomain(opt().option.level)) {
          ASSERT_EQ(r, -1);
          EXPECT_EQ(errno, ENOTSUP) << strerror(errno);
        } else {
          ASSERT_EQ(r, 0) << strerror(errno);
          ASSERT_EQ(get_len, sizeof(get));
          EXPECT_EQ(get, static_cast<char>(want));
        }
      }

      {
        int16_t get = -1;
        socklen_t get_len = sizeof(get);
        const int r =
            getsockopt(sock().get(), opt().option.level, opt().option.name, &get, &get_len);
        if (!IsOptionLevelSupportedByDomain(opt().option.level)) {
          ASSERT_EQ(r, -1);
          EXPECT_EQ(errno, ENOTSUP) << strerror(errno);
        } else if (!IsOptionCharCompatible()) {
          ASSERT_EQ(r, 0) << strerror(errno);
          ASSERT_EQ(get_len, sizeof(get));
          EXPECT_EQ(get, want);
        } else {
          ASSERT_EQ(r, 0) << strerror(errno);
          // Truncates size < 4 to 1 and only writes the low byte.
          // https://github.com/torvalds/linux/blob/2585cf9dfaa/net/ipv4/ip_sockglue.c#L1742-L1745
          ASSERT_EQ(get_len, sizeof(char));
          EXPECT_EQ(get, static_cast<int16_t>(uint16_t(-1) << 8) | want);
        }
      }

      {
        int get = -1;
        socklen_t get_len = sizeof(get);
        const int r =
            getsockopt(sock().get(), opt().option.level, opt().option.name, &get, &get_len);
        if (!IsOptionLevelSupportedByDomain(opt().option.level)) {
          ASSERT_EQ(r, -1);
          EXPECT_EQ(errno, ENOTSUP) << strerror(errno);
        } else {
          ASSERT_EQ(r, 0) << strerror(errno);
          ASSERT_EQ(get_len, sizeof(get));
          EXPECT_EQ(get, want);
        }
      }
    }();
  }
}

const std::vector<int> kBooleanOptionValidValues = {-2, -1, 0, 1, 2, 15, 255, 256};

// The tests below use valid and invalid values that attempt to cover normal use cases,
// min/max values, and invalid negative/large values.
// Special values (e.g. ones that reset an option to its default) have option-specific tests.
INSTANTIATE_TEST_SUITE_P(
    IntSocketOptionTests, IntSocketOptionTest,
    testing::Combine(testing::Values(SocketDomain::IPv4(), SocketDomain::IPv6()),
                     testing::Values(SocketType::Stream(), SocketType::Dgram()),
                     testing::Values(
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_MULTICAST_LOOP),
                             .is_boolean = true,
                             .default_value = 1,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_TOS),
                             .is_boolean = false,
                             .default_value = 0,
                             // The ECN (2 rightmost) bits may be cleared, so we use arbitrary
                             // values without these bits set. See CheckSkipECN test.
                             .valid_values = {0x04, 0xC0, 0xFC},
                             // Larger-than-byte values are accepted but the extra bits are
                             // merely ignored. See InvalidLargeTOS test.
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_RECVTOS),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_TTL),
                             .is_boolean = false,
                             .default_value = 64,
                             // -1 is not tested here, it is a special value which resets ttl to
                             // its default value.
                             .valid_values = {1, 2, 15, 255},
                             .invalid_values = {-2, 0, 256},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_RECVTTL),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption {
                           .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_MULTICAST_LOOP),
                           .is_boolean = true, .default_value = 1,
#if defined(__Fuchsia__)
                           .valid_values = kBooleanOptionValidValues, .invalid_values = {},
#else
                           // On Linux, this option only accepts 0 or 1. This is one of a kind.
                           // There seem to be no good reasons for it, so it should probably be
                           // fixed in Linux rather than in Fuchsia.
                           // https://github.com/torvalds/linux/blob/eec4df26e24/net/ipv6/ipv6_sockglue.c#L758
                               .valid_values = {0, 1}, .invalid_values = {-2, -1, 2, 15, 255, 256},
#endif
                         },
                         IntSocketOption {
                           .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_TCLASS),
                           .is_boolean = false, .default_value = 0,
#if defined(__Fuchsia__)
                           // TODO(https://gvisor.dev/issues/6389): Remove once Fuchsia treats
                           // IPV6_TCLASS differently than IP_TOS. See CheckSkipECN test.
                               .valid_values = {0x04, 0xC0, 0xFC},
#else
                           // -1 is not tested here, it is a special value which resets the traffic
                           // class to its default value.
                               .valid_values = {0, 1, 2, 15, 255},
#endif
                           .invalid_values = {-2, 256},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_RECVTCLASS),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_UNICAST_HOPS),
                             .is_boolean = false,
                             .default_value = 64,
                             // -1 is not tested here, it is a special value which resets ttl to
                             // its default value.
                             .valid_values = {0, 1, 2, 15, 255},
                             .invalid_values = {-2, 256},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_RECVHOPLIMIT),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_RECVPKTINFO),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(SOL_SOCKET, SO_NO_CHECK),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(SOL_SOCKET, SO_TIMESTAMP),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(SOL_SOCKET, SO_TIMESTAMPNS),
                             .is_boolean = true,
                             .default_value = 0,
                             .valid_values = kBooleanOptionValidValues,
                             .invalid_values = {},
                         })),
    SocketKindAndIntOptionToString);

// TODO(https://github.com/google/gvisor/issues/6972): Test multicast ttl options on SOCK_STREAM
// sockets. Right now it's complicated because setting these options on a stream socket silently
// fails (no error returned but no change observed).
INSTANTIATE_TEST_SUITE_P(
    DatagramIntSocketOptionTests, IntSocketOptionTest,
    testing::Combine(testing::Values(SocketDomain::IPv4(), SocketDomain::IPv6()),
                     testing::Values(SocketType::Dgram()),
                     testing::Values(
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_MULTICAST_TTL),
                             .is_boolean = false,
                             .default_value = 1,
                             // -1 is not tested here, it is a special value which
                             // resets the ttl to its default value.
                             .valid_values = {0, 1, 2, 15, 128, 255},
                             .invalid_values = {-2, 256},
                         },
                         IntSocketOption{
                             .option = STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_MULTICAST_HOPS),
                             .is_boolean = false,
                             .default_value = 1,
                             // -1 is not tested here, it is a special value which
                             // resets the hop limit to its default value.
                             .valid_values = {0, 1, 2, 15, 128, 255},
                             .invalid_values = {-2, 256},
                         })),
    SocketKindAndIntOptionToString);

using SocketKindAndOption = std::tuple<SocketDomain, SocketType, SocketOption>;

std::string SocketKindAndOptionToString(const testing::TestParamInfo<SocketKindAndOption>& info) {
  auto const& [domain, type, opt] = info.param;
  return socketKindAndOptionToString(domain, type, opt);
}

class SocketOptionSharedTest : public SocketOptionTestBase,
                               public testing::WithParamInterface<SocketKindAndOption> {
 protected:
  SocketOptionSharedTest()
      : SocketOptionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
        opt_(std::get<2>(GetParam())) {}

  void SetUp() override { SocketOptionTestBase::SetUp(); }

  void TearDown() override { SocketOptionTestBase::TearDown(); }

  SocketOption opt() const { return opt_; }

 private:
  const SocketOption opt_;
};

using TtlHopLimitSocketOptionTest = SocketOptionSharedTest;

TEST_P(TtlHopLimitSocketOptionTest, ResetToDefault) {
  if (!IsOptionLevelSupportedByDomain(opt().level)) {
    GTEST_SKIP() << "Option not supported by socket domain";
  }

  constexpr int kDefaultTTL = 64;
  constexpr int kNonDefaultValue = kDefaultTTL + 1;
  ASSERT_EQ(setsockopt(sock().get(), opt().level, opt().name, &kNonDefaultValue,
                       sizeof(kNonDefaultValue)),
            0)
      << strerror(errno);

  // Coherence check.
  {
    int get = -1;
    socklen_t get_len = sizeof(get);
    ASSERT_EQ(getsockopt(sock().get(), opt().level, opt().name, &get, &get_len), 0)
        << strerror(errno);
    ASSERT_EQ(get_len, sizeof(get));
    EXPECT_EQ(get, kNonDefaultValue);
  }

  constexpr int kResetValue = -1;
  ASSERT_EQ(setsockopt(sock().get(), opt().level, opt().name, &kResetValue, sizeof(kResetValue)), 0)
      << strerror(errno);

  {
    int get = -1;
    socklen_t get_len = sizeof(get);
    ASSERT_EQ(getsockopt(sock().get(), opt().level, opt().name, &get, &get_len), 0)
        << strerror(errno);
    ASSERT_EQ(get_len, sizeof(get));
    EXPECT_EQ(get, kDefaultTTL);
  }
}

INSTANTIATE_TEST_SUITE_P(
    TtlHopLimitSocketOptionTests, TtlHopLimitSocketOptionTest,
    testing::Combine(testing::Values(SocketDomain::IPv4(), SocketDomain::IPv6()),
                     testing::Values(SocketType::Dgram(), SocketType::Stream()),
                     testing::Values(STRINGIFIED_SOCKOPT(IPPROTO_IP, IP_TTL),
                                     STRINGIFIED_SOCKOPT(IPPROTO_IPV6, IPV6_UNICAST_HOPS))),
    SocketKindAndOptionToString);

// TODO(https://fxbug.dev/90038): Use SocketOptionTestBase for these tests.
class SocketOptsTest : public SocketKindTest {
 protected:
  static bool IsTCP() { return std::get<1>(GetParam()).which() == SocketType::Which::Stream; }

  static bool IsIPv6() { return std::get<0>(GetParam()).which() == SocketDomain::Which::IPv6; }

  static SockOption GetTOSOption() {
    if (IsIPv6()) {
      return {
          .level = IPPROTO_IPV6,
          .option = IPV6_TCLASS,
      };
    }
    return {
        .level = IPPROTO_IP,
        .option = IP_TOS,
    };
  }

  static SockOption GetMcastTTLOption() {
    if (IsIPv6()) {
      return {
          .level = IPPROTO_IPV6,
          .option = IPV6_MULTICAST_HOPS,
      };
    }
    return {
        .level = IPPROTO_IP,
        .option = IP_MULTICAST_TTL,
    };
  }

  static SockOption GetMcastIfOption() {
    if (IsIPv6()) {
      return {
          .level = IPPROTO_IPV6,
          .option = IPV6_MULTICAST_IF,
      };
    }
    return {
        .level = IPPROTO_IP,
        .option = IP_MULTICAST_IF,
    };
  }

  static SockOption GetRecvTOSOption() {
    if (IsIPv6()) {
      return {
          .level = IPPROTO_IPV6,
          .option = IPV6_RECVTCLASS,
      };
    }
    return {
        .level = IPPROTO_IP,
        .option = IP_RECVTOS,
    };
  }

  constexpr static SockOption GetNoChecksum() {
    return {
        .level = SOL_SOCKET,
        .option = SO_NO_CHECK,
    };
  }

  constexpr static SockOption GetTimestamp() {
    return {
        .level = SOL_SOCKET,
        .option = SO_TIMESTAMP,
    };
  }

  constexpr static SockOption GetTimestampNs() {
    return {
        .level = SOL_SOCKET,
        .option = SO_TIMESTAMPNS,
    };
  }
};

TEST_P(SocketOptsTest, ResetTtlToDefault) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  int get1 = -1;
  socklen_t get1_sz = sizeof(get1);
  EXPECT_EQ(getsockopt(s.get(), IPPROTO_IP, IP_TTL, &get1, &get1_sz), 0) << strerror(errno);
  EXPECT_EQ(get1_sz, sizeof(get1));

  int set1 = 100;
  if (set1 == get1) {
    set1 += 1;
  }
  socklen_t set1_sz = sizeof(set1);
  EXPECT_EQ(setsockopt(s.get(), IPPROTO_IP, IP_TTL, &set1, set1_sz), 0) << strerror(errno);

  int set2 = -1;
  socklen_t set2_sz = sizeof(set2);
  EXPECT_EQ(setsockopt(s.get(), IPPROTO_IP, IP_TTL, &set2, set2_sz), 0) << strerror(errno);

  int get2 = -1;
  socklen_t get2_sz = sizeof(get2);
  EXPECT_EQ(getsockopt(s.get(), IPPROTO_IP, IP_TTL, &get2, &get2_sz), 0) << strerror(errno);
  EXPECT_EQ(get2_sz, sizeof(get2));
  EXPECT_EQ(get2, get1);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, NullTOS) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  socklen_t set_sz = sizeof(int);
  SockOption t = GetTOSOption();
  if (IsIPv6()) {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, nullptr, set_sz), 0) << strerror(errno);
  } else {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, nullptr, set_sz), -1);
    EXPECT_EQ(errno, EFAULT) << strerror(errno);
  }
  socklen_t get_sz = sizeof(int);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, nullptr, &get_sz), -1);
  EXPECT_EQ(errno, EFAULT) << strerror(errno);
  int get = -1;
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, nullptr), -1);
  EXPECT_EQ(errno, EFAULT) << strerror(errno);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, InvalidLargeTOS) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  // Test with exceeding the byte space.
  int set = 256;
  constexpr int kDefaultTOS = 0;
  socklen_t set_sz = sizeof(set);
  SockOption t = GetTOSOption();
  if (IsIPv6()) {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), -1);
    EXPECT_EQ(errno, EINVAL) << strerror(errno);
  } else {
    // Linux allows values larger than 255, though it only looks at the char part of the value.
    // https://github.com/torvalds/linux/blob/eec4df26e24/net/ipv4/ip_sockglue.c#L1047
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), 0) << strerror(errno);
  }
  int get = -1;
  socklen_t get_sz = sizeof(get);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
  EXPECT_EQ(get_sz, sizeof(get));
  EXPECT_EQ(get, kDefaultTOS);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, CheckSkipECN) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  int set = 0xFF;
  socklen_t set_sz = sizeof(set);
  SockOption t = GetTOSOption();
  EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), 0) << strerror(errno);
  int expect = static_cast<uint8_t>(set);
  if (IsTCP()
#if !defined(__Fuchsia__)
      // gvisor-netstack`s implemention of setsockopt(..IPV6_TCLASS..)
      // clears the ECN bits from the TCLASS value. This keeps gvisor
      // in parity with the Linux test-hosts that run a custom kernel.
      // But that is not the behavior of vanilla Linux kernels.
      // This #if can be removed when we migrate away from gvisor-netstack.
      && !IsIPv6()
#endif
  ) {
    expect &= ~INET_ECN_MASK;
  }
  int get = -1;
  socklen_t get_sz = sizeof(get);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
  EXPECT_EQ(get_sz, sizeof(get));
  EXPECT_EQ(get, expect);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, ZeroTOSOptionSize) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  int set = 0xC0;
  socklen_t set_sz = 0;
  SockOption t = GetTOSOption();
  if (IsIPv6()) {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), -1);
    EXPECT_EQ(errno, EINVAL) << strerror(errno);
  } else {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), 0) << strerror(errno);
  }
  int get = -1;
  socklen_t get_sz = 0;
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
  EXPECT_EQ(get_sz, 0u);
  EXPECT_EQ(get, -1);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, SmallTOSOptionSize) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  int set = 0xC0;
  constexpr int kDefaultTOS = 0;
  SockOption t = GetTOSOption();
  for (socklen_t i = 1; i < sizeof(int); i++) {
    int expect_tos;
    socklen_t expect_sz;
    if (IsIPv6()) {
      EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, i), -1);
      EXPECT_EQ(errno, EINVAL) << strerror(errno);
      expect_tos = kDefaultTOS;
      expect_sz = i;
    } else {
      EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, i), 0) << strerror(errno);
      expect_tos = set;
      expect_sz = sizeof(uint8_t);
    }
    uint get = -1;
    socklen_t get_sz = i;
    EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
    EXPECT_EQ(get_sz, expect_sz);
    // Account for partial copies by getsockopt, retrieve the lower
    // bits specified by get_sz, while comparing against expect_tos.
    EXPECT_EQ(get & ~(~0u << (get_sz * 8)), static_cast<uint>(expect_tos));
  }
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, LargeTOSOptionSize) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  char buffer[100];
  int* set = reinterpret_cast<int*>(buffer);
  // Point to a larger buffer so that the setsockopt does not overrun.
  *set = 0xC0;
  SockOption t = GetTOSOption();
  for (socklen_t i = sizeof(int); i < 10; i++) {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, set, i), 0) << strerror(errno);
    int get = -1;
    socklen_t get_sz = i;
    // We expect the system call handler to only copy atmost sizeof(int) bytes
    // as asserted by the check below. Hence, we do not expect the copy to
    // overflow in getsockopt.
    EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
    EXPECT_EQ(get_sz, sizeof(int));
    EXPECT_EQ(get, *set);
  }
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, NegativeTOS) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  int set = -1;
  socklen_t set_sz = sizeof(set);
  SockOption t = GetTOSOption();
  EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), 0) << strerror(errno);
  int expect;
  if (IsIPv6()) {
    // On IPv6 TCLASS, setting -1 has the effect of resetting the
    // TrafficClass.
    expect = 0;
  } else {
    expect = static_cast<uint8_t>(set);
    if (IsTCP()) {
      expect &= ~INET_ECN_MASK;
    }
  }
  int get = -1;
  socklen_t get_sz = sizeof(get);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
  EXPECT_EQ(get_sz, sizeof(get));
  EXPECT_EQ(get, expect);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, InvalidNegativeTOS) {
  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  int set = -2;
  socklen_t set_sz = sizeof(set);
  SockOption t = GetTOSOption();
  int expect;
  if (IsIPv6()) {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), -1);
    EXPECT_EQ(errno, EINVAL) << strerror(errno);
    expect = 0;
  } else {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &set, set_sz), 0) << strerror(errno);
    expect = static_cast<uint8_t>(set);
    if (IsTCP()) {
      expect &= ~INET_ECN_MASK;
    }
  }
  int get = 0;
  socklen_t get_sz = sizeof(get);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_sz), 0) << strerror(errno);
  EXPECT_EQ(get_sz, sizeof(get));
  EXPECT_EQ(get, expect);
  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, SetUDPMulticastTTLNegativeOne) {
  if (IsTCP()) {
    GTEST_SKIP() << "Skip multicast tests on TCP socket";
  }

  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  constexpr int kArbitrary = 6;
  SockOption t = GetMcastTTLOption();
  EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &kArbitrary, sizeof(kArbitrary)), 0)
      << strerror(errno);

  constexpr int kNegOne = -1;
  EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &kNegOne, sizeof(kNegOne)), 0)
      << strerror(errno);

  int get = -1;
  socklen_t get_len = sizeof(get);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_len), 0) << strerror(errno);
  EXPECT_EQ(get_len, sizeof(get));
  EXPECT_EQ(get, 1);

  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, SetUDPMulticastIfImrIfindex) {
  if (IsTCP()) {
    GTEST_SKIP() << "Skip multicast tests on TCP socket";
  }

  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  constexpr int kOne = 1;
  SockOption t = GetMcastIfOption();
  if (IsIPv6()) {
    EXPECT_EQ(setsockopt(s.get(), t.level, t.option, &kOne, sizeof(kOne)), 0) << strerror(errno);

    int param_out;
    socklen_t len = sizeof(param_out);
    ASSERT_EQ(getsockopt(s.get(), t.level, t.option, &param_out, &len), 0) << strerror(errno);
    ASSERT_EQ(len, sizeof(param_out));

    ASSERT_EQ(param_out, kOne);
  } else {
    ip_mreqn param_in = {
        .imr_ifindex = kOne,
    };
    ASSERT_EQ(setsockopt(s.get(), t.level, t.option, &param_in, sizeof(param_in)), 0)
        << strerror(errno);

    in_addr param_out;
    socklen_t len = sizeof(param_out);
    ASSERT_EQ(getsockopt(s.get(), t.level, t.option, &param_out, &len), 0) << strerror(errno);
    ASSERT_EQ(len, sizeof(param_out));

    ASSERT_EQ(param_out.s_addr, INADDR_ANY);
  }

  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, SetUDPMulticastIfImrAddress) {
  if (IsTCP()) {
    GTEST_SKIP() << "Skip multicast tests on TCP socket";
  }
  if (IsIPv6()) {
    GTEST_SKIP() << "V6 sockets don't support setting IP_MULTICAST_IF by addr";
  }

  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  SockOption t = GetMcastIfOption();
  ip_mreqn param_in = {
      .imr_address =
          {
              .s_addr = htonl(INADDR_LOOPBACK),
          },
  };
  ASSERT_EQ(setsockopt(s.get(), t.level, t.option, &param_in, sizeof(param_in)), 0)
      << strerror(errno);

  in_addr param_out;
  socklen_t len = sizeof(param_out);
  ASSERT_EQ(getsockopt(s.get(), t.level, t.option, &param_out, &len), 0) << strerror(errno);
  ASSERT_EQ(len, sizeof(param_out));

  ASSERT_EQ(param_out.s_addr, param_in.imr_address.s_addr);

  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

// Tests that a two byte RECVTOS/RECVTCLASS optval is acceptable.
TEST_P(SocketOptsTest, SetReceiveTOSShort) {
  if (IsTCP()) {
    GTEST_SKIP() << "Skip receive TOS tests on TCP socket";
  }

  fbl::unique_fd s;
  ASSERT_TRUE(s = NewSocket()) << strerror(errno);

  constexpr char kSockOptOn2Byte[] = {kSockOptOn, 0};
  constexpr char kSockOptOff2Byte[] = {kSockOptOff, 0};

  SockOption t = GetRecvTOSOption();
  if (IsIPv6()) {
    ASSERT_EQ(setsockopt(s.get(), t.level, t.option, &kSockOptOn2Byte, sizeof(kSockOptOn2Byte)), -1)
        << strerror(errno);
    EXPECT_EQ(errno, EINVAL) << strerror(errno);
  } else {
    ASSERT_EQ(setsockopt(s.get(), t.level, t.option, &kSockOptOn2Byte, sizeof(kSockOptOn2Byte)), 0)
        << strerror(errno);
  }

  int get = -1;
  socklen_t get_len = sizeof(get);
  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_len), 0) << strerror(errno);
  EXPECT_EQ(get_len, sizeof(get));
  if (IsIPv6()) {
    EXPECT_EQ(get, kSockOptOff);
  } else {
    EXPECT_EQ(get, kSockOptOn);
  }

  if (IsIPv6()) {
    ASSERT_EQ(setsockopt(s.get(), t.level, t.option, &kSockOptOff2Byte, sizeof(kSockOptOff2Byte)),
              -1)
        << strerror(errno);
    EXPECT_EQ(errno, EINVAL) << strerror(errno);
  } else {
    ASSERT_EQ(setsockopt(s.get(), t.level, t.option, &kSockOptOff2Byte, sizeof(kSockOptOff2Byte)),
              0)
        << strerror(errno);
  }

  EXPECT_EQ(getsockopt(s.get(), t.level, t.option, &get, &get_len), 0) << strerror(errno);
  EXPECT_EQ(get_len, sizeof(get));
  EXPECT_EQ(get, kSockOptOff);

  EXPECT_EQ(close(s.release()), 0) << strerror(errno);
}

TEST_P(SocketOptsTest, UpdateAnyTimestampDisablesOtherTimestampOptions) {
  constexpr std::pair<SockOption, const char*> kOpts[] = {
      std::make_pair(GetTimestamp(), "SO_TIMESTAMP"),
      std::make_pair(GetTimestampNs(), "SO_TIMESTAMPNS"),
  };
  constexpr int optvals[] = {kSockOptOff, kSockOptOn};

  for (const auto& [opt_to_enable, opt_to_enable_name] : kOpts) {
    SCOPED_TRACE("Enable option " + std::string(opt_to_enable_name));
    for (const auto& [opt_to_update, opt_to_update_name] : kOpts) {
      SCOPED_TRACE("Update option " + std::string(opt_to_update_name));
      if (opt_to_enable == opt_to_update) {
        continue;
      }
      for (const int optval : optvals) {
        SCOPED_TRACE("Update value " + std::to_string(optval));
        fbl::unique_fd s;
        ASSERT_TRUE(s = NewSocket()) << strerror(errno);

        ASSERT_EQ(setsockopt(s.get(), opt_to_enable.level, opt_to_enable.option, &kSockOptOn,
                             sizeof(kSockOptOn)),
                  0)
            << strerror(errno);
        {
          int get = -1;
          socklen_t get_len = sizeof(get);
          ASSERT_EQ(getsockopt(s.get(), opt_to_enable.level, opt_to_enable.option, &get, &get_len),
                    0)
              << strerror(errno);
          EXPECT_EQ(get_len, sizeof(get));
          EXPECT_EQ(get, kSockOptOn);
        }

        ASSERT_EQ(
            setsockopt(s.get(), opt_to_update.level, opt_to_update.option, &optval, sizeof(optval)),
            0)
            << strerror(errno);
        {
          int get = -1;
          socklen_t get_len = sizeof(get);
          ASSERT_EQ(getsockopt(s.get(), opt_to_update.level, opt_to_update.option, &get, &get_len),
                    0)
              << strerror(errno);
          EXPECT_EQ(get_len, sizeof(get));
          EXPECT_EQ(get, optval);
        }

        // The initially enabled option should be disabled after the mutually exclusive option is
        // updated.
        {
          int get = -1;
          socklen_t get_len = sizeof(get);
          ASSERT_EQ(getsockopt(s.get(), opt_to_enable.level, opt_to_enable.option, &get, &get_len),
                    0)
              << strerror(errno);
          EXPECT_EQ(get_len, sizeof(get));
          EXPECT_EQ(get, kSockOptOff);
        }

        EXPECT_EQ(close(s.release()), 0) << strerror(errno);
      }
    }
  }
}

INSTANTIATE_TEST_SUITE_P(
    LocalhostTest, SocketOptsTest,
    testing::Combine(testing::Values(SocketDomain::IPv4(), SocketDomain::IPv6()),
                     testing::Values(SocketType::Dgram(), SocketType::Stream())),
    SocketKindToString);

using TypeMulticast = std::tuple<SocketType, bool>;

std::string TypeMulticastToString(const testing::TestParamInfo<TypeMulticast>& info) {
  auto const& [type, multicast] = info.param;
  std::ostringstream oss;
  oss << socketTypeToString(type);
  if (multicast) {
    oss << "Multicast";
  } else {
    oss << "Loopback";
  }
  return oss.str();
}

class ReuseTest : public testing::TestWithParam<TypeMulticast> {};

TEST_P(ReuseTest, AllowsAddressReuse) {
  const int on = true;
  auto const& [type, multicast] = GetParam();

#if defined(__Fuchsia__)
  if (multicast && type.which() == SocketType::Which::Stream) {
    GTEST_SKIP() << "Cannot bind a TCP socket to a multicast address on Fuchsia";
  }
#endif

  sockaddr_in addr = LoopbackSockaddrV4(0);
  if (multicast) {
    int n = inet_pton(addr.sin_family, "224.0.2.1", &addr.sin_addr);
    ASSERT_GE(n, 0) << strerror(errno);
    ASSERT_EQ(n, 1);
  }

  fbl::unique_fd s1;
  ASSERT_TRUE(s1 = fbl::unique_fd(socket(AF_INET, type.Get(), 0))) << strerror(errno);

// TODO(https://gvisor.dev/issue/3839): Remove this.
#if defined(__Fuchsia__)
  // Must outlive the block below.
  fbl::unique_fd s;
  if (type.which() != SocketType::Which::Dgram && multicast) {
    ASSERT_EQ(bind(s1.get(), reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)), -1);
    ASSERT_EQ(errno, EADDRNOTAVAIL) << strerror(errno);
    ASSERT_TRUE(s = fbl::unique_fd(socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP))) << strerror(errno);
    ip_mreqn param = {
        .imr_multiaddr = addr.sin_addr,
        .imr_address =
            {
                .s_addr = htonl(INADDR_ANY),
            },
        .imr_ifindex = 1,
    };
    ASSERT_EQ(setsockopt(s.get(), SOL_IP, IP_ADD_MEMBERSHIP, &param, sizeof(param)), 0)
        << strerror(errno);
  }
#endif

  ASSERT_EQ(setsockopt(s1.get(), SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on)), 0) << strerror(errno);
  ASSERT_EQ(bind(s1.get(), reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)), 0)
      << strerror(errno);

  socklen_t addrlen = sizeof(addr);
  ASSERT_EQ(getsockname(s1.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), 0)
      << strerror(errno);
  ASSERT_EQ(addrlen, sizeof(addr));

  fbl::unique_fd s2;
  ASSERT_TRUE(s2 = fbl::unique_fd(socket(AF_INET, type.Get(), 0))) << strerror(errno);
  ASSERT_EQ(setsockopt(s2.get(), SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on)), 0) << strerror(errno);
  ASSERT_EQ(bind(s2.get(), reinterpret_cast<const sockaddr*>(&addr), sizeof(addr)), 0)
      << strerror(errno);
}

INSTANTIATE_TEST_SUITE_P(LocalhostTest, ReuseTest,
                         testing::Combine(testing::Values(SocketType::Dgram(),
                                                          SocketType::Stream()),
                                          testing::Values(false, true)),
                         TypeMulticastToString);

class AddrKind {
 public:
  enum class Kind {
    V4,
    V6,
    V4MAPPEDV6,
  };

  explicit AddrKind(Kind kind) : kind_(kind) {}
  Kind Kind() const { return kind_; }

  constexpr const char* AddrKindToString() const {
    switch (kind_) {
      case Kind::V4:
        return "V4";
      case Kind::V6:
        return "V6";
      case Kind::V4MAPPEDV6:
        return "V4MAPPEDV6";
    }
  }

 private:
  const enum Kind kind_;
};

template <int socktype>
class SocketTest : public testing::TestWithParam<AddrKind> {
 protected:
  void SetUp() override {
    ASSERT_TRUE(sock_ = fbl::unique_fd(socket(Domain().Get(), socktype, 0))) << strerror(errno);
  }

  void TearDown() override { EXPECT_EQ(close(sock_.release()), 0) << strerror(errno); }

  const fbl::unique_fd& sock() { return sock_; }

  SocketDomain Domain() const {
    switch (GetParam().Kind()) {
      case AddrKind::Kind::V4:
        return SocketDomain::IPv4();
      case AddrKind::Kind::V6:
      case AddrKind::Kind::V4MAPPEDV6:
        return SocketDomain::IPv6();
    }
  }

  socklen_t AddrLen() const {
    if (Domain().which() == SocketDomain::Which::IPv4) {
      return sizeof(sockaddr_in);
    }
    return sizeof(sockaddr_in6);
  }

  virtual sockaddr_storage Address(uint16_t port) const = 0;

 private:
  fbl::unique_fd sock_;
};

template <int socktype>
class AnyAddrSocketTest : public SocketTest<socktype> {
 protected:
  sockaddr_storage Address(uint16_t port) const override {
    sockaddr_storage addr{
        .ss_family = this->Domain().Get(),
    };

    switch (this->GetParam().Kind()) {
      case AddrKind::Kind::V4: {
        auto sin = reinterpret_cast<sockaddr_in*>(&addr);
        sin->sin_addr.s_addr = htonl(INADDR_ANY);
        sin->sin_port = port;
        return addr;
      }
      case AddrKind::Kind::V6: {
        auto sin6 = reinterpret_cast<sockaddr_in6*>(&addr);
        sin6->sin6_addr = IN6ADDR_ANY_INIT;
        sin6->sin6_port = port;
        return addr;
      }
      case AddrKind::Kind::V4MAPPEDV6: {
        sockaddr_in v4_addr{
            .sin_port = port,
            .sin_addr =
                {
                    .s_addr = htonl(INADDR_ANY),
                },
        };
        *reinterpret_cast<sockaddr_in6*>(&addr) = MapIpv4SockaddrToIpv6Sockaddr(v4_addr);
        return addr;
      }
    }
  }
};

using AnyAddrStreamSocketTest = AnyAddrSocketTest<SOCK_STREAM>;
using AnyAddrDatagramSocketTest = AnyAddrSocketTest<SOCK_DGRAM>;

TEST_P(AnyAddrStreamSocketTest, Connect) {
  sockaddr_storage any = Address(0);
  socklen_t addrlen = AddrLen();
  ASSERT_EQ(connect(sock().get(), reinterpret_cast<const sockaddr*>(&any), addrlen), -1);
  ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno);

  // The error should have been consumed.
  int err;
  socklen_t optlen = sizeof(err);
  ASSERT_EQ(getsockopt(sock().get(), SOL_SOCKET, SO_ERROR, &err, &optlen), 0) << strerror(errno);
  ASSERT_EQ(optlen, sizeof(err));
  ASSERT_EQ(err, 0) << strerror(err);
}

TEST_P(AnyAddrDatagramSocketTest, Connect) {
  sockaddr_storage any = Address(0);
  socklen_t addrlen = AddrLen();
  EXPECT_EQ(connect(sock().get(), reinterpret_cast<const sockaddr*>(&any), addrlen), 0)
      << strerror(errno);
}

INSTANTIATE_TEST_SUITE_P(AnyAddrSocketTestStream, AnyAddrStreamSocketTest,
                         testing::Values(AddrKind::Kind::V4, AddrKind::Kind::V6,
                                         AddrKind::Kind::V4MAPPEDV6),
                         [](const auto info) { return info.param.AddrKindToString(); });
INSTANTIATE_TEST_SUITE_P(AnyAddrSocketTestDatagram, AnyAddrDatagramSocketTest,
                         testing::Values(AddrKind::Kind::V4, AddrKind::Kind::V6,
                                         AddrKind::Kind::V4MAPPEDV6),
                         [](const auto info) { return info.param.AddrKindToString(); });

enum class ShutdownEnd {
  Local,
  Remote,
};

enum class ExpectedPostShutdownReadResult {
  Success,
  Eagain,
};

enum class ReadType {
  Blocking,
  NonBlocking,
};

enum class ReadSocketState {
  WithPendingData,
  NoPendingData,
};

using ReadAfterShutdownTestCase =
    std::tuple<SocketDomain, SocketType, ShutdownEnd, ShutdownType, ReadType, ReadSocketState>;

ExpectedPostShutdownReadResult GetExpectedPostShutdownReadResult(
    const ReadAfterShutdownTestCase& test_case) {
  const auto& [domain, socket_type, which_end, shutdown_type, read_type, read_socket_state] =
      test_case;

  if (read_socket_state == ReadSocketState::WithPendingData) {
    // Post-shutdown reads always return pending data if it is present.
    return ExpectedPostShutdownReadResult::Success;
  }

  switch (socket_type.which()) {
    case SocketType::Which::Stream:
      if ((which_end == ShutdownEnd::Local && shutdown_type.which() == ShutdownType::Which::Read) ||
          (which_end == ShutdownEnd::Remote &&
           shutdown_type.which() == ShutdownType::Which::Write)) {
        return ExpectedPostShutdownReadResult::Success;
      }
      return ExpectedPostShutdownReadResult::Eagain;
    case SocketType::Which::Dgram:
      if (which_end == ShutdownEnd::Local && shutdown_type.which() == ShutdownType::Which::Read &&
          read_type == ReadType::Blocking) {
        return ExpectedPostShutdownReadResult::Success;
      }
      return ExpectedPostShutdownReadResult::Eagain;
      break;
  }
}

class ReadAfterShutdownTest : public testing::TestWithParam<ReadAfterShutdownTestCase> {};

TEST_P(ReadAfterShutdownTest, Success) {
  const auto& [domain, socket_type, which_end, shutdown_type, read_type, read_socket_state] =
      GetParam();

#ifdef __Fuchsia__
  if (socket_type.which() == SocketType::Which::Dgram && read_type == ReadType::Blocking &&
      shutdown_type.which() == ShutdownType::Which::Read && which_end == ShutdownEnd::Local &&
      read_socket_state == ReadSocketState::NoPendingData) {
    // TODO(https://fxbug.dev/42041): Support blocking reads after shutdown for dgram sockets.
    GTEST_SKIP() << "Blocking dgram reads with no pending data hang on Fuchsia when the socket "
                    "is shutdown with SHUT_RD";
  }
#endif

  fbl::unique_fd remote;
  fbl::unique_fd local;
  ASSERT_NO_FATAL_FAILURE(ConnectSocketsOverLoopback(domain, socket_type, remote, local));

  char buf[] = "abc";
  if (read_socket_state == ReadSocketState::WithPendingData) {
    ASSERT_EQ(write(remote.get(), &buf, sizeof(buf)), ssize_t(sizeof(buf))) << strerror(errno);
    pollfd pfd = {
        .fd = local.get(),
        .events = POLLIN,
    };
    int n = poll(&pfd, 1, std::chrono::milliseconds(kTimeout).count());
    ASSERT_GE(n, 0) << strerror(errno);
    ASSERT_EQ(n, 1);
    EXPECT_EQ(pfd.revents, POLLIN);
  }

  int shutdown_fd = [&, which_end = which_end]() {
    switch (which_end) {
      case ShutdownEnd::Local:
        return local.get();
      case ShutdownEnd::Remote:
        return remote.get();
    }
  }();

  EXPECT_EQ(shutdown(shutdown_fd, shutdown_type.Get()), 0) << strerror(errno);

  if (socket_type.which() == SocketType::Which::Stream && which_end == ShutdownEnd::Remote &&
      shutdown_type.which() == ShutdownType::Which::Write) {
    // Give the TCP FIN time to propagate from `remote` to `local`.
    pollfd pfd = {
        .fd = local.get(),
        .events = POLLRDHUP,
    };
    int n = poll(&pfd, 1, std::chrono::milliseconds(kTimeout).count());
    ASSERT_GE(n, 0) << strerror(errno);
    ASSERT_EQ(n, 1);
    EXPECT_EQ(pfd.revents, POLLRDHUP);
  }

  char recv_buf[sizeof(buf) + 1];
  const int flags = read_type == ReadType::Blocking ? 0 : MSG_DONTWAIT;
  switch (GetExpectedPostShutdownReadResult(GetParam())) {
    case ExpectedPostShutdownReadResult::Success: {
      switch (read_socket_state) {
        case ReadSocketState::WithPendingData:
          EXPECT_EQ(recv(local.get(), &recv_buf, sizeof(recv_buf), flags), ssize_t(sizeof(buf)))
              << strerror(errno);
          EXPECT_EQ(std::string_view(recv_buf, sizeof(buf)), std::string_view(buf, sizeof(buf)));
          break;
        case ReadSocketState::NoPendingData:
          EXPECT_EQ(recv(local.get(), &recv_buf, sizeof(recv_buf), flags), 0) << strerror(errno);
          break;
      }
    } break;
    case ExpectedPostShutdownReadResult::Eagain: {
      switch (read_type) {
        case ReadType::Blocking: {
          timeval tv = {
              .tv_sec = 1,
          };
          EXPECT_EQ(setsockopt(local.get(), SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)), 0)
              << strerror(errno);

          std::latch fut_started(1);
          const auto fut = std::async(std::launch::async, [&]() {
            fut_started.count_down();

            EXPECT_EQ(recv(local.get(), &recv_buf, sizeof(recv_buf), flags), -1) << strerror(errno);
            EXPECT_EQ(errno, EAGAIN);
          });
          fut_started.wait();
          ASSERT_NO_FATAL_FAILURE(AssertBlocked(fut));
        } break;
        case ReadType::NonBlocking:
          EXPECT_EQ(recv(local.get(), &recv_buf, sizeof(recv_buf), flags), -1) << strerror(errno);
          EXPECT_EQ(errno, EAGAIN);
          break;
      }
    } break;
  }
}

std::string ReadAfterShutdownTestCaseToString(
    const testing::TestParamInfo<ReadAfterShutdownTestCase>& info) {
  const auto& [domain, socket_type, which_end, shutdown_type, read_type, read_socket_state] =
      info.param;
  std::ostringstream oss;
  oss << socketDomainToString(domain);
  oss << '_' << socketTypeToString(socket_type);

  switch (which_end) {
    case ShutdownEnd::Local:
      oss << '_' << "Self";
      break;
    case ShutdownEnd::Remote:
      oss << '_' << "Peer";
      break;
  }

  switch (shutdown_type.which()) {
    case ShutdownType::Which::Read:
      oss << '_' << "SHUT_RD";
      break;
    case ShutdownType::Which::Write:
      oss << '_' << "SHUT_WR";
      break;
  }

  switch (read_type) {
    case ReadType::Blocking:
      oss << '_' << "BlockingRead";
      break;
    case ReadType::NonBlocking:
      oss << '_' << "NonBlockingRead";
      break;
  }

  switch (read_socket_state) {
    case ReadSocketState::WithPendingData:
      oss << '_' << "WithPendingData";
      break;
    case ReadSocketState::NoPendingData:
      oss << '_' << "NoPendingData";
      break;
  }

  return oss.str();
}

INSTANTIATE_TEST_SUITE_P(
    ReadAfterShutdownTests, ReadAfterShutdownTest,
    testing::Combine(testing::Values(SocketDomain::IPv4(), SocketDomain::IPv6()),
                     testing::Values(SocketType::Dgram(), SocketType::Stream()),
                     testing::Values(ShutdownEnd::Remote, ShutdownEnd::Local),
                     testing::Values(ShutdownType::Read(), ShutdownType::Write()),
                     testing::Values(ReadType::Blocking, ReadType::NonBlocking),
                     testing::Values(ReadSocketState::WithPendingData,
                                     ReadSocketState::NoPendingData)),
    ReadAfterShutdownTestCaseToString);

// Socket tests across multiple socket-types, SOCK_DGRAM, SOCK_STREAM.
class NetSocketTest : public testing::TestWithParam<SocketType> {};

// Test MSG_PEEK
// MSG_PEEK : Peek into the socket receive queue without moving the contents from it.
//
// TODO(https://fxbug.dev/90876): change this test to use recvmsg instead of recvfrom to exercise
// MSG_PEEK with scatter/gather.
TEST_P(NetSocketTest, SocketPeekTest) {
  const SocketType socket_type = GetParam();
  ssize_t expectReadLen = 0;
  char sendbuf[8] = {};
  char recvbuf[2 * sizeof(sendbuf)] = {};
  ssize_t sendlen = sizeof(sendbuf);

  fbl::unique_fd sendfd;
  fbl::unique_fd recvfd;
  ASSERT_NO_FATAL_FAILURE(
      ConnectSocketsOverLoopback(SocketDomain::IPv4(), socket_type, sendfd, recvfd));

  switch (socket_type.which()) {
    case SocketType::Which::Stream: {
      // Expect to read both the packets in a single recv() call.
      expectReadLen = sizeof(recvbuf);
      break;
    }
    case SocketType::Which::Dgram: {
      // Expect to read single packet per recv() call.
      expectReadLen = sizeof(sendbuf);
      break;
    }
  }

  // This test sends 2 packets with known values and validates MSG_PEEK across the 2 packets.
  sendbuf[0] = 0x56;
  sendbuf[6] = 0x78;

  // send 2 separate packets and test peeking across
  EXPECT_EQ(send(sendfd.get(), sendbuf, sizeof(sendbuf), 0), sendlen) << strerror(errno);
  EXPECT_EQ(send(sendfd.get(), sendbuf, sizeof(sendbuf), 0), sendlen) << strerror(errno);

  auto start = std::chrono::steady_clock::now();
  // First peek on first byte.
  EXPECT_EQ(asyncSocketRead(recvfd.get(), sendfd.get(), recvbuf, 1, MSG_PEEK, socket_type,
                            SocketDomain::IPv4(), kTimeout),
            1);
  auto success_rcv_duration = std::chrono::steady_clock::now() - start;
  EXPECT_EQ(recvbuf[0], sendbuf[0]);

  // Second peek across first 2 packets and drain them from the socket receive queue.
  ssize_t torecv = sizeof(recvbuf);
  for (int i = 0; torecv > 0; i++) {
    int flags = i % 2 ? 0 : MSG_PEEK;
    ssize_t readLen = 0;
    // Retry socket read with MSG_PEEK to ensure all of the expected data is received.
    //
    // TODO(https://fxbug.dev/74639) : Use SO_RCVLOWAT instead of retry.
    do {
      readLen = asyncSocketRead(recvfd.get(), sendfd.get(), recvbuf, sizeof(recvbuf), flags,
                                socket_type, SocketDomain::IPv4(), kTimeout);
      if (HasFailure()) {
        break;
      }
    } while (flags == MSG_PEEK && readLen < expectReadLen);
    EXPECT_EQ(readLen, expectReadLen);

    EXPECT_EQ(recvbuf[0], sendbuf[0]);
    EXPECT_EQ(recvbuf[6], sendbuf[6]);
    // For SOCK_STREAM, we validate peek across 2 packets with a single recv call.
    if (readLen == sizeof(recvbuf)) {
      EXPECT_EQ(recvbuf[8], sendbuf[0]);
      EXPECT_EQ(recvbuf[14], sendbuf[6]);
    }
    if (flags != MSG_PEEK) {
      torecv -= readLen;
    }
  }

  // Third peek on empty socket receive buffer, expect failure.
  //
  // As we expect failure, to keep the recv wait time minimal, we base it on the time taken for a
  // successful recv.
  EXPECT_EQ(asyncSocketRead(recvfd.get(), sendfd.get(), recvbuf, 1, MSG_PEEK, socket_type,
                            SocketDomain::IPv4(), success_rcv_duration * 10),
            0);
  EXPECT_EQ(close(recvfd.release()), 0) << strerror(errno);
  EXPECT_EQ(close(sendfd.release()), 0) << strerror(errno);
}

INSTANTIATE_TEST_SUITE_P(NetSocket, NetSocketTest,
                         testing::Values(SocketType::Dgram(), SocketType::Stream()));

TEST_P(SocketKindTest, IoctlInterfaceLookupRoundTrip) {
  fbl::unique_fd fd;
  ASSERT_TRUE(fd = NewSocket()) << strerror(errno);

  // This test assumes index 1 is bound to a valid interface. In Fuchsia's test environment (the
  // component executing this test), 1 is always bound to "lo".
  ifreq ifr_iton = {};
  ifr_iton.ifr_ifindex = 1;
  // Set ifr_name to random chars to test ioctl correctly sets null terminator.
  memset(ifr_iton.ifr_name, 0xde, IFNAMSIZ);
  ASSERT_EQ(strnlen(ifr_iton.ifr_name, IFNAMSIZ), (size_t)IFNAMSIZ);
  ASSERT_EQ(ioctl(fd.get(), SIOCGIFNAME, &ifr_iton), 0) << strerror(errno);
  ASSERT_LT(strnlen(ifr_iton.ifr_name, IFNAMSIZ), (size_t)IFNAMSIZ);

  ifreq ifr_ntoi;
  strncpy(ifr_ntoi.ifr_name, ifr_iton.ifr_name, IFNAMSIZ);
  ASSERT_EQ(ioctl(fd.get(), SIOCGIFINDEX, &ifr_ntoi), 0) << strerror(errno);
  EXPECT_EQ(ifr_ntoi.ifr_ifindex, 1);

  ifreq ifr_err;
  memset(ifr_err.ifr_name, 0xde, IFNAMSIZ);
  // Although the first few bytes of ifr_name contain the correct name, there is no null
  // terminator and the remaining bytes are gibberish, should match no interfaces.
  memcpy(ifr_err.ifr_name, ifr_iton.ifr_name, strnlen(ifr_iton.ifr_name, IFNAMSIZ));

  const struct {
    std::string name;
    int request;
  } requests[] = {
      {
          .name = "SIOCGIFINDEX",
          .request = SIOCGIFINDEX,
      },
      {
          .name = "SIOCGIFFLAGS",
          .request = SIOCGIFFLAGS,
      },
  };
  for (const auto& request : requests) {
    ASSERT_EQ(ioctl(fd.get(), request.request, &ifr_err), -1) << request.name;
    EXPECT_EQ(errno, ENODEV) << request.name << ": " << strerror(errno);
  }
}

TEST_P(SocketKindTest, IoctlFIONREAD) {
  auto const& [domain, socket_type] = GetParam();

  fbl::unique_fd recvfd;
  fbl::unique_fd sendfd;
  ASSERT_NO_FATAL_FAILURE(ConnectSocketsOverLoopback(domain, socket_type, sendfd, recvfd));

  char sendbuf[1];
  EXPECT_EQ(send(sendfd.get(), sendbuf, sizeof(sendbuf), 0), ssize_t(sizeof(sendbuf)))
      << strerror(errno);

  pollfd pfd = {
      .fd = recvfd.get(),
      .events = POLLIN,
  };
  int n = poll(&pfd, 1, std::chrono::milliseconds(kTimeout).count());
  ASSERT_GE(n, 0) << strerror(errno);
  ASSERT_EQ(n, 1);

  int num_readable;
  int res = ioctl(recvfd.get(), FIONREAD, &num_readable);

#ifdef __Fuchsia__
  if (socket_type.which() == SocketType::Which::Dgram) {
    // TODO(https://fxbug.dev/42040): Support FIONREAD on Fuchsia.
    ASSERT_EQ(res, -1);
    EXPECT_EQ(errno, ENOTTY) << strerror(errno);
    return;
  }
#endif

  ASSERT_EQ(res, 0) << strerror(errno);
  ASSERT_GE(num_readable, 0);
  EXPECT_EQ(static_cast<size_t>(num_readable), sizeof(sendbuf));
}

TEST_P(SocketKindTest, IoctlInterfaceNotFound) {
  fbl::unique_fd fd;
  ASSERT_TRUE(fd = NewSocket()) << strerror(errno);

  // Invalid ifindex "-1" should match no interfaces.
  ifreq ifr_iton = {};
  ifr_iton.ifr_ifindex = -1;
  ASSERT_EQ(ioctl(fd.get(), SIOCGIFNAME, &ifr_iton), -1);
  EXPECT_EQ(errno, ENODEV) << strerror(errno);

  // Empty name should match no interface.
  ifreq ifr = {};
  const struct {
    std::string name;
    int request;
  } requests[] = {
      {
          .name = "SIOCGIFINDEX",
          .request = SIOCGIFINDEX,
      },
      {
          .name = "SIOCGIFFLAGS",
          .request = SIOCGIFFLAGS,
      },
  };
  for (const auto& request : requests) {
    ASSERT_EQ(ioctl(fd.get(), request.request, &ifr), -1) << request.name;
    EXPECT_EQ(errno, ENODEV) << request.name << ": " << strerror(errno);
  }
}

template <typename F>
void TestGetname(const fbl::unique_fd& fd, F getname, const sockaddr* sa, const socklen_t sa_len) {
  ASSERT_EQ(getname(fd.get(), nullptr, nullptr), -1);
  EXPECT_EQ(errno, EFAULT) << strerror(errno);
  errno = 0;

  sockaddr_storage ss;
  ASSERT_EQ(getname(fd.get(), reinterpret_cast<sockaddr*>(&ss), nullptr), -1);
  EXPECT_EQ(errno, EFAULT) << strerror(errno);
  errno = 0;

  socklen_t len = 0;
  ASSERT_EQ(getname(fd.get(), nullptr, &len), 0) << strerror(errno);
  EXPECT_EQ(len, sa_len);

  len = 1;
  ASSERT_EQ(getname(fd.get(), nullptr, &len), -1);
  EXPECT_EQ(errno, EFAULT) << strerror(errno);
  EXPECT_EQ(len, 1u);
  errno = 0;

  sa_family_t family;
  len = sizeof(family);
  ASSERT_EQ(getname(fd.get(), reinterpret_cast<sockaddr*>(&family), &len), 0) << strerror(errno);
  ASSERT_EQ(len, sa_len);
  EXPECT_EQ(family, sa->sa_family);

  len = sa_len;
  ASSERT_EQ(getname(fd.get(), reinterpret_cast<sockaddr*>(&ss), &len), 0) << strerror(errno);
  ASSERT_EQ(len, sa_len);
  EXPECT_EQ(memcmp(&ss, sa, sa_len), 0);

  struct {
    sockaddr_storage ss;
    char unused;
  } ss_with_extra = {
      .unused = 0x44,
  };
  len = sizeof(ss_with_extra);
  ASSERT_EQ(getname(fd.get(), reinterpret_cast<sockaddr*>(&ss_with_extra), &len), 0)
      << strerror(errno);
  ASSERT_EQ(len, sa_len);
  EXPECT_EQ(memcmp(&ss, sa, sa_len), 0);
  EXPECT_EQ(ss_with_extra.unused, 0x44);
}

TEST_P(SocketKindTest, Getsockname) {
  auto [ss, len] = LoopbackAddr();

  fbl::unique_fd fd;
  ASSERT_TRUE(fd = NewSocket()) << strerror(errno);

  ASSERT_EQ(bind(fd.get(), reinterpret_cast<sockaddr*>(&ss), sizeof(ss)), 0) << strerror(errno);
  socklen_t ss_len = sizeof(ss);
  // Get the socket's local address so TestGetname can compare against it.
  ASSERT_EQ(getsockname(fd.get(), reinterpret_cast<sockaddr*>(&ss), &ss_len), 0) << strerror(errno);
  ASSERT_EQ(ss_len, len);

  ASSERT_NO_FATAL_FAILURE(TestGetname(fd, getsockname, reinterpret_cast<sockaddr*>(&ss), len));
}

TEST_P(SocketKindTest, Getpeername) {
  auto const& [domain, protocol] = GetParam();
  auto [ss, len] = LoopbackAddr();

  fbl::unique_fd listener;
  ASSERT_TRUE(listener = NewSocket()) << strerror(errno);
  ASSERT_EQ(bind(listener.get(), reinterpret_cast<sockaddr*>(&ss), sizeof(ss)), 0)
      << strerror(errno);
  socklen_t ss_len = sizeof(ss);
  ASSERT_EQ(getsockname(listener.get(), reinterpret_cast<sockaddr*>(&ss), &ss_len), 0)
      << strerror(errno);
  if (protocol.which() == SocketType::Which::Stream) {
    ASSERT_EQ(listen(listener.get(), 1), 0) << strerror(errno);
  }

  fbl::unique_fd client;
  ASSERT_TRUE(client = NewSocket()) << strerror(errno);
  ASSERT_EQ(connect(client.get(), reinterpret_cast<sockaddr*>(&ss), sizeof(ss)), 0)
      << strerror(errno);

  ASSERT_NO_FATAL_FAILURE(TestGetname(client, getpeername, reinterpret_cast<sockaddr*>(&ss), len));
}

TEST(SocketKindTest, IoctlLookupForNonSocketFd) {
  fbl::unique_fd fd;
  ASSERT_TRUE(fd = fbl::unique_fd(open("/", O_RDONLY | O_DIRECTORY))) << strerror(errno);

  ifreq ifr_iton = {};
  ifr_iton.ifr_ifindex = 1;
  ASSERT_EQ(ioctl(fd.get(), SIOCGIFNAME, &ifr_iton), -1);
  EXPECT_EQ(errno, ENOTTY) << strerror(errno);

  ifreq ifr;
  strcpy(ifr.ifr_name, "loblah");
  const struct {
    std::string name;
    int request;
  } requests[] = {
      {
          .name = "SIOCGIFINDEX",
          .request = SIOCGIFINDEX,
      },
      {
          .name = "SIOCGIFFLAGS",
          .request = SIOCGIFFLAGS,
      },
  };
  for (const auto& request : requests) {
    ASSERT_EQ(ioctl(fd.get(), request.request, &ifr), -1) << request.name;
    EXPECT_EQ(errno, ENOTTY) << request.name << ": " << strerror(errno);
  }
}

INSTANTIATE_TEST_SUITE_P(
    NetSocket, SocketKindTest,
    testing::Combine(testing::Values(SocketDomain::IPv4(), SocketDomain::IPv6()),
                     testing::Values(SocketType::Dgram(), SocketType::Stream())),
    SocketKindToString);

using DomainProtocol = std::tuple<SocketDomain, int>;
class IcmpSocketTest : public testing::TestWithParam<DomainProtocol> {
 protected:
  void SetUp() override {
#if !defined(__Fuchsia__)
    if (!IsRoot()) {
      GTEST_SKIP() << "This test requires root";
    }
#endif
    auto const& [domain, protocol] = GetParam();
    ASSERT_TRUE(fd_ = fbl::unique_fd(socket(domain.Get(), SOCK_DGRAM, protocol)))
        << strerror(errno);
  }

  const fbl::unique_fd& fd() const { return fd_; }

 private:
  fbl::unique_fd fd_;
};

TEST_P(IcmpSocketTest, GetSockoptSoProtocol) {
  auto const& [domain, protocol] = GetParam();

  int opt;
  socklen_t optlen = sizeof(opt);
  ASSERT_EQ(getsockopt(fd().get(), SOL_SOCKET, SO_PROTOCOL, &opt, &optlen), 0) << strerror(errno);
  EXPECT_EQ(optlen, sizeof(opt));
  EXPECT_EQ(opt, protocol);
}

TEST_P(IcmpSocketTest, PayloadIdentIgnored) {
  auto const& [domain, protocol] = GetParam();

  constexpr short kBindIdent = 3;
  constexpr short kDestinationIdent = kBindIdent + 1;

  switch (domain.which()) {
    case SocketDomain::Which::IPv4: {
      const sockaddr_in bind_addr = LoopbackSockaddrV4(kBindIdent);
      ASSERT_EQ(bind(fd().get(), reinterpret_cast<const sockaddr*>(&bind_addr), sizeof(bind_addr)),
                0)
          << strerror(errno);
      const icmphdr pkt = []() {
        icmphdr pkt;
        // Populate with garbage to prove other fields are unused.
        memset(&pkt, 0x4a, sizeof(pkt));
        pkt.type = ICMP_ECHO;
        pkt.code = 0;
        return pkt;
      }();
      const sockaddr_in dst_addr = {
          .sin_family = bind_addr.sin_family,
          .sin_port = htons(kDestinationIdent),
          .sin_addr = bind_addr.sin_addr,
      };
      ASSERT_EQ(sendto(fd().get(), &pkt, sizeof(pkt), 0,
                       reinterpret_cast<const sockaddr*>(&dst_addr), sizeof(dst_addr)),
                ssize_t(sizeof(pkt)))
          << strerror(errno);

      struct {
        std::remove_const<decltype(pkt)>::type hdr;
        char unused;
      } hdr_with_extra = {
          .unused = 0x44,
      };
      memset(&hdr_with_extra.hdr, 0x4a, sizeof(hdr_with_extra.hdr));
      ASSERT_EQ(read(fd().get(), &hdr_with_extra, sizeof(hdr_with_extra)), ssize_t(sizeof(pkt)))
          << strerror(errno);
      EXPECT_EQ(hdr_with_extra.unused, 0x44);
      EXPECT_EQ(hdr_with_extra.hdr.type, 0);
      EXPECT_EQ(hdr_with_extra.hdr.code, 0);
      EXPECT_NE(hdr_with_extra.hdr.checksum, 0);
      EXPECT_EQ(htons(hdr_with_extra.hdr.un.echo.id), kBindIdent);
      EXPECT_EQ(hdr_with_extra.hdr.un.echo.sequence, pkt.un.echo.sequence);
    } break;
    case SocketDomain::Which::IPv6: {
      const sockaddr_in6 bind_addr = LoopbackSockaddrV6(kBindIdent);
      ASSERT_EQ(bind(fd().get(), reinterpret_cast<const sockaddr*>(&bind_addr), sizeof(bind_addr)),
                0)
          << strerror(errno);
      const icmp6_hdr pkt = []() {
        icmp6_hdr pkt;
        // Populate with garbage to prove other fields are unused.
        memset(&pkt, 0x4a, sizeof(pkt));
        pkt.icmp6_type = ICMP6_ECHO_REQUEST;
        pkt.icmp6_code = 0;
        return pkt;
      }();
      const sockaddr_in6 dst_addr = {
          .sin6_family = bind_addr.sin6_family,
          .sin6_port = htons(kDestinationIdent),
          .sin6_addr = bind_addr.sin6_addr,
      };
      ASSERT_EQ(sendto(fd().get(), &pkt, sizeof(pkt), 0,
                       reinterpret_cast<const sockaddr*>(&dst_addr), sizeof(dst_addr)),
                ssize_t(sizeof(pkt)))
          << strerror(errno);

      struct {
        std::remove_const<decltype(pkt)>::type hdr;
        char unused;
      } hdr_with_extra = {
          .unused = 0x44,
      };
      memset(&hdr_with_extra.hdr, 0x4a, sizeof(hdr_with_extra.hdr));
      ASSERT_EQ(read(fd().get(), &hdr_with_extra, sizeof(hdr_with_extra)), ssize_t(sizeof(pkt)))
          << strerror(errno);
      EXPECT_EQ(hdr_with_extra.unused, 0x44);
      EXPECT_EQ(hdr_with_extra.hdr.icmp6_type, ICMP6_ECHO_REPLY);
      EXPECT_EQ(hdr_with_extra.hdr.icmp6_code, 0);
      EXPECT_NE(hdr_with_extra.hdr.icmp6_cksum, 0);
      EXPECT_EQ(htons(hdr_with_extra.hdr.icmp6_id), kBindIdent);
      EXPECT_EQ(hdr_with_extra.hdr.icmp6_seq, pkt.icmp6_seq);
    } break;
  }
}

INSTANTIATE_TEST_SUITE_P(NetSocket, IcmpSocketTest,
                         testing::Values(std::make_pair(SocketDomain::IPv4(), IPPROTO_ICMP),
                                         std::make_pair(SocketDomain::IPv6(), IPPROTO_ICMPV6)));

}  // namespace
