diff --git a/indico_toolkit/results/predictions/__init__.py b/indico_toolkit/results/predictions/__init__.py index dff7d51..555cf5c 100644 --- a/indico_toolkit/results/predictions/__init__.py +++ b/indico_toolkit/results/predictions/__init__.py @@ -1,3 +1,4 @@ +import sys from typing import TYPE_CHECKING from ..normalization import normalize_prediction_dict @@ -55,3 +56,14 @@ def from_dict( return Unbundling.from_dict(document, task, review, prediction) else: raise ValueError(f"unsupported task type {task.type!r}") + + +# `dataclass()` doesn't (yet) provide a way to configure the generated `__replace__` +# method on Python 3.13+. Unshadow `Prediction.__replace__` in generated subclasses. +if sys.version_info >= (3, 13): + del Classification.__replace__ + del DocumentExtraction.__replace__ + del Extraction.__replace__ + del FormExtraction.__replace__ + del Summarization.__replace__ + del Unbundling.__replace__ diff --git a/indico_toolkit/results/predictions/documentextraction.py b/indico_toolkit/results/predictions/documentextraction.py index c942ab7..94d1f0b 100644 --- a/indico_toolkit/results/predictions/documentextraction.py +++ b/indico_toolkit/results/predictions/documentextraction.py @@ -1,5 +1,5 @@ -from copy import copy, deepcopy -from dataclasses import dataclass, field, replace +from copy import copy +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from ...etloutput import ( @@ -131,6 +131,19 @@ def table_cells(self, table_cells: "Iterable[tuple[Table, Cell]]") -> None: self.tables.append(table) self.cells.append(cell) + def __deepcopy__(self, memo: Any) -> "Self": + """ + Supports `copy.deepcopy(prediction)` without copying immutable objects. + This provides a significant time and memory improvement when OCR is assigned. + """ + new_instance = super().__deepcopy__(memo) + new_instance.groups = copy(self.groups) + new_instance.spans = copy(self.spans) + new_instance.tokens = copy(self.tokens) + new_instance.tables = copy(self.tables) + new_instance.cells = copy(self.cells) + return new_instance + @staticmethod def from_dict( document: "Document", @@ -190,15 +203,3 @@ def to_dict(self) -> "dict[str, Any]": prediction["rejected"] = True return prediction - - def copy(self) -> "Self": - return replace( - self, - groups=copy(self.groups), - spans=copy(self.spans), - tokens=copy(self.tokens), - tables=copy(self.tables), - cells=copy(self.cells), - confidences=copy(self.confidences), - extras=deepcopy(self.extras), - ) diff --git a/indico_toolkit/results/predictions/prediction.py b/indico_toolkit/results/predictions/prediction.py index 6edf03c..87db94f 100644 --- a/indico_toolkit/results/predictions/prediction.py +++ b/indico_toolkit/results/predictions/prediction.py @@ -1,5 +1,5 @@ from copy import copy, deepcopy -from dataclasses import dataclass, replace +from dataclasses import dataclass from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -28,15 +28,44 @@ def confidence(self) -> float: def confidence(self, value: float) -> None: self.confidences[self.label] = value + def __deepcopy__(self, memo: Any) -> "Self": + """ + Supports `copy.deepcopy(prediction)` without copying immutable objects. + """ + new_instance = copy(self) + new_instance.confidences = copy(self.confidences) + new_instance.extras = deepcopy(self.extras, memo) + return new_instance + + def __replace__override__(self, **attributes: Any) -> "Self": + """ + Supports `copy.replace(prediction, **attrs)` on Python 3.13+ + + Unlike `dataclasses.replace(**attrs)` this performs a deep copy and allows + assigning properties in addition to attributes. + + E.g. + >>> dataclasses.replace(prediction, confidence=1.0) + Shallow copy and raises TypeError(...) + >>> copy.replace(prediction, confidence=1.0) + Deep copy and returns Prediction(confidence=1.0, ...) + """ + new_instance = deepcopy(self) + + for attribute, value in attributes.items(): + setattr(new_instance, attribute, value) + + return new_instance + def to_dict(self) -> "dict[str, Any]": """ Create a prediction dictionary for auto review changes. """ raise NotImplementedError() - def copy(self) -> "Self": - return replace( - self, - confidences=copy(self.confidences), - extras=deepcopy(self.extras), - ) + +# `dataclass()` doesn't (yet) provide a way to override the generated `__replace__` +# method on Python 3.13+. It must be overridden after class generation and unshadowed +# on all derived classes. +Prediction.__replace__ = Prediction.__replace__override__ # type:ignore +del Prediction.__replace__override__ diff --git a/indico_toolkit/results/predictions/summarization.py b/indico_toolkit/results/predictions/summarization.py index 6f2fa3e..fa3f677 100644 --- a/indico_toolkit/results/predictions/summarization.py +++ b/indico_toolkit/results/predictions/summarization.py @@ -1,4 +1,4 @@ -from copy import copy, deepcopy +from copy import copy from dataclasses import dataclass, replace from typing import TYPE_CHECKING, Any @@ -70,6 +70,14 @@ def span(self, span: "Span") -> None: """ self.citation = replace(self.citation, span=span) + def __deepcopy__(self, memo: Any) -> "Self": + """ + Supports `copy.deepcopy(prediction)` without copying immutable objects. + """ + new_instance = super().__deepcopy__(memo) + new_instance.citations = copy(self.citations) + return new_instance + @staticmethod def from_dict( document: "Document", @@ -125,11 +133,3 @@ def to_dict(self) -> "dict[str, Any]": prediction["rejected"] = True return prediction - - def copy(self) -> "Self": - return replace( - self, - citations=copy(self.citations), - confidences=copy(self.confidences), - extras=deepcopy(self.extras), - ) diff --git a/indico_toolkit/results/predictions/unbundling.py b/indico_toolkit/results/predictions/unbundling.py index 7e998dd..242cad2 100644 --- a/indico_toolkit/results/predictions/unbundling.py +++ b/indico_toolkit/results/predictions/unbundling.py @@ -1,5 +1,5 @@ -from copy import copy, deepcopy -from dataclasses import dataclass, replace +from copy import copy +from dataclasses import dataclass from typing import TYPE_CHECKING, Any from ...etloutput import Span @@ -25,6 +25,14 @@ def pages(self) -> "tuple[int, ...]": """ return tuple(span.page for span in self.spans) + def __deepcopy__(self, memo: Any) -> "Self": + """ + Supports `copy.deepcopy(prediction)` without copying immutable objects. + """ + new_instance = super().__deepcopy__(memo) + new_instance.spans = copy(self.spans) + return new_instance + @staticmethod def from_dict( document: "Document", @@ -55,11 +63,3 @@ def to_dict(self) -> "dict[str, Any]": "confidence": self.confidences, "spans": [span.to_dict() for span in self.spans], } - - def copy(self) -> "Self": - return replace( - self, - spans=copy(self.spans), - confidences=copy(self.confidences), - extras=deepcopy(self.extras), - ) diff --git a/indico_toolkit/results/result.py b/indico_toolkit/results/result.py index cb55259..73459ed 100644 --- a/indico_toolkit/results/result.py +++ b/indico_toolkit/results/result.py @@ -1,6 +1,8 @@ -from dataclasses import dataclass +from copy import deepcopy +from dataclasses import dataclass, replace from functools import partial from itertools import chain +from typing import TYPE_CHECKING, Any from . import predictions as prediction from .document import Document @@ -11,6 +13,9 @@ from .task import Task from .utils import get +if TYPE_CHECKING: + from typing_extensions import Self + @dataclass(frozen=True, order=True) class Result: @@ -44,6 +49,12 @@ def admin_review(self) -> "PredictionList[Prediction]": def final(self) -> "PredictionList[Prediction]": return self.predictions.where(review=self.reviews[-1] if self.reviews else None) + def __deepcopy__(self, memo: Any) -> "Self": + """ + Supports `copy.deepcopy(result)` without copying immutable objects. + """ + return replace(self, predictions=deepcopy(self.predictions, memo)) + @staticmethod def from_dict(result: object) -> "Result": """