Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions indico_toolkit/results/predictions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import TYPE_CHECKING

from ..normalization import normalize_prediction_dict
Expand Down Expand Up @@ -55,3 +56,14 @@ def from_dict(
return Unbundling.from_dict(document, task, review, prediction)
else:
raise ValueError(f"unsupported task type {task.type!r}")


# `dataclass()` doesn't (yet) provide a way to configure the generated `__replace__`
# method on Python 3.13+. Unshadow `Prediction.__replace__` in generated subclasses.
if sys.version_info >= (3, 13):
del Classification.__replace__
del DocumentExtraction.__replace__
del Extraction.__replace__
del FormExtraction.__replace__
del Summarization.__replace__
del Unbundling.__replace__
29 changes: 15 additions & 14 deletions indico_toolkit/results/predictions/documentextraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import copy, deepcopy
from dataclasses import dataclass, field, replace
from copy import copy
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from ...etloutput import (
Expand Down Expand Up @@ -131,6 +131,19 @@ def table_cells(self, table_cells: "Iterable[tuple[Table, Cell]]") -> None:
self.tables.append(table)
self.cells.append(cell)

def __deepcopy__(self, memo: Any) -> "Self":
"""
Supports `copy.deepcopy(prediction)` without copying immutable objects.
This provides a significant time and memory improvement when OCR is assigned.
"""
new_instance = super().__deepcopy__(memo)
new_instance.groups = copy(self.groups)
new_instance.spans = copy(self.spans)
new_instance.tokens = copy(self.tokens)
new_instance.tables = copy(self.tables)
new_instance.cells = copy(self.cells)
return new_instance

@staticmethod
def from_dict(
document: "Document",
Expand Down Expand Up @@ -190,15 +203,3 @@ def to_dict(self) -> "dict[str, Any]":
prediction["rejected"] = True

return prediction

def copy(self) -> "Self":
return replace(
self,
groups=copy(self.groups),
spans=copy(self.spans),
tokens=copy(self.tokens),
tables=copy(self.tables),
cells=copy(self.cells),
confidences=copy(self.confidences),
extras=deepcopy(self.extras),
)
43 changes: 36 additions & 7 deletions indico_toolkit/results/predictions/prediction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import copy, deepcopy
from dataclasses import dataclass, replace
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
Expand Down Expand Up @@ -28,15 +28,44 @@ def confidence(self) -> float:
def confidence(self, value: float) -> None:
self.confidences[self.label] = value

def __deepcopy__(self, memo: Any) -> "Self":
"""
Supports `copy.deepcopy(prediction)` without copying immutable objects.
"""
new_instance = copy(self)
new_instance.confidences = copy(self.confidences)
new_instance.extras = deepcopy(self.extras, memo)
return new_instance

def __replace__override__(self, **attributes: Any) -> "Self":
"""
Supports `copy.replace(prediction, **attrs)` on Python 3.13+

Unlike `dataclasses.replace(**attrs)` this performs a deep copy and allows
assigning properties in addition to attributes.

E.g.
>>> dataclasses.replace(prediction, confidence=1.0)
Shallow copy and raises TypeError(...)
>>> copy.replace(prediction, confidence=1.0)
Deep copy and returns Prediction(confidence=1.0, ...)
"""
new_instance = deepcopy(self)

for attribute, value in attributes.items():
setattr(new_instance, attribute, value)

return new_instance

def to_dict(self) -> "dict[str, Any]":
"""
Create a prediction dictionary for auto review changes.
"""
raise NotImplementedError()

def copy(self) -> "Self":
return replace(
self,
confidences=copy(self.confidences),
extras=deepcopy(self.extras),
)

# `dataclass()` doesn't (yet) provide a way to override the generated `__replace__`
# method on Python 3.13+. It must be overridden after class generation and unshadowed
# on all derived classes.
Prediction.__replace__ = Prediction.__replace__override__ # type:ignore
del Prediction.__replace__override__
18 changes: 9 additions & 9 deletions indico_toolkit/results/predictions/summarization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from copy import copy, deepcopy
from copy import copy
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -70,6 +70,14 @@ def span(self, span: "Span") -> None:
"""
self.citation = replace(self.citation, span=span)

def __deepcopy__(self, memo: Any) -> "Self":
"""
Supports `copy.deepcopy(prediction)` without copying immutable objects.
"""
new_instance = super().__deepcopy__(memo)
new_instance.citations = copy(self.citations)
return new_instance

@staticmethod
def from_dict(
document: "Document",
Expand Down Expand Up @@ -125,11 +133,3 @@ def to_dict(self) -> "dict[str, Any]":
prediction["rejected"] = True

return prediction

def copy(self) -> "Self":
return replace(
self,
citations=copy(self.citations),
confidences=copy(self.confidences),
extras=deepcopy(self.extras),
)
20 changes: 10 additions & 10 deletions indico_toolkit/results/predictions/unbundling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import copy, deepcopy
from dataclasses import dataclass, replace
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from ...etloutput import Span
Expand All @@ -25,6 +25,14 @@ def pages(self) -> "tuple[int, ...]":
"""
return tuple(span.page for span in self.spans)

def __deepcopy__(self, memo: Any) -> "Self":
"""
Supports `copy.deepcopy(prediction)` without copying immutable objects.
"""
new_instance = super().__deepcopy__(memo)
new_instance.spans = copy(self.spans)
return new_instance

@staticmethod
def from_dict(
document: "Document",
Expand Down Expand Up @@ -55,11 +63,3 @@ def to_dict(self) -> "dict[str, Any]":
"confidence": self.confidences,
"spans": [span.to_dict() for span in self.spans],
}

def copy(self) -> "Self":
return replace(
self,
spans=copy(self.spans),
confidences=copy(self.confidences),
extras=deepcopy(self.extras),
)
13 changes: 12 additions & 1 deletion indico_toolkit/results/result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from copy import deepcopy
from dataclasses import dataclass, replace
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any

from . import predictions as prediction
from .document import Document
Expand All @@ -11,6 +13,9 @@
from .task import Task
from .utils import get

if TYPE_CHECKING:
from typing_extensions import Self


@dataclass(frozen=True, order=True)
class Result:
Expand Down Expand Up @@ -44,6 +49,12 @@ def admin_review(self) -> "PredictionList[Prediction]":
def final(self) -> "PredictionList[Prediction]":
return self.predictions.where(review=self.reviews[-1] if self.reviews else None)

def __deepcopy__(self, memo: Any) -> "Self":
"""
Supports `copy.deepcopy(result)` without copying immutable objects.
"""
return replace(self, predictions=deepcopy(self.predictions, memo))

@staticmethod
def from_dict(result: object) -> "Result":
"""
Expand Down