blob: 73333790081cf02dad8fd6d4fbdeac45ba3e8327 [file] [log] [blame]
import os
import pytest
@pytest.fixture
def example():
def _example(name):
with open(
os.path.join(os.path.dirname(__file__), "examples", name + ".toml"),
encoding="utf-8",
) as f:
return f.read()
return _example
@pytest.fixture
def json_example():
def _example(name):
with open(
os.path.join(os.path.dirname(__file__), "examples", "json", name + ".json"),
encoding="utf-8",
) as f:
return f.read()
return _example
@pytest.fixture
def invalid_example():
def _example(name):
with open(
os.path.join(
os.path.dirname(__file__), "examples", "invalid", name + ".toml"
),
encoding="utf-8",
) as f:
return f.read()
return _example
TEST_DIR = os.path.join(os.path.dirname(__file__), "toml-test", "tests")
IGNORED_TESTS = {
"valid": [
"float/inf-and-nan", # Can't compare nan
]
}
def get_tomltest_cases():
dirs = sorted(
f for f in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, f))
)
assert dirs == ["invalid", "valid"]
rv = {"invalid_encode": {}}
for d in dirs:
rv[d] = {}
ignored = IGNORED_TESTS.get(d, [])
for root, _, files in os.walk(os.path.join(TEST_DIR, d)):
relpath = os.path.relpath(root, os.path.join(TEST_DIR, d))
if relpath == ".":
relpath = ""
for f in files:
try:
bn, ext = f.rsplit(".", 1)
except ValueError:
bn, ext = f.rsplit("-", 1)
key = f"{relpath}/{bn}"
if ext == "multi":
continue
if key in ignored:
continue
if d == "invalid" and relpath == "encoding":
rv["invalid_encode"][bn] = os.path.join(root, f)
continue
if key not in rv[d]:
rv[d][key] = {}
with open(os.path.join(root, f), encoding="utf-8") as inp:
rv[d][key][ext] = inp.read()
return rv
def pytest_generate_tests(metafunc):
test_list = get_tomltest_cases()
if "valid_case" in metafunc.fixturenames:
metafunc.parametrize(
"valid_case",
test_list["valid"].values(),
ids=list(test_list["valid"].keys()),
)
elif "invalid_decode_case" in metafunc.fixturenames:
metafunc.parametrize(
"invalid_decode_case",
test_list["invalid"].values(),
ids=list(test_list["invalid"].keys()),
)
elif "invalid_encode_case" in metafunc.fixturenames:
metafunc.parametrize(
"invalid_encode_case",
test_list["invalid_encode"].values(),
ids=list(test_list["invalid_encode"].keys()),
)