blob: 8e35006b43f6ca3ccdb31a8256b8fab1570dc0c4 [file] [log] [blame]
//! Schannel TLS streams.
use std::any::Any;
use std::cmp;
use std::error::Error;
use std::fmt;
use std::io::{self, Read, BufRead, Write, Cursor};
use std::mem;
use std::ptr;
use std::slice;
use std::sync::Arc;
use winapi::shared::minwindef as winapi;
use winapi::shared::{ntdef, sspi, winerror};
use winapi::um::{self, schannel, wincrypt};
use crate::{INIT_REQUESTS, ACCEPT_REQUESTS, Inner, secbuf, secbuf_desc};
use crate::alpn_list::AlpnList;
use crate::cert_chain::{CertChain, CertChainContext};
use crate::cert_store::{CertAdd, CertStore};
use crate::cert_context::CertContext;
use crate::security_context::SecurityContext;
use crate::context_buffer::ContextBuffer;
use crate::schannel_cred::SchannelCred;
lazy_static! {
static ref szOID_PKIX_KP_SERVER_AUTH: Vec<u8> =
wincrypt::szOID_PKIX_KP_SERVER_AUTH.bytes().chain(Some(0)).collect();
static ref szOID_SERVER_GATED_CRYPTO: Vec<u8> =
wincrypt::szOID_SERVER_GATED_CRYPTO.bytes().chain(Some(0)).collect();
static ref szOID_SGC_NETSCAPE: Vec<u8> =
wincrypt::szOID_SGC_NETSCAPE.bytes().chain(Some(0)).collect();
}
/// A builder type for `TlsStream`s.
pub struct Builder {
domain: Option<Vec<u16>>,
use_sni: bool,
accept_invalid_hostnames: bool,
verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
cert_store: Option<CertStore>,
requested_application_protocols: Option<Vec<Vec<u8>>>,
}
impl Default for Builder {
fn default() -> Builder {
Builder {
domain: None,
use_sni: true,
accept_invalid_hostnames: false,
verify_callback: None,
cert_store: None,
requested_application_protocols: None,
}
}
}
impl Builder {
/// Returns a new `Builder`.
pub fn new() -> Builder {
Builder::default()
}
/// Sets the domain associated with connections created with this `Builder`.
///
/// The domain will be used for Server Name Indication as well as
/// certificate validation.
pub fn domain(&mut self, domain: &str) -> &mut Builder {
self.domain = Some(domain.encode_utf16().chain(Some(0)).collect());
self
}
/// Determines if Server Name Indication (SNI) will be used.
///
/// Defaults to `true`.
pub fn use_sni(&mut self, use_sni: bool) -> &mut Builder {
self.use_sni = use_sni;
self
}
/// Determines if the server's hostname will be checked during certificate verification.
///
/// Defaults to `false`.
pub fn accept_invalid_hostnames(&mut self, accept_invalid_hostnames: bool) -> &mut Builder {
self.accept_invalid_hostnames = accept_invalid_hostnames;
self
}
/// Set a verification callback to be used for connections created with this `Builder`.
///
/// The callback is provided with an io::Result indicating if the (pre)validation was
/// successful. The Ok() variant indicates a successful validation while the Err() variant
/// contains the errorcode returned from the internal verification process.
/// The validated certificate, is accessible through the second argument of the closure.
pub fn verify_callback<F>(&mut self, callback: F) -> &mut Builder
where F: Fn(CertValidationResult) -> io::Result<()> + 'static + Sync + Send
{
self.verify_callback = Some(Arc::new(callback));
self
}
/// Specifies a custom certificate store which is later used when validating
/// a server's certificate.
///
/// This option is only used for client connections and is used to construct
/// the certificate chain which the server's certificate is validated
/// against.
///
/// Note that adding certificates here means that they are
/// implicitly trusted.
pub fn cert_store(&mut self, cert_store: CertStore) -> &mut Builder {
self.cert_store = Some(cert_store);
self
}
/// Requests one of a set of application protocols using alpn
pub fn request_application_protocols(&mut self, alpns: &[&[u8]]) -> &mut Builder {
self.requested_application_protocols =
Some(alpns.iter().map(|bytes| bytes.to_vec()).collect::<Vec<_>>());
self
}
/// Initialize a new TLS session where the stream provided will be
/// connecting to a remote TLS server.
///
/// If the stream provided is a blocking stream then the entire handshake
/// will be performed if possible, but if the stream is in nonblocking mode
/// then a `HandshakeError::Interrupted` variant may be returned. This
/// type can then be extracted to later call
/// `MidHandshakeTlsStream::handshake` when data becomes available.
pub fn connect<S>(&mut self,
cred: SchannelCred,
stream: S)
-> Result<TlsStream<S>, HandshakeError<S>>
where S: Read + Write
{
self.initialize(cred, false, stream)
}
/// Initialize a new TLS session where the stream provided will be
/// accepting a connection.
///
/// This method will tweak the protocol for "who talks first" and also
/// currently disables validation of the client that's connecting to us.
///
/// If the stream provided is a blocking stream then the entire handshake
/// will be performed if possible, but if the stream is in nonblocking mode
/// then a `HandshakeError::Interrupted` variant may be returned. This
/// type can then be extracted to later call
/// `MidHandshakeTlsStream::handshake` when data becomes available.
pub fn accept<S>(&mut self,
cred: SchannelCred,
stream: S)
-> Result<TlsStream<S>, HandshakeError<S>>
where S: Read + Write
{
self.initialize(cred, true, stream)
}
fn initialize<S>(&mut self,
mut cred: SchannelCred,
server: bool,
stream: S)
-> Result<TlsStream<S>, HandshakeError<S>>
where S: Read + Write
{
let domain = match self.domain {
Some(ref domain) if self.use_sni => Some(&domain[..]),
_ => None,
};
let (ctxt, buf) = match SecurityContext::initialize(&mut cred,
server,
domain,
&self.requested_application_protocols) {
Ok(pair) => pair,
Err(e) => return Err(HandshakeError::Failure(e)),
};
let stream = TlsStream {
cred: cred,
context: ctxt,
cert_store: self.cert_store.clone(),
domain: self.domain.clone(),
use_sni: self.use_sni,
accept_invalid_hostnames: self.accept_invalid_hostnames,
verify_callback: self.verify_callback.clone(),
stream: stream,
server: server,
accept_first: true,
state: State::Initializing {
needs_flush: false,
more_calls: true,
shutting_down: false,
validated: false,
},
needs_read: 1,
dec_in: Cursor::new(Vec::new()),
enc_in: Cursor::new(Vec::new()),
out_buf: Cursor::new(buf.map(|b| b.to_owned()).unwrap_or(Vec::new())),
last_write_len: 0,
requested_application_protocols: self.requested_application_protocols.clone(),
};
MidHandshakeTlsStream {
inner: stream,
}.handshake()
}
}
enum State {
Initializing {
needs_flush: bool,
more_calls: bool,
shutting_down: bool,
validated: bool,
},
Streaming { sizes: sspi::SecPkgContext_StreamSizes, },
Shutdown,
}
/// An Schannel TLS stream.
pub struct TlsStream<S> {
cred: SchannelCred,
context: SecurityContext,
cert_store: Option<CertStore>,
domain: Option<Vec<u16>>,
use_sni: bool,
accept_invalid_hostnames: bool,
verify_callback: Option<Arc<dyn Fn(CertValidationResult) -> io::Result<()> + Sync + Send>>,
stream: S,
state: State,
server: bool,
accept_first: bool,
needs_read: usize,
// valid from position() to len()
dec_in: Cursor<Vec<u8>>,
// valid from 0 to position()
enc_in: Cursor<Vec<u8>>,
// valid from position() to len()
out_buf: Cursor<Vec<u8>>,
/// the (unencrypted) length of the last write call used to track writes
last_write_len: usize,
requested_application_protocols: Option<Vec<Vec<u8>>>,
}
/// ensures that a TlsStream is always Sync/Send
fn _is_sync() {
fn sync<T: Sync + Send>() {}
sync::<TlsStream<()>>();
}
/// A failure which can happen during the `Builder::initialize` phase, either an
/// I/O error or an intermediate stream which has not completed its handshake.
#[derive(Debug)]
pub enum HandshakeError<S> {
/// A fatal I/O error occurred
Failure(io::Error),
/// The stream connection is in progress, but the handshake is not completed
/// yet.
Interrupted(MidHandshakeTlsStream<S>),
}
/// A struct used to wrap various cert chain validation results for callback processing.
pub struct CertValidationResult {
chain: CertChainContext,
res: i32,
chain_index: i32,
element_index: i32,
}
impl CertValidationResult {
/// Returns the certificate that failed validation if applicable
pub fn failed_certificate(&self) -> Option<CertContext> {
if let Some(cert_chain) = self.chain.get_chain(self.chain_index as usize) {
return cert_chain.get(self.element_index as usize);
}
None
}
/// Returns the final certificate chain in the certificate context if applicable
pub fn chain(&self) -> Option<CertChain> {
self.chain.final_chain()
}
/// Returns the result of the built-in certificate verification process.
pub fn result(&self) -> io::Result<()> {
if self.res as u32 != winerror::ERROR_SUCCESS {
Err(io::Error::from_raw_os_error(self.res))
} else {
Ok(())
}
}
}
impl<S: fmt::Debug + Any> Error for HandshakeError<S> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
HandshakeError::Failure(ref e) => Some(e),
HandshakeError::Interrupted(_) => None,
}
}
}
impl<S: fmt::Debug + Any> fmt::Display for HandshakeError<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let desc = match *self {
HandshakeError::Failure(_) => "failed to perform handshake",
HandshakeError::Interrupted(_) => "interrupted performing handshake",
};
write!(f, "{}", desc)?;
if let Some(e) = self.source() {
write!(f, ": {}", e)?;
}
Ok(())
}
}
/// A stream which has not yet completed its handshake.
#[derive(Debug)]
pub struct MidHandshakeTlsStream<S> {
inner: TlsStream<S>,
}
impl<S> fmt::Debug for TlsStream<S>
where S: fmt::Debug
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("TlsStream")
.field("stream", &self.stream)
.finish()
}
}
impl<S> TlsStream<S> {
/// Returns a reference to the wrapped stream.
pub fn get_ref(&self) -> &S {
&self.stream
}
/// Returns a mutable reference to the wrapped stream.
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
/// Indicates if this stream is the server- or client-side of a TLS session.
pub fn is_server(&self) -> bool {
self.server
}
}
impl<S> TlsStream<S>
where S: Read + Write
{
/// Returns the certificate used to identify this side of the TLS session.
///
/// Its associated cert store contains any intermediate certificates sent
/// along with the leaf.
pub fn certificate(&self) -> io::Result<CertContext> {
self.context.local_cert()
}
/// Returns the peer's certificate, if available.
///
/// Its associated cert store contains any intermediate certificates sent
/// by the server.
pub fn peer_certificate(&self) -> io::Result<CertContext> {
self.context.remote_cert()
}
/// Returns the negotiated application protocol for this tls stream, if one exists
pub fn negotiated_application_protocol(&self) -> io::Result<Option<Vec<u8>>> {
let client_proto = self.context.application_protocol()?;
if client_proto.ProtoNegoStatus != sspi::SecApplicationProtocolNegotiationStatus_Success
|| client_proto.ProtoNegoExt != sspi::SecApplicationProtocolNegotiationExt_ALPN
{
return Ok(None);
}
Ok(Some(client_proto.ProtocolId[..client_proto.ProtocolIdSize as usize].to_vec()))
}
/// Returns whether or not the session was resumed.
pub fn session_resumed(&self) -> io::Result<bool> {
let session_info = self.context.session_info()?;
Ok(session_info.dwFlags & schannel::SSL_SESSION_RECONNECT > 0)
}
/// Returns a reference to the buffer of pending data.
///
/// Like `BufRead::fill_buf` except that it will return an empty slice
/// rather than reading from the wrapped stream if there is no buffered
/// data.
pub fn get_buf(&self) -> &[u8] {
&self.dec_in.get_ref()[self.dec_in.position() as usize..]
}
/// Shuts the TLS session down.
pub fn shutdown(&mut self) -> io::Result<()> {
match self.state {
State::Shutdown => return Ok(()),
State::Initializing { shutting_down: true, .. } => {}
_ => {
unsafe {
let mut token = um::schannel::SCHANNEL_SHUTDOWN;
let ptr = &mut token as *mut _ as *mut u8;
let size = mem::size_of_val(&token);
let token = slice::from_raw_parts_mut(ptr, size);
let mut buf = [secbuf(sspi::SECBUFFER_TOKEN, Some(token))];
let mut desc = secbuf_desc(&mut buf);
match sspi::ApplyControlToken(self.context.get_mut(), &mut desc) {
winerror::SEC_E_OK => {}
err => return Err(io::Error::from_raw_os_error(err as i32)),
}
}
self.state = State::Initializing {
needs_flush: false,
more_calls: true,
shutting_down: true,
validated: false,
};
self.needs_read = 0;
}
}
self.initialize().map(|_| ())
}
fn step_initialize(&mut self) -> io::Result<()> {
unsafe {
let pos = self.enc_in.position() as usize;
let mut inbufs = vec![secbuf(sspi::SECBUFFER_TOKEN,
Some(&mut self.enc_in.get_mut()[..pos])),
secbuf(sspi::SECBUFFER_EMPTY, None)];
// Make sure `AlpnList` is kept alive for the duration of this function.
let mut alpns = self.requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn));
if let Some(ref mut alpns) = alpns {
inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS,
Some(&mut alpns[..])));
};
let mut inbuf_desc = secbuf_desc(&mut inbufs[..]);
let mut outbufs = [secbuf(sspi::SECBUFFER_TOKEN, None),
secbuf(sspi::SECBUFFER_ALERT, None),
secbuf(sspi::SECBUFFER_EMPTY, None)];
let mut outbuf_desc = secbuf_desc(&mut outbufs);
let mut attributes = 0;
let status = if self.server {
let ptr = if self.accept_first {
ptr::null_mut()
} else {
self.context.get_mut()
};
sspi::AcceptSecurityContext(&mut self.cred.as_inner(),
ptr,
&mut inbuf_desc,
ACCEPT_REQUESTS,
0,
self.context.get_mut(),
&mut outbuf_desc,
&mut attributes,
ptr::null_mut())
} else {
let domain = match self.domain {
Some(ref domain) if self.use_sni => domain.as_ptr() as *mut u16,
_ => ptr::null_mut(),
};
sspi::InitializeSecurityContextW(&mut self.cred.as_inner(),
self.context.get_mut(),
domain,
INIT_REQUESTS,
0,
0,
&mut inbuf_desc,
0,
ptr::null_mut(),
&mut outbuf_desc,
&mut attributes,
ptr::null_mut())
};
for buf in &outbufs[1..] {
if !buf.pvBuffer.is_null() {
sspi::FreeContextBuffer(buf.pvBuffer);
}
}
match status {
winerror::SEC_I_CONTINUE_NEEDED => {
// Windows apparently doesn't like AcceptSecurityContext
// being called as if it were the second time unless the
// first call to AcceptSecurityContext succeeded with
// CONTINUE_NEEDED.
//
// In other words, if we were to set `accept_first` to
// `false` after the literal first call to
// `AcceptSecurityContext` while the call returned
// INCOMPLETE_MESSAGE, the next call would return an error.
//
// For that reason we only set `accept_first` to false here
// once we've actually successfully received the full
// "token" from the client.
self.accept_first = false;
let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
} else {
self.enc_in.position() as usize
};
let to_write = ContextBuffer(outbufs[0]);
self.consume_enc_in(nread);
self.needs_read = (self.enc_in.position() == 0) as usize;
self.out_buf.get_mut().extend_from_slice(&to_write);
}
winerror::SEC_E_INCOMPLETE_MESSAGE => {
self.needs_read = if inbufs[1].BufferType == sspi::SECBUFFER_MISSING {
inbufs[1].cbBuffer as usize
} else {
1
};
}
winerror::SEC_E_OK => {
let nread = if inbufs[1].BufferType == sspi::SECBUFFER_EXTRA {
self.enc_in.position() as usize - inbufs[1].cbBuffer as usize
} else {
self.enc_in.position() as usize
};
let to_write = if outbufs[0].pvBuffer.is_null() {
None
} else {
Some(ContextBuffer(outbufs[0]))
};
self.consume_enc_in(nread);
self.needs_read = (self.enc_in.position() == 0) as usize;
if let Some(to_write) = to_write {
self.out_buf.get_mut().extend_from_slice(&to_write);
}
if self.enc_in.position() != 0 {
self.decrypt()?;
}
if let State::Initializing { ref mut more_calls, .. } = self.state {
*more_calls = false;
}
}
_ => {
return Err(io::Error::from_raw_os_error(status as i32))
}
}
Ok(())
}
}
fn initialize(&mut self) -> io::Result<Option<sspi::SecPkgContext_StreamSizes>> {
loop {
match self.state {
State::Initializing { mut needs_flush, more_calls, shutting_down, validated } => {
if self.write_out()? > 0 {
needs_flush = true;
if let State::Initializing { ref mut needs_flush, .. } = self.state {
*needs_flush = true;
}
}
if needs_flush {
self.stream.flush()?;
if let State::Initializing { ref mut needs_flush, .. } = self.state {
*needs_flush = false;
}
}
if !shutting_down && !validated {
// on the last call, we require a valid certificate
if self.validate(!more_calls)? {
if let State::Initializing { ref mut validated, .. } = self.state {
*validated = true;
}
}
}
if !more_calls {
self.state = if shutting_down {
State::Shutdown
} else {
State::Streaming { sizes: self.context.stream_sizes()? }
};
continue;
}
if self.needs_read > 0 {
if self.read_in()? == 0 {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof,
"unexpected EOF during handshake"));
}
}
self.step_initialize()?;
}
State::Streaming { sizes } => return Ok(Some(sizes)),
State::Shutdown => return Ok(None),
}
}
}
/// Returns true when the certificate was succesfully verified
/// Returns false, when a verification isn't necessary (yet)
/// Returns an error when the verification failed
fn validate(&mut self, require_cert: bool) -> io::Result<bool> {
// If we're accepting connections then we don't perform any validation
// for the remote certificate, that's what they're doing!
if self.server {
return Ok(false);
}
let cert_context = match self.context.remote_cert() {
Err(_) if !require_cert => return Ok(false),
ret => ret?
};
let cert_chain = unsafe {
let cert_store = match (cert_context.cert_store(), &self.cert_store) {
(Some(ref mut chain_certs), &Some(ref extra_certs)) => {
for extra_cert in extra_certs.certs() {
chain_certs.add_cert(&extra_cert, CertAdd::ReplaceExisting)?;
}
chain_certs.as_inner()
},
(Some(chain_certs), &None) => chain_certs.as_inner(),
(None, &Some(ref extra_certs)) => extra_certs.as_inner(),
(None, &None) => ptr::null_mut()
};
let flags = wincrypt::CERT_CHAIN_CACHE_END_CERT |
wincrypt::CERT_CHAIN_REVOCATION_CHECK_CACHE_ONLY |
wincrypt::CERT_CHAIN_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT;
let mut para: wincrypt::CERT_CHAIN_PARA = mem::zeroed();
para.cbSize = mem::size_of_val(&para) as winapi::DWORD;
para.RequestedUsage.dwType = wincrypt::USAGE_MATCH_TYPE_OR;
let mut identifiers = [szOID_PKIX_KP_SERVER_AUTH.as_ptr() as ntdef::LPSTR,
szOID_SERVER_GATED_CRYPTO.as_ptr() as ntdef::LPSTR,
szOID_SGC_NETSCAPE.as_ptr() as ntdef::LPSTR];
para.RequestedUsage.Usage.cUsageIdentifier = identifiers.len() as winapi::DWORD;
para.RequestedUsage.Usage.rgpszUsageIdentifier = identifiers.as_mut_ptr();
let mut cert_chain = mem::zeroed();
let res = wincrypt::CertGetCertificateChain(ptr::null_mut(),
cert_context.as_inner(),
ptr::null_mut(),
cert_store,
&mut para,
flags,
ptr::null_mut(),
&mut cert_chain);
if res == winapi::TRUE {
CertChainContext(cert_chain as *mut _)
} else {
return Err(io::Error::last_os_error())
}
};
unsafe {
// check if we trust the root-CA explicitly
let mut para_flags = wincrypt::CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS;
if let Some(ref mut store) = self.cert_store {
if let Some(chain) = cert_chain.final_chain() {
// check if any cert of the chain is in the passed store (and therefore trusted)
if chain.certificates().any(|cert| store.certs().any(|root_cert| root_cert == cert)) {
para_flags |= wincrypt::CERT_CHAIN_POLICY_ALLOW_UNKNOWN_CA_FLAG;
}
}
}
let mut extra_para: wincrypt::SSL_EXTRA_CERT_CHAIN_POLICY_PARA = mem::zeroed();
*extra_para.u.cbSize_mut() = mem::size_of_val(&extra_para) as winapi::DWORD;
extra_para.dwAuthType = wincrypt::AUTHTYPE_SERVER;
match self.domain {
Some(ref mut domain) if !self.accept_invalid_hostnames => {
extra_para.pwszServerName = domain.as_mut_ptr();
}
_ => {}
}
let mut para: wincrypt::CERT_CHAIN_POLICY_PARA = mem::zeroed();
para.cbSize = mem::size_of_val(&para) as winapi::DWORD;
para.dwFlags = para_flags;
para.pvExtraPolicyPara = &mut extra_para as *mut _ as *mut _;
let mut status: wincrypt::CERT_CHAIN_POLICY_STATUS = mem::zeroed();
status.cbSize = mem::size_of_val(&status) as winapi::DWORD;
let verify_chain_policy_structure = wincrypt::CERT_CHAIN_POLICY_SSL as ntdef::LPCSTR;
let res = wincrypt::CertVerifyCertificateChainPolicy(verify_chain_policy_structure,
cert_chain.0,
&mut para,
&mut status);
if res == winapi::FALSE {
return Err(io::Error::last_os_error())
}
let mut verify_result = if status.dwError != winerror::ERROR_SUCCESS {
Err(io::Error::from_raw_os_error(status.dwError as i32))
} else {
Ok(())
};
// check if there's a user-specified verify callback
if let Some(ref callback) = self.verify_callback {
verify_result = callback(CertValidationResult{
chain: cert_chain,
res: status.dwError as i32,
chain_index: status.lChainIndex,
element_index: status.lElementIndex});
}
verify_result?;
}
Ok(true)
}
fn write_out(&mut self) -> io::Result<usize> {
let mut out = 0;
while self.out_buf.position() as usize != self.out_buf.get_ref().len() {
let position = self.out_buf.position() as usize;
let nwritten = self.stream.write(&self.out_buf.get_ref()[position..])?;
out += nwritten;
self.out_buf.set_position((position + nwritten) as u64);
}
Ok(out)
}
fn read_in(&mut self) -> io::Result<usize> {
let mut sum_nread = 0;
while self.needs_read > 0 {
let existing_len = self.enc_in.position() as usize;
let min_len = cmp::max(cmp::max(1024, 2 * existing_len), self.needs_read);
if self.enc_in.get_ref().len() < min_len {
self.enc_in.get_mut().resize(min_len, 0);
}
let nread = {
let buf = &mut self.enc_in.get_mut()[existing_len..];
self.stream.read(buf)?
};
self.enc_in.set_position((existing_len + nread) as u64);
self.needs_read = self.needs_read.saturating_sub(nread);
if nread == 0 {
break;
}
sum_nread += nread;
}
Ok(sum_nread)
}
fn consume_enc_in(&mut self, nread: usize) {
let size = self.enc_in.position() as usize;
assert!(size >= nread);
let count = size - nread;
if count > 0 {
self.enc_in.get_mut().drain(..nread);
}
self.enc_in.set_position(count as u64);
}
fn decrypt(&mut self) -> io::Result<bool> {
unsafe {
let position = self.enc_in.position() as usize;
let mut bufs = [secbuf(sspi::SECBUFFER_DATA,
Some(&mut self.enc_in.get_mut()[..position])),
secbuf(sspi::SECBUFFER_EMPTY, None),
secbuf(sspi::SECBUFFER_EMPTY, None),
secbuf(sspi::SECBUFFER_EMPTY, None)];
let mut bufdesc = secbuf_desc(&mut bufs);
match sspi::DecryptMessage(self.context.get_mut(),
&mut bufdesc,
0,
ptr::null_mut()) {
winerror::SEC_E_OK => {
let start = bufs[1].pvBuffer as usize - self.enc_in.get_ref().as_ptr() as usize;
let end = start + bufs[1].cbBuffer as usize;
self.dec_in.get_mut().clear();
self.dec_in
.get_mut()
.extend_from_slice(&self.enc_in.get_ref()[start..end]);
self.dec_in.set_position(0);
let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
self.enc_in.position() as usize - bufs[3].cbBuffer as usize
} else {
self.enc_in.position() as usize
};
self.consume_enc_in(nread);
self.needs_read = (self.enc_in.position() == 0) as usize;
Ok(false)
}
winerror::SEC_E_INCOMPLETE_MESSAGE => {
self.needs_read = if bufs[1].BufferType == sspi::SECBUFFER_MISSING {
bufs[1].cbBuffer as usize
} else {
1
};
Ok(false)
}
winerror::SEC_I_CONTEXT_EXPIRED => Ok(true),
winerror::SEC_I_RENEGOTIATE => {
self.state = State::Initializing {
needs_flush: false,
more_calls: true,
shutting_down: false,
validated: false,
};
let nread = if bufs[3].BufferType == sspi::SECBUFFER_EXTRA {
self.enc_in.position() as usize - bufs[3].cbBuffer as usize
} else {
self.enc_in.position() as usize
};
self.consume_enc_in(nread);
self.needs_read = 0;
Ok(false)
}
e => Err(io::Error::from_raw_os_error(e as i32)),
}
}
}
fn encrypt(&mut self, buf: &[u8], sizes: &sspi::SecPkgContext_StreamSizes) -> io::Result<()> {
assert!(buf.len() <= sizes.cbMaximumMessage as usize);
unsafe {
let len = sizes.cbHeader as usize + buf.len() + sizes.cbTrailer as usize;
if self.out_buf.get_ref().len() < len {
self.out_buf.get_mut().resize(len, 0);
}
let message_start = sizes.cbHeader as usize;
self.out_buf
.get_mut()[message_start..message_start + buf.len()]
.clone_from_slice(buf);
let mut bufs = {
let out_buf = self.out_buf.get_mut();
let size = sizes.cbHeader as usize;
let header = secbuf(sspi::SECBUFFER_STREAM_HEADER,
Some(&mut out_buf[..size]));
let data = secbuf(sspi::SECBUFFER_DATA,
Some(&mut out_buf[size..size + buf.len()]));
let trailer = secbuf(sspi::SECBUFFER_STREAM_TRAILER,
Some(&mut out_buf[size + buf.len()..]));
let empty = secbuf(sspi::SECBUFFER_EMPTY, None);
[header, data, trailer, empty]
};
let mut bufdesc = secbuf_desc(&mut bufs);
match sspi::EncryptMessage(self.context.get_mut(), 0, &mut bufdesc, 0) {
winerror::SEC_E_OK => {
let len = bufs[0].cbBuffer + bufs[1].cbBuffer + bufs[2].cbBuffer;
self.out_buf.get_mut().truncate(len as usize);
self.out_buf.set_position(0);
Ok(())
}
err => Err(io::Error::from_raw_os_error(err as i32)),
}
}
}
}
impl<S> MidHandshakeTlsStream<S> {
/// Returns a shared reference to the inner stream.
pub fn get_ref(&self) -> &S {
self.inner.get_ref()
}
/// Returns a mutable reference to the inner stream.
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut()
}
}
impl<S> MidHandshakeTlsStream<S>
where S: Read + Write,
{
/// Restarts the handshake process.
pub fn handshake(mut self) -> Result<TlsStream<S>, HandshakeError<S>> {
match self.inner.initialize() {
Ok(_) => Ok(self.inner),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
Err(HandshakeError::Interrupted(self))
}
Err(e) => Err(HandshakeError::Failure(e)),
}
}
}
impl<S> Write for TlsStream<S>
where S: Read + Write
{
/// In the case of a WouldBlock error, we expect another call
/// starting with the same input data
/// This is similar to the use of ACCEPT_MOVING_WRITE_BUFFER in openssl
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let sizes = match self.initialize()? {
Some(sizes) => sizes,
None => return Err(io::Error::from_raw_os_error(winerror::SEC_E_CONTEXT_EXPIRED as i32)),
};
// if we have pending output data, it must have been because a previous
// attempt to send this part of the data ran into an error.
if self.out_buf.position() == self.out_buf.get_ref().len() as u64 {
let len = cmp::min(buf.len(), sizes.cbMaximumMessage as usize);
self.encrypt(&buf[..len], &sizes)?;
self.last_write_len = len;
}
self.write_out()?;
Ok(self.last_write_len)
}
fn flush(&mut self) -> io::Result<()> {
// Make sure the write buffer is emptied
self.write_out()?;
self.stream.flush()
}
}
impl<S> Read for TlsStream<S>
where S: Read + Write
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let nread = {
let read_buf = self.fill_buf()?;
let nread = cmp::min(buf.len(), read_buf.len());
buf[..nread].copy_from_slice(&read_buf[..nread]);
nread
};
self.consume(nread);
Ok(nread)
}
}
impl<S> BufRead for TlsStream<S>
where S: Read + Write
{
fn fill_buf(&mut self) -> io::Result<&[u8]> {
while self.get_buf().is_empty() {
if let None = self.initialize()? {
break;
}
if self.needs_read > 0 {
if self.read_in()? == 0 {
break;
}
self.needs_read = 0;
}
let eof = self.decrypt()?;
if eof {
break;
}
}
Ok(self.get_buf())
}
fn consume(&mut self, amt: usize) {
let pos = self.dec_in.position() + amt as u64;
assert!(pos <= self.dec_in.get_ref().len() as u64);
self.dec_in.set_position(pos);
}
}