blob: f891e8ce72635d4707b0fb2aded15627da3cb729 [file] [log] [blame]
// Copyright 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package com.google.crypto.tink.subtle;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NonWritableChannelException;
import java.nio.channels.SeekableByteChannel;
import java.security.GeneralSecurityException;
import java.util.Arrays;
/**
* An instance of {@link SeekableByteChannel} that allows random access to the plaintext of some
* ciphertext.
*/
class StreamingAeadSeekableDecryptingChannel implements SeekableByteChannel {
// Each plaintext segment has 16 bytes more of memory than the actual plaintext that it contains.
// This is a workaround for an incompatibility between Conscrypt and OpenJDK in their
// AES-GCM implementations, see b/67416642, b/31574439, and cr/170969008 for more information.
// Conscrypt refused to fix this issue, but even if they fixed it, there are always Android phones
// running old versions of Conscrypt, so we decided to take matters into our own hands.
// Why 16? Actually any number larger than 16 should work. 16 is the lower bound because it's the
// size of the tags of each AES-GCM ciphertext segment.
private static final int PLAINTEXT_SEGMENT_EXTRA_SIZE = 16;
private final SeekableByteChannel ciphertextChannel;
private final ByteBuffer ciphertextSegment;
private final ByteBuffer plaintextSegment;
private final ByteBuffer header;
private final long ciphertextChannelSize; // unverified size of the ciphertext
private final int numberOfSegments; // unverified number of segments
private final int lastCiphertextSegmentSize; // unverified size of the last segment.
private final byte[] aad;
private final StreamSegmentDecrypter decrypter;
private long plaintextPosition;
private long plaintextSize;
private boolean headerRead;
private boolean isCurrentSegmentDecrypted;
private int currentSegmentNr;
private boolean isopen;
private final int plaintextSegmentSize;
private final int ciphertextSegmentSize;
private final int ciphertextOffset;
private final int firstSegmentOffset;
public StreamingAeadSeekableDecryptingChannel(
NonceBasedStreamingAead streamAead,
SeekableByteChannel ciphertext,
byte[] associatedData) throws IOException, GeneralSecurityException {
decrypter = streamAead.newStreamSegmentDecrypter();
ciphertextChannel = ciphertext;
header = ByteBuffer.allocate(streamAead.getHeaderLength());
ciphertextSegmentSize = streamAead.getCiphertextSegmentSize();
ciphertextSegment = ByteBuffer.allocate(ciphertextSegmentSize);
plaintextSegmentSize = streamAead.getPlaintextSegmentSize();
plaintextSegment = ByteBuffer.allocate(plaintextSegmentSize + PLAINTEXT_SEGMENT_EXTRA_SIZE);
plaintextPosition = 0;
headerRead = false;
currentSegmentNr = -1;
isCurrentSegmentDecrypted = false;
ciphertextChannelSize = ciphertextChannel.size();
aad = Arrays.copyOf(associatedData, associatedData.length);
isopen = ciphertextChannel.isOpen();
int fullSegments = (int) (ciphertextChannelSize / ciphertextSegmentSize);
int remainder = (int) (ciphertextChannelSize % ciphertextSegmentSize);
int ciphertextOverhead = streamAead.getCiphertextOverhead();
if (remainder > 0) {
numberOfSegments = fullSegments + 1;
if (remainder < ciphertextOverhead) {
throw new IOException("Invalid ciphertext size");
}
lastCiphertextSegmentSize = remainder;
} else {
numberOfSegments = fullSegments;
lastCiphertextSegmentSize = ciphertextSegmentSize;
}
ciphertextOffset = streamAead.getCiphertextOffset();
firstSegmentOffset = ciphertextOffset - streamAead.getHeaderLength();
if (firstSegmentOffset < 0) {
throw new IOException("Invalid ciphertext offset or header length");
}
long overhead = (long) numberOfSegments * ciphertextOverhead + ciphertextOffset;
if (overhead > ciphertextChannelSize) {
throw new IOException("Ciphertext is too short");
}
plaintextSize = ciphertextChannelSize - overhead;
}
/**
* A description of the state of this StreamingAeadSeekableDecryptingChannel.
* While this description does not contain plaintext or key material
* it contains length information that might be confidential.
*/
@Override
public synchronized String toString() {
StringBuilder res =
new StringBuilder();
String ctChannel;
try {
ctChannel = "position:" + ciphertextChannel.position();
} catch (IOException ex) {
ctChannel = "position: n/a";
}
res.append("StreamingAeadSeekableDecryptingChannel")
.append("\nciphertextChannel").append(ctChannel)
.append("\nciphertextChannelSize:").append(ciphertextChannelSize)
.append("\nplaintextSize:").append(plaintextSize)
.append("\nciphertextSegmentSize:").append(ciphertextSegmentSize)
.append("\nnumberOfSegments:").append(numberOfSegments)
.append("\nheaderRead:").append(headerRead)
.append("\nplaintextPosition:").append(plaintextPosition)
.append("\nHeader")
.append(" position:").append(header.position())
.append(" limit:").append(header.position())
.append("\ncurrentSegmentNr:").append(currentSegmentNr)
.append("\nciphertextSgement")
.append(" position:").append(ciphertextSegment.position())
.append(" limit:").append(ciphertextSegment.limit())
.append("\nisCurrentSegmentDecrypted:").append(isCurrentSegmentDecrypted)
.append("\nplaintextSegment")
.append(" position:").append(plaintextSegment.position())
.append(" limit:").append(plaintextSegment.limit());
return res.toString();
}
/**
* Returns the position of this channel.
* The position is relative to the plaintext.
*/
@Override
public synchronized long position() {
return plaintextPosition;
}
/**
* Sets the position in the plaintext.
* Setting the position to a value greater than the plaintext size is legal.
* A later attempt to read byte will throw an IOException.
*/
@Override
public synchronized SeekableByteChannel position(long newPosition) {
plaintextPosition = newPosition;
return this;
}
/**
* Tries to read the header of the ciphertext and derive the key used for the
* ciphertext from the information in the header.
*
* @return true if the header was fully read and has a correct format.
* Returns false if the header could not be read.
* @throws IOException if the header was incorrectly formatted or if there
* was an exception during the key derivation.
*/
private boolean tryReadHeader() throws IOException {
ciphertextChannel.position(header.position() + firstSegmentOffset);
ciphertextChannel.read(header);
if (header.remaining() > 0) {
return false;
} else {
header.flip();
try {
decrypter.init(header, aad);
headerRead = true;
} catch (GeneralSecurityException ex) {
// TODO(bleichen): Define the state of this.
throw new IOException(ex);
}
return true;
}
}
private int getSegmentNr(long plaintextPosition) {
return (int) ((plaintextPosition + ciphertextOffset) / plaintextSegmentSize);
}
/**
* Tries to read and decrypt a ciphertext segment.
* @param segmentNr the number of the segment
* @return true if the segment was read and correctly decrypted.
* Returns false if the segment could not be fully read.
* @throws IOException if there was an exception reading the ciphertext,
* if the segment number was incorrect, or
* if there was an exception trying to decrypt the ciphertext segment.
*/
private boolean tryLoadSegment(int segmentNr) throws IOException {
if (segmentNr < 0 || segmentNr >= numberOfSegments) {
throw new IOException("Invalid position");
}
boolean isLast = segmentNr == numberOfSegments - 1;
if (segmentNr == currentSegmentNr) {
if (isCurrentSegmentDecrypted) {
return true;
}
} else {
// segmentNr != currentSegmentNr
long ciphertextPosition = (long) segmentNr * ciphertextSegmentSize;
int segmentSize = ciphertextSegmentSize;
if (isLast) {
segmentSize = lastCiphertextSegmentSize;
}
if (segmentNr == 0) {
segmentSize -= ciphertextOffset;
ciphertextPosition = ciphertextOffset;
}
ciphertextChannel.position(ciphertextPosition);
ciphertextSegment.clear();
ciphertextSegment.limit(segmentSize);
currentSegmentNr = segmentNr;
isCurrentSegmentDecrypted = false;
}
if (ciphertextSegment.remaining() > 0) {
ciphertextChannel.read(ciphertextSegment);
}
if (ciphertextSegment.remaining() > 0) {
return false;
}
ciphertextSegment.flip();
plaintextSegment.clear();
try {
decrypter.decryptSegment(ciphertextSegment, segmentNr, isLast, plaintextSegment);
} catch (GeneralSecurityException ex) {
// The current segment did not validate. Ensure that this instance remains
// in a valid state.
currentSegmentNr = -1;
throw new IOException("Failed to decrypt", ex);
}
plaintextSegment.flip();
isCurrentSegmentDecrypted = true;
return true;
}
/**
* Returns true if plaintextPositon is at the end of the file
* and this has been verified, by decrypting the last segment.
*/
private boolean reachedEnd() {
return (isCurrentSegmentDecrypted
&& currentSegmentNr == numberOfSegments - 1
&& plaintextSegment.remaining() == 0);
}
/**
* Atomic read from a given position.
*
* This method works in the same way as read(ByteBuffer), except that it starts at the given
* position and does not modify the channel's position.
*/
public synchronized int read(ByteBuffer dst, long start) throws IOException {
long oldPosition = position();
try {
position(start);
return read(dst);
} finally {
position(oldPosition);
}
}
@Override
public synchronized int read(ByteBuffer dst) throws IOException {
if (!isopen) {
throw new ClosedChannelException();
}
if (!headerRead) {
if (!tryReadHeader()) {
return 0;
}
}
int startPos = dst.position();
while (dst.remaining() > 0 && plaintextPosition < plaintextSize) {
// Determine segmentNr for the plaintext to read and the offset in
// the plaintext, where reading should start.
int segmentNr = getSegmentNr(plaintextPosition);
int segmentOffset;
if (segmentNr == 0) {
segmentOffset = (int) plaintextPosition;
} else {
segmentOffset = (int) ((plaintextPosition + ciphertextOffset) % plaintextSegmentSize);
}
if (tryLoadSegment(segmentNr)) {
plaintextSegment.position(segmentOffset);
if (plaintextSegment.remaining() <= dst.remaining()) {
plaintextPosition += plaintextSegment.remaining();
dst.put(plaintextSegment);
} else {
int sliceSize = dst.remaining();
ByteBuffer slice = plaintextSegment.duplicate();
slice.limit(slice.position() + sliceSize);
dst.put(slice);
plaintextPosition += sliceSize;
plaintextSegment.position(plaintextSegment.position() + sliceSize);
}
} else {
break;
}
}
int read = dst.position() - startPos;
if (read == 0 && reachedEnd()) {
return -1;
}
return read;
}
/**
* Returns the expected size of the plaintext.
* Note that this implementation does not perform an integrity check on the size.
* I.e. if the file has been truncated then size() will return the wrong
* result. Reading the last block of the ciphertext will verify whether size()
* is correct.
*/
@Override
public long size() {
return plaintextSize;
}
public synchronized long verifiedSize() throws IOException {
if (tryLoadSegment(numberOfSegments - 1)) {
return plaintextSize;
} else {
throw new IOException("could not verify the size");
}
}
@Override
public SeekableByteChannel truncate(long size) throws NonWritableChannelException {
throw new NonWritableChannelException();
}
@Override
public int write(ByteBuffer src) throws NonWritableChannelException {
throw new NonWritableChannelException();
}
@Override
public synchronized void close() throws IOException {
ciphertextChannel.close();
isopen = false;
}
@Override
public synchronized boolean isOpen() {
return isopen;
}
}