| # Copyright (c) 2020 Valve Corporation |
| # |
| # SPDX-License-Identifier: MIT |
| |
| import re |
| import sys |
| import os.path |
| import struct |
| import string |
| import copy |
| from math import floor |
| |
| if os.isatty(sys.stdout.fileno()): |
| set_red = "\033[31m" |
| set_green = "\033[1;32m" |
| set_normal = "\033[0m" |
| else: |
| set_red = '' |
| set_green = '' |
| set_normal = '' |
| |
| initial_code = ''' |
| import re |
| |
| def insert_code(code): |
| insert_queue.append(CodeCheck(code, current_position)) |
| |
| def insert_pattern(pattern): |
| insert_queue.append(PatternCheck(pattern, False, current_position)) |
| |
| def vector_gpr(prefix, name, size, align): |
| insert_code(f'{name} = {name}0') |
| for i in range(size): |
| insert_code(f'{name}{i} = {name}0 + {i}') |
| insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}') |
| insert_code(f'success = {name}0 % {align} == 0') |
| return f'{prefix}[#{name}0:#{name}{size - 1}]' |
| |
| def sgpr_vector(name, size, align): |
| return vector_gpr('s', name, size, align) |
| |
| funcs.update({ |
| 's64': lambda name: vector_gpr('s', name, 2, 2), |
| 's96': lambda name: vector_gpr('s', name, 3, 2), |
| 's128': lambda name: vector_gpr('s', name, 4, 4), |
| 's256': lambda name: vector_gpr('s', name, 8, 4), |
| 's512': lambda name: vector_gpr('s', name, 16, 4), |
| }) |
| for i in range(2, 14): |
| funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1) |
| |
| def _match_func(names): |
| for name in names.split(' '): |
| insert_code(f'funcs["{name}"] = lambda _: {name}') |
| return ' '.join(f'${name}' for name in names.split(' ')) |
| |
| funcs['match_func'] = _match_func |
| |
| def search_re(pattern): |
| global success |
| success = re.search(pattern, output.read_line()) != None and success |
| |
| ''' |
| |
| class Check: |
| def __init__(self, data, position): |
| self.data = data.rstrip() |
| self.position = position |
| |
| def run(self, state): |
| pass |
| |
| class CodeCheck(Check): |
| def run(self, state): |
| indent = 0 |
| first_line = [l for l in self.data.split('\n') if l.strip() != ''][0] |
| indent_amount = len(first_line) - len(first_line.lstrip()) |
| indent = first_line[:indent_amount] |
| new_lines = [] |
| for line in self.data.split('\n'): |
| if line.strip() == '': |
| new_lines.append('') |
| continue |
| if line[:indent_amount] != indent: |
| state.result.log += 'unexpected indent in code check:\n' |
| state.result.log += self.data + '\n' |
| return False |
| new_lines.append(line[indent_amount:]) |
| code = '\n'.join(new_lines) |
| |
| try: |
| exec(code, state.g) |
| state.result.log += state.g['log'] |
| state.g['log'] = '' |
| except BaseException as e: |
| state.result.log += 'code check at %s raised exception:\n' % self.position |
| state.result.log += code + '\n' |
| state.result.log += str(e) |
| return False |
| if not state.g['success']: |
| state.result.log += 'code check at %s failed:\n' % self.position |
| state.result.log += code + '\n' |
| return False |
| return True |
| |
| class StringStream: |
| class Pos: |
| def __init__(self): |
| self.line = 1 |
| self.column = 1 |
| |
| def __init__(self, data, name): |
| self.name = name |
| self.data = data |
| self.offset = 0 |
| self.pos = StringStream.Pos() |
| |
| def reset(self): |
| self.offset = 0 |
| self.pos = StringStream.Pos() |
| |
| def peek(self, num=1): |
| return self.data[self.offset:self.offset+num] |
| |
| def peek_test(self, chars): |
| c = self.peek(1) |
| return c != '' and c in chars |
| |
| def read(self, num=4294967296): |
| res = self.peek(num) |
| self.offset += len(res) |
| for c in res: |
| if c == '\n': |
| self.pos.line += 1 |
| self.pos.column = 1 |
| else: |
| self.pos.column += 1 |
| return res |
| |
| def get_line(self, num): |
| return self.data.split('\n')[num - 1].rstrip() |
| |
| def read_line(self): |
| line = '' |
| while self.peek(1) not in ['\n', '']: |
| line += self.read(1) |
| self.read(1) |
| return line |
| |
| def skip_whitespace(self, inc_line): |
| chars = [' ', '\t'] + (['\n'] if inc_line else []) |
| while self.peek(1) in chars: |
| self.read(1) |
| |
| def get_number(self): |
| num = '' |
| while self.peek() in string.digits: |
| num += self.read(1) |
| return num |
| |
| def check_identifier(self): |
| return self.peek_test(string.ascii_letters + '_') |
| |
| def get_identifier(self): |
| res = '' |
| if self.check_identifier(): |
| while self.peek_test(string.ascii_letters + string.digits + '_'): |
| res += self.read(1) |
| return res |
| |
| def format_error_lines(at, line_num, column_num, ctx, line): |
| pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx) |
| return [pred + line + '"', |
| '-' * (column_num - 1 + len(pred)) + '^'] |
| |
| class MatchResult: |
| def __init__(self, pattern): |
| self.success = True |
| self.func_res = None |
| self.pattern = pattern |
| self.pattern_pos = StringStream.Pos() |
| self.output_pos = StringStream.Pos() |
| self.fail_message = '' |
| |
| def set_pos(self, pattern, output): |
| self.pattern_pos.line = pattern.pos.line |
| self.pattern_pos.column = pattern.pos.column |
| self.output_pos.line = output.pos.line |
| self.output_pos.column = output.pos.column |
| |
| def fail(self, msg): |
| self.success = False |
| self.fail_message = msg |
| |
| def format_pattern_pos(self): |
| pat_pos = self.pattern_pos |
| pat_line = self.pattern.get_line(pat_pos.line) |
| res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line) |
| func_res = self.func_res |
| while func_res: |
| pat_pos = func_res.pattern_pos |
| pat_line = func_res.pattern.get_line(pat_pos.line) |
| res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line) |
| func_res = func_res.func_res |
| return '\n'.join(res) |
| |
| def do_match(g, pattern, output, skip_lines, in_func=False): |
| assert(not in_func or not skip_lines) |
| |
| if not in_func: |
| output.skip_whitespace(False) |
| pattern.skip_whitespace(False) |
| |
| old_g = copy.copy(g) |
| old_g_keys = list(g.keys()) |
| res = MatchResult(pattern) |
| escape = False |
| while True: |
| res.set_pos(pattern, output) |
| |
| c = pattern.read(1) |
| fail = False |
| if c == '': |
| if not in_func: |
| while output.peek() in [' ', '\t']: |
| output.read(1) |
| if output.read(1) not in ['', '\n']: |
| res.fail('expected end of output') |
| |
| if res.success: |
| break |
| elif output.peek() == '': |
| res.fail('unexpected end of output') |
| elif c == '\\': |
| escape = True |
| continue |
| elif c == '\n': |
| old_line = output.pos.line |
| output.skip_whitespace(True) |
| if output.pos.line == old_line: |
| res.fail('expected newline in output') |
| elif not escape and c == '#': |
| num = output.get_number() |
| if num == '': |
| res.fail('expected number in output') |
| elif pattern.check_identifier(): |
| name = pattern.get_identifier() |
| if name in g and int(num) != g[name]: |
| res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name])) |
| elif name != '_': |
| g[name] = int(num) |
| elif not escape and c == '$': |
| name = pattern.get_identifier() |
| |
| val = '' |
| while not output.peek_test(string.whitespace): |
| val += output.read(1) |
| |
| if name in g and val != g[name]: |
| res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name])) |
| elif name != '_': |
| g[name] = val |
| elif not escape and c == '%' and pattern.check_identifier(): |
| if output.read(1) != '%': |
| res.fail('expected \'%\' in output') |
| else: |
| num = output.get_number() |
| if num == '': |
| res.fail('expected number in output') |
| else: |
| name = pattern.get_identifier() |
| if name in g and int(num) != g[name]: |
| res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name])) |
| elif name != '_': |
| g[name] = int(num) |
| elif not escape and c == '@' and pattern.check_identifier(): |
| name = pattern.get_identifier() |
| args = '' |
| if pattern.peek_test('('): |
| pattern.read(1) |
| while pattern.peek() not in ['', ')']: |
| args += pattern.read(1) |
| assert(pattern.read(1) == ')') |
| func_res = g['funcs'][name](args) |
| match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True) |
| if not match_res.success: |
| res.func_res = match_res |
| res.output_pos = match_res.output_pos |
| res.fail(match_res.fail_message) |
| elif not escape and c == ' ': |
| while pattern.peek_test(' '): |
| pattern.read(1) |
| |
| read_whitespace = False |
| while output.peek_test(' \t'): |
| output.read(1) |
| read_whitespace = True |
| if not read_whitespace: |
| res.fail('expected whitespace in output, got %r' % (output.peek(1))) |
| else: |
| outc = output.peek(1) |
| if outc != c: |
| res.fail('expected %r in output, got %r' % (c, outc)) |
| else: |
| output.read(1) |
| if not res.success: |
| if skip_lines and output.peek() != '': |
| g.clear() |
| g.update(old_g) |
| res.success = True |
| output.read_line() |
| pattern.reset() |
| output.skip_whitespace(False) |
| pattern.skip_whitespace(False) |
| else: |
| return res |
| |
| escape = False |
| |
| return res |
| |
| class PatternCheck(Check): |
| def __init__(self, data, search, position): |
| Check.__init__(self, data, position) |
| self.search = search |
| |
| def run(self, state): |
| pattern_stream = StringStream(self.data.rstrip(), 'pattern') |
| res = do_match(state.g, pattern_stream, state.g['output'], self.search) |
| if not res.success: |
| state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message) |
| state.result.log += res.format_pattern_pos() + '\n\n' |
| if not self.search: |
| out_line = state.g['output'].get_line(res.output_pos.line) |
| state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line)) |
| else: |
| state.result.log += 'output was:\n' |
| state.result.log += state.g['output'].data.rstrip() + '\n' |
| return False |
| return True |
| |
| class CheckState: |
| def __init__(self, result, variant, checks, output): |
| self.result = result |
| self.variant = variant |
| self.checks = checks |
| |
| self.checks.insert(0, CodeCheck(initial_code, None)) |
| self.insert_queue = [] |
| |
| self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue, |
| 'variant': variant, 'log': '', 'output': StringStream(output, 'output'), |
| 'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck, |
| 'current_position': ''} |
| |
| class TestResult: |
| def __init__(self, expected): |
| self.result = '' |
| self.expected = expected |
| self.log = '' |
| |
| def check_output(result, variant, checks, output): |
| state = CheckState(result, variant, checks, output) |
| |
| while len(state.checks): |
| check = state.checks.pop(0) |
| state.current_position = check.position |
| if not check.run(state): |
| result.result = 'failed' |
| return |
| |
| for check in state.insert_queue[::-1]: |
| state.checks.insert(0, check) |
| state.insert_queue.clear() |
| |
| result.result = 'passed' |
| return |
| |
| def parse_check(variant, line, checks, pos): |
| if line.startswith(';'): |
| line = line[1:] |
| if len(checks) and isinstance(checks[-1], CodeCheck): |
| checks[-1].data += '\n' + line |
| else: |
| checks.append(CodeCheck(line, pos)) |
| elif line.startswith('!'): |
| checks.append(PatternCheck(line[1:], False, pos)) |
| elif line.startswith('>>'): |
| checks.append(PatternCheck(line[2:], True, pos)) |
| elif line.startswith('~'): |
| end = len(line) |
| start = len(line) |
| for c in [';', '!', '>>']: |
| if line.find(c) != -1 and line.find(c) < end: |
| end = line.find(c) |
| if end != len(line): |
| match = re.match(line[1:end], variant) |
| if match and match.end() == len(variant): |
| parse_check(variant, line[end:], checks, pos) |
| |
| def parse_test_source(test_name, variant, fname): |
| in_test = False |
| test = [] |
| expected_result = 'passed' |
| line_num = 1 |
| for line in open(fname, 'r').readlines(): |
| if line.startswith('BEGIN_TEST(%s)' % test_name): |
| in_test = True |
| elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name): |
| in_test = True |
| expected_result = 'todo' |
| elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name): |
| in_test = True |
| expected_result = 'failed' |
| elif line.startswith('END_TEST'): |
| in_test = False |
| elif in_test: |
| test.append((line_num, line.strip())) |
| line_num += 1 |
| |
| checks = [] |
| for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]: |
| parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1])) |
| |
| return checks, expected_result |
| |
| def parse_and_check_test(test_name, variant, test_file, output, current_result): |
| checks, expected = parse_test_source(test_name, variant, test_file) |
| |
| result = TestResult(expected) |
| if len(checks) == 0: |
| result.result = 'empty' |
| result.log = 'no checks found' |
| elif current_result != None: |
| result.result, result.log = current_result |
| else: |
| check_output(result, variant, checks, output) |
| if result.result == 'failed' and expected == 'todo': |
| result.result = 'todo' |
| |
| return result |
| |
| def print_results(results, output, expected): |
| results = {name: result for name, result in results.items() if result.result == output} |
| results = {name: result for name, result in results.items() if (result.result == result.expected) == expected} |
| |
| if not results: |
| return 0 |
| |
| print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected')) |
| for test, result in results.items(): |
| color = '' if expected else set_red |
| print(' %s%s%s' % (color, test, set_normal)) |
| if result.log.strip() != '': |
| for line in result.log.rstrip().split('\n'): |
| print(' ' + line.rstrip()) |
| print('') |
| |
| return len(results) |
| |
| def get_cstr(fp): |
| res = b'' |
| while True: |
| c = fp.read(1) |
| if c == b'\x00': |
| return res.decode('utf-8') |
| else: |
| res += c |
| |
| if __name__ == "__main__": |
| results = {} |
| |
| stdin = sys.stdin.buffer |
| while True: |
| packet_type = stdin.read(4) |
| if packet_type == b'': |
| break; |
| |
| test_name = get_cstr(stdin) |
| test_variant = get_cstr(stdin) |
| if test_variant != '': |
| full_name = test_name + '/' + test_variant |
| else: |
| full_name = test_name |
| |
| test_source_file = get_cstr(stdin) |
| current_result = None |
| if ord(stdin.read(1)): |
| current_result = (get_cstr(stdin), get_cstr(stdin)) |
| code_size = struct.unpack("=L", stdin.read(4))[0] |
| code = stdin.read(code_size).decode('utf-8') |
| |
| results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result) |
| |
| result_types = ['passed', 'failed', 'todo', 'empty'] |
| num_expected = 0 |
| num_unexpected = 0 |
| for t in result_types: |
| num_expected += print_results(results, t, True) |
| for t in result_types: |
| num_unexpected += print_results(results, t, False) |
| num_expected_skipped = print_results(results, 'skipped', True) |
| num_unexpected_skipped = print_results(results, 'skipped', False) |
| |
| num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped |
| color = set_red if num_unexpected else set_green |
| print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal)) |
| if num_unexpected_skipped: |
| print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal)) |
| |
| if num_unexpected: |
| sys.exit(1) |