Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,10 +502,13 @@ def parse_complex_value(value_str: str) -> complex:
(nan+nanj)
>>> parse_complex_value('0 + NaN j')
nanj
>>> parse_complex_value('+0 + πj/2')
1.5707963267948966j
>>> parse_complex_value('+infinity + 3πj/4')
(inf+2.356194490192345j)

Handles both "0j" and "0 j" formats with optional spaces.
Handles formats: "A + Bj", "A + B j", "A + πj/N", "A + Nπj/M"
"""
# Handle the format like "+0 + 0j" or "NaN + NaN j"
m = r_complex_value.match(value_str)
if m is None:
raise ParseError(value_str)
Expand All @@ -517,7 +520,16 @@ def parse_complex_value(value_str: str) -> complex:

# Parse imaginary part with its sign
imag_sign = m.group(3)
imag_val_str = m.group(4)
# Group 4 is πj form (e.g., "πj/2"), group 5 is plain form (e.g., "NaN")
if m.group(4): # πj form
imag_val_str_raw = m.group(4)
# Remove 'j' to get coefficient: "πj/2" -> "π/2"
imag_val_str = imag_val_str_raw.replace('j', '')
else: # plain form
imag_val_str_raw = m.group(5)
# Strip trailing 'j' if present: "0j" -> "0"
imag_val_str = imag_val_str_raw[:-1] if imag_val_str_raw.endswith('j') else imag_val_str_raw

imag_val = parse_value(imag_sign + imag_val_str)

return complex(real_val, imag_val)
Expand Down Expand Up @@ -583,16 +595,29 @@ def complex_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[complex]:
return complex_cond, expr, complex_from_dtype


def _check_component_with_tolerance(actual: float, expected: float, allow_any_sign: bool) -> bool:
"""
Helper to check if actual matches expected, with optional sign flexibility and tolerance.
"""
if allow_any_sign and not math.isnan(expected):
return abs(actual) == abs(expected) or math.isclose(abs(actual), abs(expected), abs_tol=0.01)
elif not math.isnan(expected):
check_fn = make_strict_eq(expected) if expected == 0 or math.isinf(expected) else make_rough_eq(expected)
return check_fn(actual)
else:
return math.isnan(actual)


def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], str]:
"""
Parses a complex result string to return a checker and expression.

Handles cases like:
- "``+0 + 0j``" - exact complex value
- "``0 + NaN j`` (sign of the real component is unspecified)" - allow any sign for real
- "``NaN + NaN j``" - both parts NaN
- "``0 + NaN j`` (sign of the real component is unspecified)"
- "``+0 + πj/2``" - with π expressions (uses approximate equality)
"""
# Check for unspecified sign note
# Check for unspecified sign notes
unspecified_real_sign = "sign of the real component is unspecified" in result_str
unspecified_imag_sign = "sign of the imaginary component is unspecified" in result_str

Expand All @@ -601,13 +626,22 @@ def parse_complex_result(result_str: str) -> Tuple[Callable[[complex], bool], st
m = re.search(r"``([^`]+)``", result_str)
if m:
value_str = m.group(1)
# Check if the value contains π expressions (for approximate comparison)
has_pi = 'π' in value_str

try:
expected = parse_complex_value(value_str)
except ParseError:
raise ParseError(result_str)

# Create checker based on whether signs are unspecified
if unspecified_real_sign and not math.isnan(expected.real):
# Create checker based on whether signs are unspecified and whether π is involved
if has_pi:
# Use approximate equality for both real and imaginary parts if they involve π
def check_result(z: complex) -> bool:
real_match = _check_component_with_tolerance(z.real, expected.real, unspecified_real_sign)
imag_match = _check_component_with_tolerance(z.imag, expected.imag, unspecified_imag_sign)
return real_match and imag_match
elif unspecified_real_sign and not math.isnan(expected.real):
# Allow any sign for real part
def check_result(z: complex) -> bool:
imag_check = make_strict_eq(expected.imag)
Expand Down Expand Up @@ -693,9 +727,10 @@ class UnaryCase(Case):
r"For complex floating-point operands, let ``a = real\(x_i\)``, ``b = imag\(x_i\)``"
)
r_complex_case = re.compile(r"If ``a`` is (.+) and ``b`` is (.+), the result is (.+)")
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j"
# Matches complex values like "+0 + 0j", "NaN + NaN j", "infinity + NaN j", "πj/2", "3πj/4"
# Two formats: 1) πj/N expressions where j is part of the coefficient, 2) plain values followed by j
r_complex_value = re.compile(
r"([+-]?)([^\s]+)\s*([+-])\s*([^\s]+)\s*j"
r"([+-]?)([^\s]+)\s*([+-])\s*(?:(\d*πj(?:/\d+)?)|([^\s]+))\s*j?"
)


Expand Down
Loading