// Copyright 2017 The Wuffs Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

pub status "?bad code"

pri status "?internal error: inconsistent I/O"

// TODO: move bulk data buffers like decoder.suffixes or decoder.output into
// the workbuf? The first attempt at this was a performance regression for
// decoding all but the smallest GIFs. See these git commits for numbers:
//  - 49627b4 Flatten the lzw.decoder.suffixes array
//  - f877fb2 Use the workbuf instead of lzw.decoder.suffixes
//  - 85be5b9 Delete the obsolete lzw.decoder.suffixes array
// and the roll back has combined numbers:
//  - 3056a84 Roll back 3 recent lzw.decoder.suffixes commits
pub const decoder_workbuf_len_max_incl_worst_case base.u64 = 0

pub struct decoder?(
	// set_literal_width_arg is the saved argument passed to set_literal_width.
	// This field is copied to the literal_width field at the start of
	// decode_io_writer. During that method, calling set_literal_width will
	// change set_literal_width_arg but not literal_width.
	set_literal_width_arg base.u32[..8],

	// read_from state that does not change during a decode call.
	literal_width base.u32[..8],
	clear_code    base.u32[..256],
	end_code      base.u32[..257],

	// read_from state that does change during a decode call.
	save_code base.u32[..4096],
	prev_code base.u32[..4095],
	width     base.u32[..12],
	bits      base.u32,
	n_bits    base.u32[..31],
	output_ri base.u32[..8191],
	output_wi base.u32[..8191],

	// read_from return value. The read_from method effectively returns a
	// base.u32 to show how decode should continue after calling write_to. That
	// value needs to be saved across write_to's possible suspension, so we
	// might as well save it explicitly as a decoder field.
	read_from_return_value base.u32,

	util base.utility,
)(
	// read_from per-code state.
	suffixes array[4096] array[8] base.u8,
	prefixes array[4096] base.u16[..4095],
	// lm1s is the "length minus 1"s of the values for the implicit key-value
	// table in this decoder. See std/lzw/README.md for more detail.
	lm1s array[4096] base.u16[..4095],

	// output[output_ri:output_wi] is the buffered output, connecting read_from
	// with write_to and flush.
	output array[8192 + 7] base.u8,
)

pub func decoder.set_literal_width!(lw base.u32[2..8]) {
	this.set_literal_width_arg = args.lw
}

pub func decoder.workbuf_len() base.range_ii_u64 {
	return this.util.make_range_ii_u64(min_incl:0, max_incl:0)
}

pub func decoder.decode_io_writer?(dst base.io_writer, src base.io_reader, workbuf slice base.u8) {
	var i base.u32[..8191]

	// Initialize read_from state.
	this.literal_width = 8
	if this.set_literal_width_arg >= 2 {
		this.literal_width = this.set_literal_width_arg
	}
	this.clear_code = (1 as base.u32) << this.literal_width
	this.end_code = this.clear_code + 1
	this.save_code = this.end_code
	this.prev_code = this.end_code
	this.width = this.literal_width + 1
	this.bits = 0
	this.n_bits = 0
	this.output_ri = 0
	this.output_wi = 0
	i = 0
	while i < this.clear_code {
		assert i < 256 via "a < b: a < c; c <= b"(c:this.clear_code)
		this.lm1s[i] = 0
		this.suffixes[i][0] = i as base.u8
		i += 1
	}

	while true {
		this.read_from!(src:args.src)

		if this.output_wi > 0 {
			this.write_to?(dst:args.dst)
		}

		if this.read_from_return_value == 0 {
			break
		} else if this.read_from_return_value == 1 {
			continue
		} else if this.read_from_return_value == 2 {
			yield base."$short read"
		} else if this.read_from_return_value == 3 {
			return "?bad code"
		} else {
			return "?internal error: inconsistent I/O"
		}
	}
}

pri func decoder.read_from!(src base.io_reader) {
	var clear_code base.u32[..256]
	var end_code   base.u32[..257]

	var save_code base.u32[..4096]
	var prev_code base.u32[..4095]
	var width     base.u32[..12]
	var bits      base.u32
	var n_bits    base.u32[..31]
	var output_wi base.u32[..8191]

	var code       base.u32[..4095]
	var c          base.u32[..4095]
	var o          base.u32[..8191]
	var steps      base.u32
	var first_byte base.u8
	var lm1_b      base.u16[..4095]
	var lm1_a      base.u16[..4095]

	clear_code = this.clear_code
	end_code = this.end_code

	save_code = this.save_code
	prev_code = this.prev_code
	width = this.width
	bits = this.bits
	n_bits = this.n_bits
	output_wi = this.output_wi

	while true {
		if n_bits < width {
			assert n_bits < 12 via "a < b: a < c; c <= b"(c:width)
			if args.src.available() >= 4 {
				// Read 4 bytes, using the "Variant 4" technique of
				// https://fgiesen.wordpress.com/2018/02/20/reading-bits-in-far-too-many-ways-part-2/
				bits |= args.src.peek_u32le() ~mod<< n_bits
				args.src.skip_fast!(actual:(31 - n_bits) >> 3, worst_case:3)
				n_bits |= 24
				assert width <= n_bits via "a <= b: a <= c; c <= b"(c:12)
				assert n_bits >= width via "a >= b: b <= a"()
			} else if args.src.available() <= 0 {
				this.read_from_return_value = 2
				break
			} else {
				bits |= args.src.peek_u8_as_u32() << n_bits
				args.src.skip_fast!(actual:1, worst_case:1)
				n_bits += 8
				if n_bits >= width {
					// No-op.
				} else if args.src.available() <= 0 {
					this.read_from_return_value = 2
					break
				} else {
					bits |= args.src.peek_u8_as_u32() << n_bits
					args.src.skip_fast!(actual:1, worst_case:1)
					n_bits += 8
					assert width <= n_bits via "a <= b: a <= c; c <= b"(c:12)
					assert n_bits >= width via "a >= b: b <= a"()

					// This if condition is always false, but for some unknown
					// reason, removing it worsens the benchmarks slightly.
					if n_bits < width {
						this.read_from_return_value = 4
						break
					}
				}
			}
		}

		code = bits.low_bits(n:width)
		bits >>= width
		n_bits -= width

		if code < clear_code {
			assert code < 256 via "a < b: a < c; c <= b"(c:clear_code)
			this.output[output_wi] = code as base.u8
			output_wi = (output_wi + 1) & 8191
			if save_code <= 4095 {
				lm1_a = (this.lm1s[prev_code] + 1) & 4095
				this.lm1s[save_code] = lm1_a

				if (lm1_a % 8) != 0 {
					this.prefixes[save_code] = this.prefixes[prev_code]
					this.suffixes[save_code] = this.suffixes[prev_code]
					this.suffixes[save_code][lm1_a % 8] = code as base.u8
				} else {
					this.prefixes[save_code] = prev_code as base.u16
					this.suffixes[save_code][0] = code as base.u8
				}

				save_code += 1
				if width < 12 {
					width += 1 & (save_code >> width)
				}
				prev_code = code
			}

		} else if code <= end_code {
			if code == end_code {
				this.read_from_return_value = 0
				break
			}
			save_code = end_code
			prev_code = end_code
			width = this.literal_width + 1

		} else if code <= save_code {
			c = code
			if code == save_code {
				c = prev_code
			}

			// Letting old_wi and new_wi denote the values of output_wi before
			// and after these two lines of code, the decoded bytes will be
			// written to output[old_wi:new_wi]. They will be written
			// back-to-front, 8 bytes at a time, starting by writing
			// output[o:o + 8], which will contain output[new_wi - 1].
			//
			// In the special case that code == save_code, the decoded bytes
			// contain an extra copy (at the end) of the first byte, and will
			// be written to output[old_wi:new_wi + 1].
			o = (output_wi + ((this.lm1s[c] as base.u32) & 0xFFFFFFF8)) & 8191
			output_wi = (output_wi + 1 + (this.lm1s[c] as base.u32)) & 8191

			steps = (this.lm1s[c] as base.u32) >> 3
			while true {
				assert o <= (o + 8) via "a <= (a + b): 0 <= b"(b:8)

				// The final "8" is redundant semantically, but helps the
				// wuffs-c code generator recognize that both slices have the
				// same constant length, and hence produce efficient C code.
				this.output[o:o + 8].copy_from_slice!(s:this.suffixes[c][:8])

				if steps <= 0 {
					break
				}
				steps -= 1

				// This line is essentially "o -= 8". The "& 8191" is a no-op
				// in practice, but is necessary for the overflow checker.
				o = (o ~mod- 8) & 8191
				c = this.prefixes[c] as base.u32
			}
			first_byte = this.suffixes[c][0]

			if code == save_code {
				this.output[output_wi] = first_byte
				output_wi = (output_wi + 1) & 8191
			}

			if save_code <= 4095 {
				lm1_b = (this.lm1s[prev_code] + 1) & 4095
				this.lm1s[save_code] = lm1_b

				if (lm1_b % 8) != 0 {
					this.prefixes[save_code] = this.prefixes[prev_code]
					this.suffixes[save_code] = this.suffixes[prev_code]
					this.suffixes[save_code][lm1_b % 8] = first_byte
				} else {
					this.prefixes[save_code] = prev_code as base.u16
					this.suffixes[save_code][0] = first_byte as base.u8
				}

				save_code += 1
				if width < 12 {
					width += 1 & (save_code >> width)
				}
				prev_code = code
			}

		} else {
			this.read_from_return_value = 3
			break
		}

		// Flush the output if it could be too full to contain the entire
		// decoding of the next code. The longest possible decoding is slightly
		// less than 4096 and output's length is 8192, so a conservative
		// threshold is ensuring that output_wi <= 4095.
		if output_wi > 4095 {
			this.read_from_return_value = 1
			break
		}
	}

	// Rewind args.src, if we're not in "$short read" and we've read too many
	// bits.
	if this.read_from_return_value != 2 {
		while n_bits >= 8 {
			n_bits -= 8
			if args.src.can_undo_byte() {
				args.src.undo_byte!()
			} else {
				this.read_from_return_value = 4
				break
			}
		}
	}

	this.save_code = save_code
	this.prev_code = prev_code
	this.width = width
	this.bits = bits
	this.n_bits = n_bits
	this.output_wi = output_wi
}

pri func decoder.write_to?(dst base.io_writer) {
	var s slice base.u8
	var n base.u64

	while this.output_wi > 0 {
		if this.output_ri > this.output_wi {
			return "?internal error: inconsistent I/O"
		}
		s = this.output[this.output_ri:this.output_wi]
		n = args.dst.copy_from_slice!(s:s)
		if n == s.length() {
			this.output_ri = 0
			this.output_wi = 0
			return ok
		}
		this.output_ri = (this.output_ri ~mod+ ((n & 0xFFFFFFFF) as base.u32)) & 8191
		yield base."$short write"
	}
}

pub func decoder.flush!() slice base.u8 {
	var s slice base.u8

	if this.output_ri <= this.output_wi {
		s = this.output[this.output_ri:this.output_wi]
	}
	this.output_ri = 0
	this.output_wi = 0
	return s
}
