| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "chre_host/socket_server.h" |
| |
| #include <poll.h> |
| |
| #include <cassert> |
| #include <cinttypes> |
| #include <csignal> |
| #include <cstdlib> |
| #include <map> |
| #include <mutex> |
| |
| #include <cutils/sockets.h> |
| |
| #include "chre_host/log.h" |
| |
| namespace android { |
| namespace chre { |
| |
| std::atomic<bool> SocketServer::sSignalReceived(false); |
| |
| SocketServer::SocketServer() { |
| // Initialize the socket fds field for all inactive client slots to -1, so |
| // poll skips over it, and we don't attempt to send on it |
| for (size_t i = 1; i <= kMaxActiveClients; i++) { |
| mPollFds[i].fd = -1; |
| mPollFds[i].events = POLLIN; |
| } |
| } |
| |
| void SocketServer::run(const char *socketName, bool allowSocketCreation, |
| ClientMessageCallback clientMessageCallback) { |
| mClientMessageCallback = clientMessageCallback; |
| |
| mSockFd = android_get_control_socket(socketName); |
| if (mSockFd == INVALID_SOCKET && allowSocketCreation) { |
| LOGI("Didn't inherit socket, creating..."); |
| mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED, |
| SOCK_SEQPACKET); |
| } |
| |
| if (mSockFd == INVALID_SOCKET) { |
| LOGE("Couldn't get/create socket"); |
| } else { |
| int ret = listen(mSockFd, kMaxPendingConnectionRequests); |
| if (ret < 0) { |
| LOG_ERROR("Couldn't listen on socket", errno); |
| } else { |
| serviceSocket(); |
| } |
| |
| { |
| std::lock_guard<std::mutex> lock(mClientsMutex); |
| for (const auto &pair : mClients) { |
| int clientSocket = pair.first; |
| if (close(clientSocket) != 0) { |
| LOGI("Couldn't close client %" PRIu16 "'s socket: %s", |
| pair.second.clientId, strerror(errno)); |
| } |
| } |
| mClients.clear(); |
| } |
| close(mSockFd); |
| } |
| } |
| |
| void SocketServer::sendToAllClients(const void *data, size_t length) { |
| std::lock_guard<std::mutex> lock(mClientsMutex); |
| |
| int deliveredCount = 0; |
| for (const auto &pair : mClients) { |
| int clientSocket = pair.first; |
| uint16_t clientId = pair.second.clientId; |
| if (sendToClientSocket(data, length, clientSocket, clientId)) { |
| deliveredCount++; |
| } else if (errno == EINTR) { |
| // Exit early if we were interrupted - we should only get this for |
| // SIGINT/SIGTERM, so we should exit quickly |
| break; |
| } |
| } |
| |
| if (deliveredCount == 0) { |
| LOGW("Got message but didn't deliver to any clients"); |
| } |
| } |
| |
| bool SocketServer::sendToClientById(const void *data, size_t length, |
| uint16_t clientId) { |
| std::lock_guard<std::mutex> lock(mClientsMutex); |
| |
| bool sent = false; |
| for (const auto &pair : mClients) { |
| uint16_t thisClientId = pair.second.clientId; |
| if (thisClientId == clientId) { |
| int clientSocket = pair.first; |
| sent = sendToClientSocket(data, length, clientSocket, thisClientId); |
| break; |
| } |
| } |
| |
| return sent; |
| } |
| |
| void SocketServer::acceptClientConnection() { |
| int clientSocket = accept(mSockFd, NULL, NULL); |
| if (clientSocket < 0) { |
| LOG_ERROR("Couldn't accept client connection", errno); |
| } else if (mClients.size() >= kMaxActiveClients) { |
| LOGW("Rejecting client request - maximum number of clients reached"); |
| close(clientSocket); |
| } else { |
| ClientData clientData; |
| clientData.clientId = mNextClientId++; |
| |
| // We currently don't handle wraparound - if we're getting this many |
| // connects/disconnects, then something is wrong. |
| // TODO: can handle this properly by iterating over the existing clients to |
| // avoid a conflict. |
| if (clientData.clientId == 0) { |
| LOGE("Couldn't allocate client ID"); |
| std::exit(-1); |
| } |
| |
| bool slotFound = false; |
| for (size_t i = 1; i <= kMaxActiveClients; i++) { |
| if (mPollFds[i].fd < 0) { |
| mPollFds[i].fd = clientSocket; |
| slotFound = true; |
| break; |
| } |
| } |
| |
| if (!slotFound) { |
| LOGE("Couldn't find slot for client!"); |
| assert(slotFound); |
| close(clientSocket); |
| } else { |
| { |
| std::lock_guard<std::mutex> lock(mClientsMutex); |
| mClients[clientSocket] = clientData; |
| } |
| LOGI( |
| "Accepted new client connection (count %zu), assigned client ID " |
| "%" PRIu16, |
| mClients.size(), clientData.clientId); |
| } |
| } |
| } |
| |
| void SocketServer::handleClientData(int clientSocket) { |
| const ClientData &clientData = mClients[clientSocket]; |
| uint16_t clientId = clientData.clientId; |
| |
| ssize_t packetSize = |
| recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT); |
| if (packetSize < 0) { |
| LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId, |
| strerror(errno)); |
| } else if (packetSize == 0) { |
| LOGI("Client %" PRIu16 " disconnected", clientId); |
| disconnectClient(clientSocket); |
| } else { |
| LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId); |
| mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize); |
| } |
| } |
| |
| void SocketServer::disconnectClient(int clientSocket) { |
| { |
| std::lock_guard<std::mutex> lock(mClientsMutex); |
| mClients.erase(clientSocket); |
| } |
| close(clientSocket); |
| |
| bool removed = false; |
| for (size_t i = 1; i <= kMaxActiveClients; i++) { |
| if (mPollFds[i].fd == clientSocket) { |
| mPollFds[i].fd = -1; |
| removed = true; |
| break; |
| } |
| } |
| |
| if (!removed) { |
| LOGE("Out of sync"); |
| assert(removed); |
| } |
| } |
| |
| bool SocketServer::sendToClientSocket(const void *data, size_t length, |
| int clientSocket, uint16_t clientId) { |
| errno = 0; |
| ssize_t bytesSent = send(clientSocket, data, length, 0); |
| if (bytesSent < 0) { |
| LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length, |
| clientId, strerror(errno)); |
| } else if (bytesSent == 0) { |
| LOGW("Client %" PRIu16 " disconnected before message could be delivered", |
| clientId); |
| } else { |
| LOGV("Delivered message of size %zu bytes to client %" PRIu16, length, |
| clientId); |
| } |
| |
| return (bytesSent > 0); |
| } |
| |
| void SocketServer::serviceSocket() { |
| constexpr size_t kListenIndex = 0; |
| static_assert(kListenIndex == 0, |
| "Code assumes that the first index is always the listen " |
| "socket"); |
| |
| mPollFds[kListenIndex].fd = mSockFd; |
| mPollFds[kListenIndex].events = POLLIN; |
| |
| // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM, |
| // and ignore other signals |
| sigset_t signalMask; |
| sigfillset(&signalMask); |
| sigdelset(&signalMask, SIGINT); |
| sigdelset(&signalMask, SIGTERM); |
| |
| LOGI("Ready to accept connections"); |
| while (!sSignalReceived) { |
| int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask); |
| if (ret == -1) { |
| // Don't use TEMP_FAILURE_RETRY since our logic needs to check |
| // sSignalReceived to see if it should exit where as TEMP_FAILURE_RETRY |
| // is a tight retry loop around ppoll. |
| if (errno == EINTR) { |
| continue; |
| } |
| LOGI("Exiting poll loop: %s", strerror(errno)); |
| break; |
| } |
| |
| if (mPollFds[kListenIndex].revents & POLLIN) { |
| acceptClientConnection(); |
| } |
| |
| for (size_t i = 1; i <= kMaxActiveClients; i++) { |
| if (mPollFds[i].fd < 0) { |
| continue; |
| } |
| |
| if (mPollFds[i].revents & POLLIN) { |
| handleClientData(mPollFds[i].fd); |
| } |
| } |
| } |
| } |
| |
| } // namespace chre |
| } // namespace android |