blob: 61d7af4010e5423b8521cddb987f800d723cacc8 [file] [log] [blame]
use winapi::shared::{sspi, winerror};
use winapi::shared::minwindef::ULONG;
use winapi::um::{minschannel, schannel};
use std::mem;
use std::ptr;
use std::io;
use crate::{INIT_REQUESTS, Inner, secbuf, secbuf_desc};
use crate::alpn_list::AlpnList;
use crate::cert_context::CertContext;
use crate::context_buffer::ContextBuffer;
use crate::schannel_cred::SchannelCred;
pub struct SecurityContext(sspi::CtxtHandle);
impl Drop for SecurityContext {
fn drop(&mut self) {
unsafe {
sspi::DeleteSecurityContext(&mut self.0);
}
}
}
impl Inner<sspi::CtxtHandle> for SecurityContext {
unsafe fn from_inner(inner: sspi::CtxtHandle) -> SecurityContext {
SecurityContext(inner)
}
fn as_inner(&self) -> sspi::CtxtHandle {
self.0
}
fn get_mut(&mut self) -> &mut sspi::CtxtHandle {
&mut self.0
}
}
impl SecurityContext {
pub fn initialize(cred: &mut SchannelCred,
accept: bool,
domain: Option<&[u16]>,
requested_application_protocols: &Option<Vec<Vec<u8>>>)
-> io::Result<(SecurityContext, Option<ContextBuffer>)> {
unsafe {
let mut ctxt = mem::zeroed();
if accept {
// If we're performing an accept then we need to wait to call
// `AcceptSecurityContext` until we've actually read some data.
return Ok((SecurityContext(ctxt), None))
}
let domain = domain.map(|b| b.as_ptr() as *mut u16).unwrap_or(ptr::null_mut());
let mut inbufs = vec![];
// Make sure `AlpnList` is kept alive for the duration of this function.
let mut alpns = requested_application_protocols.as_ref().map(|alpn| AlpnList::new(&alpn));
if let Some(ref mut alpns) = alpns {
inbufs.push(secbuf(sspi::SECBUFFER_APPLICATION_PROTOCOLS,
Some(&mut alpns[..])));
};
let mut inbuf_desc = secbuf_desc(&mut inbufs[..]);
let mut outbuf = [secbuf(sspi::SECBUFFER_EMPTY, None)];
let mut outbuf_desc = secbuf_desc(&mut outbuf);
let mut attributes = 0;
match sspi::InitializeSecurityContextW(&mut cred.as_inner(),
ptr::null_mut(),
domain,
INIT_REQUESTS,
0,
0,
&mut inbuf_desc,
0,
&mut ctxt,
&mut outbuf_desc,
&mut attributes,
ptr::null_mut()) {
winerror::SEC_I_CONTINUE_NEEDED => {
Ok((SecurityContext(ctxt), Some(ContextBuffer(outbuf[0]))))
}
err => {
Err(io::Error::from_raw_os_error(err as i32))
}
}
}
}
unsafe fn attribute<T>(&self, attr: ULONG) -> io::Result<T> {
let mut value = std::mem::zeroed();
let status = sspi::QueryContextAttributesW(&self.0 as *const _ as *mut _,
attr,
&mut value as *mut _ as *mut _);
if status == winerror::SEC_E_OK {
Ok(value)
} else {
Err(io::Error::from_raw_os_error(status as i32))
}
}
pub fn application_protocol(&self) -> io::Result<sspi::SecPkgContext_ApplicationProtocol> {
unsafe {
self.attribute(sspi::SECPKG_ATTR_APPLICATION_PROTOCOL)
}
}
pub fn session_info(&self) -> io::Result<schannel::SecPkgContext_SessionInfo> {
unsafe {
self.attribute(minschannel::SECPKG_ATTR_SESSION_INFO)
}
}
pub fn stream_sizes(&self) -> io::Result<sspi::SecPkgContext_StreamSizes> {
unsafe {
self.attribute(sspi::SECPKG_ATTR_STREAM_SIZES)
}
}
pub fn remote_cert(&self) -> io::Result<CertContext> {
unsafe {
self.attribute(minschannel::SECPKG_ATTR_REMOTE_CERT_CONTEXT)
.map(|p| CertContext::from_inner(p))
}
}
pub fn local_cert(&self) -> io::Result<CertContext> {
unsafe {
self.attribute(minschannel::SECPKG_ATTR_LOCAL_CERT_CONTEXT)
.map(|p| CertContext::from_inner(p))
}
}
}