Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions array-api-strict-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ array_api_tests/test_data_type_functions.py::test_finfo
array_api_tests/test_data_type_functions.py::test_finfo_dtype
array_api_tests/test_data_type_functions.py::test_iinfo
array_api_tests/test_data_type_functions.py::test_iinfo_dtype


# complex special cases which failed "forever"
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j]
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j]
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j]
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j]
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j]
array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j]
array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j]

array_api_tests/test_special_cases.py::test_unary[sign((real(x_i) is -0 or real(x_i) == +0) and (imag(x_i) is -0 or imag(x_i) == +0)) -> 0 + 0j]
array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j]

array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j]
33 changes: 33 additions & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@ def is_scalar(x):
return isinstance(x, (int, float, complex, bool))


def complex_dtype_for(dtyp):
"""Complex dtype for a float or complex."""
if dtyp in complex_dtypes:
return dtyp
if dtyp not in real_float_dtypes:
raise ValueError(f"no complex dtype to match {dtyp}")

real_name = dtype_to_name[dtyp]
complex_name = {"float32": "complex64", "float64": "complex128"}[real_name]

complex_dtype = _name_to_dtype.get(complex_name, None)
if complex_dtype is None:
raise ValueError(f"no complex dtype to match {dtyp}")
return complex_dtype


def real_dtype_for(dtyp):
"""Real float dtype for a float or complex."""
if dtyp in real_float_dtypes:
return dtyp
if dtyp not in complex_dtypes:
raise ValueError(f"no real float dtype to match {dtyp}")

complex_name = dtype_to_name[dtyp]
real_name = {"complex64": "float32", "complex128": "float64"}[complex_name]

real_dtype = _name_to_dtype.get(real_name, None)
if real_dtype is None:
raise ValueError(f"no real dtype to match {dtyp}")
return real_dtype



def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
dtype_value_pairs = []
for name, value in mapping.items():
Expand Down
Loading