diff --git a/python/pyspark/ml/tests/connect/test_parity_classification.py b/python/pyspark/ml/tests/connect/test_parity_classification.py index 7805546dba70..3c7e8ff71a2d 100644 --- a/python/pyspark/ml/tests/connect/test_parity_classification.py +++ b/python/pyspark/ml/tests/connect/test_parity_classification.py @@ -21,8 +21,6 @@ from pyspark.testing.connectutils import ReusedConnectTestCase -# TODO(SPARK-52764): Re-enable this test after fixing the flakiness. -@unittest.skip("Disabled due to flakiness, should be enabled after fixing the issue") class ClassificationParityTests(ClassificationTestsMixin, ReusedConnectTestCase): pass diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 5849f4a6dad8..f249b0070090 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -15,6 +15,7 @@ # limitations under the License. # +import asyncio import json import logging import os @@ -154,12 +155,16 @@ def add_ref(self) -> None: self._ref_count += 1 def release_ref(self) -> None: + should_del = False with self._lock: assert self._ref_count > 0 self._ref_count -= 1 if self._ref_count == 0: - # Delete the model if possible - del_remote_cache(self.ref_id) + should_del = True + + if should_del: + # Delete the model if possible + asyncio.run(del_remote_cache(self.ref_id)) def __str__(self) -> str: return self.ref_id @@ -348,7 +353,7 @@ def remote_call() -> Any: # delete the object from the ml cache eagerly -def del_remote_cache(ref_id: str) -> None: +def del_remote_cache(ref_id: str): if ref_id is not None and "." not in ref_id: try: from pyspark.sql.connect.session import SparkSession