blob: fa47131b85f9371048c98c3ac4796d0b62fc5b1c [file] [log] [blame]
// Copyright 2023 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <unistd.h>
#include "interrupt_handling.h"
#include "util.h"
// Set to 1 to print debug messages to stderr during development
#define DEBUG 0
namespace {
// Retrieve a signal mask for SIGINT/SIGHUP/SIGTERM
sigset_t GetInterruptSignalMask() {
sigset_t mask;
sigemptyset(&mask);
sigaddset(&mask, SIGINT);
sigaddset(&mask, SIGHUP);
sigaddset(&mask, SIGTERM);
return mask;
}
// Set the signal action for a given |signum|, returning the previous one
// in |*old_action| is |old_action != nullptr|.
void SetSignalAction(int signum, const struct sigaction* action,
struct sigaction* old_action) {
if (sigaction(signum, action, old_action) < 0)
ErrnoFatal("sigaction");
}
} // namespace
//////////////////////////////////////////////////////////////////////////
///
/// InterruptBlocker
///
InterruptBlocker::InterruptBlocker() {
sigset_t block_interrupts = GetInterruptSignalMask();
if (sigprocmask(SIG_BLOCK, &block_interrupts, &prev_signal_mask_) < 0)
ErrnoFatal("sigprocmask");
}
InterruptBlocker::~InterruptBlocker() {
if (sigprocmask(SIG_SETMASK, &prev_signal_mask_, nullptr) < 0)
ErrnoFatal("sigprocmask");
}
//////////////////////////////////////////////////////////////////////////
///
/// InterruptHandlerBase
///
InterruptHandlerBase::InterruptHandlerBase(const struct sigaction& action) {
// Block the signals before changing the handlers.
sigset_t mask = GetInterruptSignalMask();
sigprocmask(SIG_BLOCK, &mask, &old_mask_);
SetSignalAction(SIGINT, &action, &old_int_action_);
SetSignalAction(SIGHUP, &action, &old_hup_action_);
SetSignalAction(SIGTERM, &action, &old_term_action_);
// Unblock the signals now.
sigprocmask(SIG_UNBLOCK, &mask, nullptr);
}
InterruptHandlerBase::~InterruptHandlerBase() {
// Block the signal before changing the action handlers.
sigset_t mask = GetInterruptSignalMask();
sigprocmask(SIG_BLOCK, &mask, nullptr);
SetSignalAction(SIGINT, &old_int_action_, nullptr);
SetSignalAction(SIGHUP, &old_hup_action_, nullptr);
SetSignalAction(SIGTERM, &old_term_action_, nullptr);
// Restore the original signal mask.
sigprocmask(SIG_SETMASK, &old_mask_, nullptr);
}
//////////////////////////////////////////////////////////////////////////
///
/// InterruptCatcher
///
InterruptCatcher::InterruptCatcher() : InterruptHandlerBase(MakeAction()) {
s_interrupted_ = 0;
HandlePendingInterrupt();
}
InterruptCatcher::~InterruptCatcher() = default;
#if DEBUG
#define WRITE(msg) ::write(2, msg, sizeof(msg) - 1)
#else
#define WRITE(msg) (void)(msg)
#endif
// static
struct sigaction InterruptCatcher::MakeAction() {
struct sigaction result = {};
result.sa_handler = [](int signum) {
s_interrupted_ = signum;
if (signum == SIGINT)
WRITE("\nSIGINT SIGNALED\n");
else if (signum == SIGHUP)
WRITE("\nSIGHUP SIGNALED\n");
else if (signum == SIGTERM)
WRITE("\nSIGTERM SIGNALED\n");
};
return result;
}
// static
void InterruptCatcher::HandlePendingInterrupt() {
sigset_t pending;
sigemptyset(&pending);
if (sigpending(&pending) == -1) {
perror("ninja: sigpending");
return;
}
if (sigismember(&pending, SIGINT)) {
WRITE("\nSIGINT PENDING\n");
s_interrupted_ = SIGINT;
} else if (sigismember(&pending, SIGTERM)) {
WRITE("\nSIGTERM PENDING\n");
s_interrupted_ = SIGTERM;
} else if (sigismember(&pending, SIGHUP)) {
WRITE("\nSIGHUP PENDING\n");
s_interrupted_ = SIGHUP;
}
}
// static
volatile sig_atomic_t InterruptCatcher::s_interrupted_ = 0;
//////////////////////////////////////////////////////////////////////////
///
/// InterruptForwarder
///
#include <unistd.h>
InterruptForwarder::InterruptForwarder(pid_t process_group)
: InterruptHandlerBase(MakeAction()), old_process_group_(s_process_group_) {
s_process_group_ = process_group;
}
InterruptForwarder::~InterruptForwarder() {
s_process_group_ = old_process_group_;
}
// static
struct sigaction InterruptForwarder::MakeAction() {
struct sigaction result = {};
result.sa_handler = [](int signum) {
// Send the interrupt to the server's process group
kill(-s_process_group_, signum);
WRITE("\nINTERRUPT FORWARDED\n");
};
return result;
}
// static
pid_t InterruptForwarder::s_process_group_ = 0;