diff --git a/array-api-strict-skips.txt b/array-api-strict-skips.txt index afc1b845..10a55fdd 100644 --- a/array-api-strict-skips.txt +++ b/array-api-strict-skips.txt @@ -32,3 +32,18 @@ array_api_tests/test_data_type_functions.py::test_finfo array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo array_api_tests/test_data_type_functions.py::test_iinfo_dtype + + +# complex special cases which failed "forever" +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] + +array_api_tests/test_special_cases.py::test_unary[sign((real(x_i) is -0 or real(x_i) == +0) and (imag(x_i) is -0 or imag(x_i) == +0)) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index f7fa306b..2fe3c1b2 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -199,6 +199,39 @@ def is_scalar(x): return isinstance(x, (int, float, complex, bool)) +def complex_dtype_for(dtyp): + """Complex dtype for a float or complex.""" + if dtyp in complex_dtypes: + return dtyp + if dtyp not in real_float_dtypes: + raise ValueError(f"no complex dtype to match {dtyp}") + + real_name = dtype_to_name[dtyp] + complex_name = {"float32": "complex64", "float64": "complex128"}[real_name] + + complex_dtype = _name_to_dtype.get(complex_name, None) + if complex_dtype is None: + raise ValueError(f"no complex dtype to match {dtyp}") + return complex_dtype + + +def real_dtype_for(dtyp): + """Real float dtype for a float or complex.""" + if dtyp in real_float_dtypes: + return dtyp + if dtyp not in complex_dtypes: + raise ValueError(f"no real float dtype to match {dtyp}") + + complex_name = dtype_to_name[dtyp] + real_name = {"complex64": "float32", "complex128": "float64"}[complex_name] + + real_dtype = _name_to_dtype.get(real_name, None) + if real_dtype is None: + raise ValueError(f"no real dtype to match {dtyp}") + return real_dtype + + + def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: dtype_value_pairs = [] for name, value in mapping.items(): diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index bf05a262..e05fd02d 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -15,6 +15,7 @@ import inspect import math import operator +import os import re from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal @@ -99,7 +100,7 @@ def or_(i: float) -> bool: def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck: def and_(i: float) -> bool: - return cond1(i) or cond2(i) + return cond1(i) and cond2(i) return and_ @@ -492,6 +493,170 @@ def check_result(result: float) -> bool: return check_result, expr +def parse_complex_value(value_str: str) -> complex: + """ + Parses a complex value string to return a complex number, e.g. + + >>> parse_complex_value('+0 + 0j') + 0j + >>> parse_complex_value('NaN + NaN j') + (nan+nanj) + >>> parse_complex_value('0 + NaN j') + nanj + >>> parse_complex_value('+0 + πj/2') + 1.5707963267948966j + >>> parse_complex_value('+infinity + 3πj/4') + (inf+2.356194490192345j) + + Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M" + """ + m = r_complex_value.match(value_str) + if m is None: + raise ParseError(value_str) + + # Parse real part with its sign + real_sign = m.group(1) if m.group(1) else "+" + real_val_str = m.group(2) + real_val = parse_value(real_sign + real_val_str) + + # Parse imaginary part with its sign + imag_sign = m.group(3) + # Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN") + if m.group(4): # πj form + imag_val_str_raw = m.group(4) + # Remove 'j' to get coefficient: "πj/2" -> "π/2" + imag_val_str = imag_val_str_raw.replace('j', '') + else: # plain form + imag_val_str_raw = m.group(5) + # Strip trailing 'j' if present: "0j" -> "0" + imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw + + imag_val = parse_value(imag_sign + imag_val_str) + + return complex(real_val, imag_val) + + +def make_strict_eq_complex(v: complex) -> Callable[[complex], bool]: + """ + Creates a checker for complex values that respects sign of zero and NaN. + """ + real_check = make_strict_eq(v.real) + imag_check = make_strict_eq(v.imag) + + def strict_eq_complex(z: complex) -> bool: + return real_check(z.real) and imag_check(z.imag) + + return strict_eq_complex + + +def parse_complex_cond( + a_cond_str: str, b_cond_str: str +) -> Tuple[Callable[[complex], bool], str, FromDtypeFunc]: + """ + Parses complex condition strings for real (a) and imaginary (b) parts. + + Returns: + - cond: Function that checks if a complex number meets the condition + - expr: String expression for the condition + - from_dtype: Strategy generator for complex numbers meeting the condition + """ + # Parse conditions for real and imaginary parts separately + a_cond, a_expr_template, a_from_dtype = parse_cond(a_cond_str) + b_cond, b_expr_template, b_from_dtype = parse_cond(b_cond_str) + + # Create compound condition + def complex_cond(z: complex) -> bool: + return a_cond(z.real) and b_cond(z.imag) + + # Create expression + a_expr = a_expr_template.replace("{}", "real(x_i)") + b_expr = b_expr_template.replace("{}", "imag(x_i)") + expr = f"{a_expr} and {b_expr}" + + # Create strategy that generates complex numbers + def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]: + assert len(kw) == 0 # sanity check + # For complex dtype, we need to get the corresponding float dtype + # complex64 -> float32, complex128 -> float64 + float_dtype = dh.real_dtype_for(dtype) + + real_strat = a_from_dtype(float_dtype) + imag_strat = b_from_dtype(float_dtype) + return st.builds(complex, real_strat, imag_strat) + + return complex_cond, expr, complex_from_dtype + + +def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool: + """ + Helper to check if actual matches expected, with optional sign flexibility and tolerance. + """ + if allow_any_sign and not math.isnan(expected): + return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01) + elif not math.isnan(expected): + check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected) + return check_fn(actual) + else: + return math.isnan(actual) + + +def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]: + """ + Parses a complex result string to return a checker and expression. + + Handles cases like: + - "``+0 + 0j``" - exact complex value + - "``0 + NaN j`` (sign of the real component is unspecified)" + - "``+0 + πj/2``" - with π expressions (uses approximate equality) + """ + # Check for unspecified sign notes + unspecified_real_sign = "sign of the real component is unspecified" in result_str + unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str + + # Extract the complex value from backticks - need to handle spaces in complex values + # Pattern: ``...`` where ... can contain spaces (for complex values like "0 + NaN j") + m = re.search(r"``([^`]+)``", result_str) + if m: + value_str = m.group(1) + # Check if the value contains π expressions (for approximate comparison) + has_pi = 'π' in value_str + + try: + expected = parse_complex_value(value_str) + except ParseError: + raise ParseError(result_str) + + # Create checker based on whether signs are unspecified and whether π is involved + if has_pi: + # Use approximate equality for both real and imaginary parts if they involve π + def check_result(z: complex) -> bool: + real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign) + imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign) + return real_match and imag_match + elif unspecified_real_sign and not math.isnan(expected.real): + # Allow any sign for real part + def check_result(z: complex) -> bool: + imag_check = make_strict_eq(expected.imag) + return abs(z.real) == abs(expected.real) and imag_check(z.imag) + elif unspecified_imag_sign and not math.isnan(expected.imag): + # Allow any sign for imaginary part + def check_result(z: complex) -> bool: + real_check = make_strict_eq(expected.real) + return real_check(z.real) and abs(z.imag) == abs(expected.imag) + elif unspecified_real_sign and unspecified_imag_sign: + # Allow any sign for both parts + def check_result(z: complex) -> bool: + return abs(z.real) == abs(expected.real) and abs(z.imag) == abs(expected.imag) + else: + # Exact match including signs + check_result = make_strict_eq_complex(expected) + + expr = value_str + return check_result, expr + else: + raise ParseError(result_str) + + class Case(Protocol): cond_expr: str result_expr: str @@ -535,6 +700,7 @@ class UnaryCase(Case): cond: UnaryCheck check_result: UnaryResultCheck raw_case: Optional[str] = field(default=None) + is_complex: bool = field(default=False) r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") @@ -549,6 +715,16 @@ class UnaryCase(Case): "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " "the result is ``(.+)``" ) +# Regex patterns for complex special cases +r_complex_marker = re.compile( + r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``" +) +r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)") +# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4" +# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j +r_complex_value = re.compile( + r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?" +) def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -630,7 +806,15 @@ def check_result(i: float, result: float) -> bool: return check_result -def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]: +def make_complex_unary_check_result(check_fn: Callable[[complex], bool]) -> UnaryResultCheck: + """Wraps a complex check function for use in UnaryCase.""" + def check_result(in_value, out_value): + # in_value is complex, out_value is complex + return check_fn(out_value) + return check_result + + +def parse_unary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[UnaryCase]: """ Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. @@ -677,8 +861,52 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]: """ cases = [] + # Check if the case block contains complex cases by looking for the marker + in_complex_section = r_complex_marker.search(case_block) is not None + for case_m in r_case.finditer(case_block): case_str = case_m.group(1) + + # Record this special case if a record list is provided + if record_list is not None: + record_list.append(f"{func_name}: {case_str}.") + + + # Try to parse complex cases if we're in the complex section + if in_complex_section and (m := r_complex_case.search(case_str)): + try: + a_cond_str = m.group(1) + b_cond_str = m.group(2) + result_str = m.group(3) + + # Skip cases with complex expressions like "cis(b)" + if "cis" in result_str or "*" in result_str: + warn(f"case for {func_name} not machine-readable: '{case_str}'") + continue + + # Parse the complex condition and result + complex_cond, cond_expr, complex_from_dtype = parse_complex_cond( + a_cond_str, b_cond_str + ) + _check_result, result_expr = parse_complex_result(result_str) + + check_result = make_complex_unary_check_result(_check_result) + + case = UnaryCase( + cond_expr=cond_expr, + cond=complex_cond, + cond_from_dtype=complex_from_dtype, + result_expr=result_expr, + check_result=check_result, + raw_case=case_str, + is_complex=True, + ) + cases.append(case) + except ParseError as e: + warn(f"case for {func_name} not machine-readable: '{e.value}'") + continue + + # Parse regular (real-valued) cases if r_already_int_case.search(case_str): cases.append(already_int_case) elif r_even_round_halves_case.search(case_str): @@ -1103,7 +1331,7 @@ def cond(i1: float, i2: float) -> bool: r_redundant_case = re.compile("result.+determined by the rule already stated above") -def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]: +def parse_binary_case_block(case_block: str, func_name: str, record_list: Optional[List[str]] = None) -> List[BinaryCase]: """ Parses a Sphinx-formatted docstring of a binary function to return a list of codified binary cases, e.g. @@ -1145,6 +1373,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) + + # Record this special case if a record list is provided + if record_list is not None: + record_list.append(f"{func_name}: {case_str}.") + if r_redundant_case.search(case_str): continue if r_binary_case.match(case_str): @@ -1162,6 +1395,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] unary_params = [] binary_params = [] iop_params = [] +special_case_records = [] # List of "func_name: case_str" for all special cases func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} for stub in category_to_funcs["elementwise"]: func_name = stub.__name__ @@ -1186,7 +1420,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] warn(f"{func=} has no parameters") continue if param_names[0] == "x": - if cases := parse_unary_case_block(case_block, func_name): + if cases := parse_unary_case_block(case_block, func_name, special_case_records): name_to_func = {func_name: func} if func_name in func_to_op.keys(): op_name = func_to_op[func_name] @@ -1204,7 +1438,7 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_case_block(case_block, func_name): + if cases := parse_binary_case_block(case_block, func_name, special_case_records): name_to_func = {func_name: func} if func_name in func_to_op.keys(): op_name = func_to_op[func_name] @@ -1249,6 +1483,22 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase] assert len(iop_params) != 0 +@pytest.fixture(scope="session", autouse=True) +def emit_special_case_records(): + """Emit all special case records at the start of test session.""" + # This runs once at the beginning of the test session + if os.environ.get('ARRAY_API_TESTS_SPECIAL_CASES_VERBOSE') == '1': + print("\n" + "="*80) + print("SPECIAL CASE RECORDS") + print("="*80) + for record in special_case_records: + print(record) + print("="*80) + print(f"Total special cases: {len(special_case_records)}") + print("="*80 + "\n") + yield # Tests run after this point + + @pytest.mark.parametrize("func_name, func, case", unary_params) def test_unary(func_name, func, case): with catch_warnings(): @@ -1257,10 +1507,24 @@ def test_unary(func_name, func, case): # drawing multiple examples like a normal test, or just hard-coding a # single example test case without using hypothesis. filterwarnings('ignore', category=NonInteractiveExampleWarning) - in_value = case.cond_from_dtype(xp.float64).example() - x = xp.asarray(in_value, dtype=xp.float64) + + # Use the is_complex flag to determine the appropriate dtype + if case.is_complex: + dtype = xp.complex128 + in_value = case.cond_from_dtype(dtype).example() + else: + dtype = xp.float64 + in_value = case.cond_from_dtype(dtype).example() + + # Create array and compute result based on dtype + x = xp.asarray(in_value, dtype=dtype) out = func(x) - out_value = float(out) + + if case.is_complex: + out_value = complex(out) + else: + out_value = float(out) + assert case.check_result(in_value, out_value), ( f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" )