blob: 4f30ac5e83ce13a0f7db856fcce7024c21692758 [file] [log] [blame]
// Copyright 2019 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
///////////////////////////////////////////////////////////////////////////////
#include "tink/streamingaead/buffered_input_stream.h"
#include <algorithm>
#include <cstring>
#include <vector>
#include "absl/memory/memory.h"
#include "tink/input_stream.h"
#include "tink/util/errors.h"
#include "tink/util/status.h"
#include "tink/util/statusor.h"
namespace crypto {
namespace tink {
namespace streamingaead {
using util::Status;
using util::StatusOr;
BufferedInputStream::BufferedInputStream(
std::unique_ptr<crypto::tink::InputStream> input_stream) {
input_stream_ = std::move(input_stream);
count_in_buffer_ = 0;
count_backedup_ = 0;
position_ = 0;
buffer_.resize(4 * 1024); // 4 KB
buffer_offset_ = 0;
after_rewind_ = false;
rewinding_enabled_ = true;
direct_access_ = false;
status_ = Status::OK;
}
crypto::tink::util::StatusOr<int> BufferedInputStream::Next(const void** data) {
if (direct_access_) return input_stream_->Next(data);
if (!status_.ok()) return status_;
// We're just after rewind, so return all the data in the buffer, if any.
if (after_rewind_ && count_in_buffer_ > 0) {
after_rewind_ = false;
*data = buffer_.data();
position_ = count_in_buffer_;
return count_in_buffer_;
}
if (count_backedup_ > 0) { // Return the backed-up bytes.
buffer_offset_ = count_in_buffer_ - count_backedup_;
*data = buffer_.data() + buffer_offset_;
int backedup = count_backedup_;
count_backedup_ = 0;
position_ = count_in_buffer_;
return backedup;
}
// Read new bytes from input_stream_.
//
// If we don't allow rewind any more, all the data buffered so far
// can be discarded, and from now on we go directly to input_stream_
if (!rewinding_enabled_) {
direct_access_ = true;
buffer_.resize(0);
return input_stream_->Next(data);
}
// Otherwise, we read from input_stream_ the next chunk of data,
// and append it to buffer_.
after_rewind_ = false;
const void* buf;
auto next_result = input_stream_->Next(&buf);
if (!next_result.ok()) {
status_ = next_result.status();
return status_;
}
size_t count_read = next_result.ValueOrDie();
if (buffer_.size() < count_in_buffer_ + count_read) {
buffer_.resize(buffer_.size() + std::max(buffer_.size(), count_read));
}
memcpy(buffer_.data() + count_in_buffer_, buf, count_read);
buffer_offset_ = count_in_buffer_;
count_backedup_ = 0;
count_in_buffer_ += count_read;
position_ = position_ + count_read;
*data = buffer_.data() + buffer_offset_;
return count_read;
}
void BufferedInputStream::BackUp(int count) {
if (direct_access_) {
input_stream_->BackUp(count);
return;
}
if (!status_.ok() || count < 1 ||
count_backedup_ == (count_in_buffer_ - buffer_offset_)) {
return;
}
int actual_count = std::min(
count, count_in_buffer_ - buffer_offset_ - count_backedup_);
count_backedup_ += actual_count;
position_ = position_ - actual_count;
}
void BufferedInputStream::DisableRewinding() {
rewinding_enabled_ = false;
}
crypto::tink::util::Status BufferedInputStream::Rewind() {
if (!rewinding_enabled_) {
return util::Status(util::error::INVALID_ARGUMENT, "rewinding is disabled");
}
if (status_.ok() || status_.error_code() == util::error::OUT_OF_RANGE) {
status_ = util::OkStatus();
position_ = 0;
count_backedup_ = 0;
buffer_offset_ = 0;
after_rewind_ = true;
}
return status_;
}
BufferedInputStream::~BufferedInputStream() {
}
int64_t BufferedInputStream::Position() const {
if (direct_access_) return input_stream_->Position();
return position_;
}
} // namespace streamingaead
} // namespace tink
} // namespace crypto