| // Copyright 2015 Google Inc. All rights reserved. |
| // |
| // Permission is hereby granted, free of charge, to any person obtaining a copy |
| // of this software and associated documentation files (the "Software"), to deal |
| // in the Software without restriction, including without limitation the rights |
| // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| // copies of the Software, and to permit persons to whom the Software is |
| // furnished to do so, subject to the following conditions: |
| // |
| // The above copyright notice and this permission notice shall be included in |
| // all copies or substantial portions of the Software. |
| // |
| // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| // THE SOFTWARE. |
| |
| //! Utility functions for HTML escaping |
| |
| use std::io; |
| use std::str::from_utf8; |
| |
| use crate::html::StrWrite; |
| |
| #[rustfmt::skip] |
| static HREF_SAFE: [u8; 128] = [ |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
| 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, |
| 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, |
| ]; |
| |
| static HEX_CHARS: &[u8] = b"0123456789ABCDEF"; |
| static AMP_ESCAPE: &str = "&"; |
| static SLASH_ESCAPE: &str = "'"; |
| |
| pub(crate) fn escape_href<W>(mut w: W, s: &str) -> io::Result<()> |
| where |
| W: StrWrite, |
| { |
| let bytes = s.as_bytes(); |
| let mut mark = 0; |
| for i in 0..bytes.len() { |
| let c = bytes[i]; |
| if c >= 0x80 || HREF_SAFE[c as usize] == 0 { |
| // character needing escape |
| |
| // write partial substring up to mark |
| if mark < i { |
| w.write_str(&s[mark..i])?; |
| } |
| match c { |
| b'&' => { |
| w.write_str(AMP_ESCAPE)?; |
| } |
| b'\'' => { |
| w.write_str(SLASH_ESCAPE)?; |
| } |
| _ => { |
| let mut buf = [0u8; 3]; |
| buf[0] = b'%'; |
| buf[1] = HEX_CHARS[((c as usize) >> 4) & 0xF]; |
| buf[2] = HEX_CHARS[(c as usize) & 0xF]; |
| let escaped = from_utf8(&buf).unwrap(); |
| w.write_str(escaped)?; |
| } |
| } |
| mark = i + 1; // all escaped characters are ASCII |
| } |
| } |
| w.write_str(&s[mark..]) |
| } |
| |
| const fn create_html_escape_table() -> [u8; 256] { |
| let mut table = [0; 256]; |
| table[b'"' as usize] = 1; |
| table[b'&' as usize] = 2; |
| table[b'<' as usize] = 3; |
| table[b'>' as usize] = 4; |
| table |
| } |
| |
| static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table(); |
| |
| static HTML_ESCAPES: [&'static str; 5] = ["", """, "&", "<", ">"]; |
| |
| /// Writes the given string to the Write sink, replacing special HTML bytes |
| /// (<, >, &, ") by escape sequences. |
| pub(crate) fn escape_html<W: StrWrite>(w: W, s: &str) -> io::Result<()> { |
| #[cfg(all(target_arch = "x86_64", feature = "simd"))] |
| { |
| simd::escape_html(w, s) |
| } |
| #[cfg(not(all(target_arch = "x86_64", feature = "simd")))] |
| { |
| escape_html_scalar(w, s) |
| } |
| } |
| |
| fn escape_html_scalar<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> { |
| let bytes = s.as_bytes(); |
| let mut mark = 0; |
| let mut i = 0; |
| while i < s.len() { |
| match bytes[i..] |
| .iter() |
| .position(|&c| HTML_ESCAPE_TABLE[c as usize] != 0) |
| { |
| Some(pos) => { |
| i += pos; |
| } |
| None => break, |
| } |
| let c = bytes[i]; |
| let escape = HTML_ESCAPE_TABLE[c as usize]; |
| let escape_seq = HTML_ESCAPES[escape as usize]; |
| w.write_str(&s[mark..i])?; |
| w.write_str(escape_seq)?; |
| i += 1; |
| mark = i; // all escaped characters are ASCII |
| } |
| w.write_str(&s[mark..]) |
| } |
| |
| #[cfg(all(target_arch = "x86_64", feature = "simd"))] |
| mod simd { |
| use crate::html::StrWrite; |
| use std::arch::x86_64::*; |
| use std::io; |
| use std::mem::size_of; |
| |
| const VECTOR_SIZE: usize = size_of::<__m128i>(); |
| |
| pub(crate) fn escape_html<W: StrWrite>(mut w: W, s: &str) -> io::Result<()> { |
| // The SIMD accelerated code uses the PSHUFB instruction, which is part |
| // of the SSSE3 instruction set. Further, we can only use this code if |
| // the buffer is at least one VECTOR_SIZE in length to prevent reading |
| // out of bounds. If either of these conditions is not met, we fall back |
| // to scalar code. |
| if is_x86_feature_detected!("ssse3") && s.len() >= VECTOR_SIZE { |
| let bytes = s.as_bytes(); |
| let mut mark = 0; |
| |
| unsafe { |
| foreach_special_simd(bytes, 0, |i| { |
| let escape_ix = *bytes.get_unchecked(i) as usize; |
| let replacement = |
| super::HTML_ESCAPES[super::HTML_ESCAPE_TABLE[escape_ix] as usize]; |
| w.write_str(&s.get_unchecked(mark..i))?; |
| mark = i + 1; // all escaped characters are ASCII |
| w.write_str(replacement) |
| })?; |
| w.write_str(&s.get_unchecked(mark..)) |
| } |
| } else { |
| super::escape_html_scalar(w, s) |
| } |
| } |
| |
| /// Creates the lookup table for use in `compute_mask`. |
| const fn create_lookup() -> [u8; 16] { |
| let mut table = [0; 16]; |
| table[(b'<' & 0x0f) as usize] = b'<'; |
| table[(b'>' & 0x0f) as usize] = b'>'; |
| table[(b'&' & 0x0f) as usize] = b'&'; |
| table[(b'"' & 0x0f) as usize] = b'"'; |
| table[0] = 0b0111_1111; |
| table |
| } |
| |
| #[target_feature(enable = "ssse3")] |
| /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant) |
| /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes |
| /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte |
| /// at `offset + 3`. It is only safe to call this function when |
| /// `bytes.len() >= offset + VECTOR_SIZE`. |
| unsafe fn compute_mask(bytes: &[u8], offset: usize) -> i32 { |
| debug_assert!(bytes.len() >= offset + VECTOR_SIZE); |
| |
| let table = create_lookup(); |
| let lookup = _mm_loadu_si128(table.as_ptr() as *const __m128i); |
| let raw_ptr = bytes.as_ptr().offset(offset as isize) as *const __m128i; |
| |
| // Load the vector from memory. |
| let vector = _mm_loadu_si128(raw_ptr); |
| // We take the least significant 4 bits of every byte and use them as indices |
| // to map into the lookup vector. |
| // Note that shuffle maps bytes with their most significant bit set to lookup[0]. |
| // Bytes that share their lower nibble with an HTML special byte get mapped to that |
| // corresponding special byte. Note that all HTML special bytes have distinct lower |
| // nibbles. Other bytes either get mapped to 0 or 127. |
| let expected = _mm_shuffle_epi8(lookup, vector); |
| // We compare the original vector to the mapped output. Bytes that shared a lower |
| // nibble with an HTML special byte match *only* if they are that special byte. Bytes |
| // that have either a 0 lower nibble or their most significant bit set were mapped to |
| // 127 and will hence never match. All other bytes have non-zero lower nibbles but |
| // were mapped to 0 and will therefore also not match. |
| let matches = _mm_cmpeq_epi8(expected, vector); |
| |
| // Translate matches to a bitmask, where every 1 corresponds to a HTML special character |
| // and a 0 is a non-HTML byte. |
| _mm_movemask_epi8(matches) |
| } |
| |
| /// Calls the given function with the index of every byte in the given byteslice |
| /// that is either ", &, <, or > and for no other byte. |
| /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may |
| /// occur otherwise. |
| #[target_feature(enable = "ssse3")] |
| unsafe fn foreach_special_simd<F>( |
| bytes: &[u8], |
| mut offset: usize, |
| mut callback: F, |
| ) -> io::Result<()> |
| where |
| F: FnMut(usize) -> io::Result<()>, |
| { |
| // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16) |
| // bytes at a time starting at the given offset. For each chunk, we compute a |
| // a bitmask indicating whether the corresponding byte is a HTML special byte. |
| // We then iterate over all the 1 bits in this mask and call the callback function |
| // with the corresponding index in the buffer. |
| // When the number of HTML special bytes in the buffer is relatively low, this |
| // allows us to quickly go through the buffer without a lookup and for every |
| // single byte. |
| |
| debug_assert!(bytes.len() >= VECTOR_SIZE); |
| let upperbound = bytes.len() - VECTOR_SIZE; |
| while offset < upperbound { |
| let mut mask = compute_mask(bytes, offset); |
| while mask != 0 { |
| let ix = mask.trailing_zeros(); |
| callback(offset + ix as usize)?; |
| mask ^= mask & -mask; |
| } |
| offset += VECTOR_SIZE; |
| } |
| |
| // Final iteration. We align the read with the end of the slice and |
| // shift off the bytes at start we have already scanned. |
| let mut mask = compute_mask(bytes, upperbound); |
| mask >>= offset - upperbound; |
| while mask != 0 { |
| let ix = mask.trailing_zeros(); |
| callback(offset + ix as usize)?; |
| mask ^= mask & -mask; |
| } |
| Ok(()) |
| } |
| |
| #[cfg(test)] |
| mod html_scan_tests { |
| #[test] |
| fn multichunk() { |
| let mut vec = Vec::new(); |
| unsafe { |
| super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix| { |
| Ok(vec.push(ix)) |
| }) |
| .unwrap(); |
| } |
| assert_eq!(vec, vec![0, 14, 15, 19]); |
| } |
| |
| // only match these bytes, and when we match them, match them VECTOR_SIZE times |
| #[test] |
| fn only_right_bytes_matched() { |
| for b in 0..255u8 { |
| let right_byte = b == b'&' || b == b'<' || b == b'>' || b == b'"'; |
| let vek = vec![b; super::VECTOR_SIZE]; |
| let mut match_count = 0; |
| unsafe { |
| super::foreach_special_simd(&vek, 0, |_| { |
| match_count += 1; |
| Ok(()) |
| }) |
| .unwrap(); |
| } |
| assert!((match_count > 0) == (match_count == super::VECTOR_SIZE)); |
| assert_eq!( |
| (match_count == super::VECTOR_SIZE), |
| right_byte, |
| "match_count: {}, byte: {:?}", |
| match_count, |
| b as char |
| ); |
| } |
| } |
| } |
| } |