Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,3 @@ benchmark/scripts/leftovers/
# direnv
.envrc
AGENTS.md
coverage.xml
3 changes: 2 additions & 1 deletion linopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# we need to extend their __mul__ functions with a quick special case
import linopy.monkey_patch_xarray # noqa: F401
from linopy.common import align
from linopy.config import options
from linopy.config import LinopyDeprecationWarning, options
from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL
from linopy.constraints import Constraint, Constraints
from linopy.expressions import LinearExpression, QuadraticExpression, merge
Expand All @@ -34,6 +34,7 @@
"EQUAL",
"GREATER_EQUAL",
"LESS_EQUAL",
"LinopyDeprecationWarning",
"LinearExpression",
"Model",
"Objective",
Expand Down
67 changes: 41 additions & 26 deletions linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import operator
import os
from collections.abc import Callable, Generator, Hashable, Iterable, Sequence
from functools import reduce, wraps
from functools import partial, reduce, wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from warnings import warn
Expand Down Expand Up @@ -1205,7 +1205,7 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool:

def align(
*objects: LinearExpression | QuadraticExpression | Variable | T_Alignable,
join: JoinOptions = "exact",
join: JoinOptions | None = None,
copy: bool = True,
indexes: Any = None,
exclude: str | Iterable[Hashable] = frozenset(),
Expand Down Expand Up @@ -1265,41 +1265,56 @@ def align(


"""
from linopy.config import options
from linopy.expressions import LinearExpression, QuadraticExpression
from linopy.variables import Variable

# Extract underlying Datasets for index computation.
if join is None:
join = options["arithmetic_convention"]

if join == "legacy":
from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning

warn(
LEGACY_DEPRECATION_MESSAGE,
LinopyDeprecationWarning,
stacklevel=2,
)
join = "inner"

elif join == "v1":
join = "exact"

finisher: list[partial[Any] | Callable[[Any], Any]] = []
das: list[Any] = []
for obj in objects:
if isinstance(obj, LinearExpression | QuadraticExpression | Variable):
if isinstance(obj, LinearExpression | QuadraticExpression):
finisher.append(partial(obj.__class__, model=obj.model))
das.append(obj.data)
elif isinstance(obj, Variable):
finisher.append(
partial(
obj.__class__,
model=obj.model,
name=obj.data.attrs["name"],
skip_broadcast=True,
)
)
das.append(obj.data)
else:
finisher.append(lambda x: x)
das.append(obj)

exclude = frozenset(exclude).union(HELPER_DIMS)

# Compute target indexes.
target_aligned = xr_align(
*das, join=join, copy=False, indexes=indexes, exclude=exclude
aligned = xr_align(
*das,
join=join,
copy=copy,
indexes=indexes,
exclude=exclude,
fill_value=fill_value,
)

# Reindex each object to target indexes.
reindex_kwargs: dict[str, Any] = {}
if fill_value is not dtypes.NA:
reindex_kwargs["fill_value"] = fill_value
results: list[Any] = []
for obj, target in zip(objects, target_aligned):
indexers = {
dim: target.indexes[dim]
for dim in target.dims
if dim not in exclude and dim in target.indexes
}
# Variable.reindex has no fill_value — it always uses sentinels
if isinstance(obj, Variable):
results.append(obj.reindex(indexers))
else:
results.append(obj.reindex(indexers, **reindex_kwargs)) # type: ignore[union-attr]
return tuple(results)
return tuple([f(da) for f, da in zip(finisher, aligned)])


LocT = TypeVar(
Expand Down
36 changes: 29 additions & 7 deletions linopy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,46 @@

from typing import Any

VALID_ARITHMETIC_JOINS = {"legacy", "v1"}

LEGACY_DEPRECATION_MESSAGE = (
"The 'legacy' arithmetic convention is deprecated and will be removed in "
"linopy v1. Set linopy.options['arithmetic_convention'] = 'v1' to opt in "
"to the new behavior, or filter this warning with:\n"
" import warnings; warnings.filterwarnings('ignore', category=LinopyDeprecationWarning)"
)


class LinopyDeprecationWarning(FutureWarning):
"""Warning for deprecated linopy features scheduled for removal."""


class OptionSettings:
def __init__(self, **kwargs: int) -> None:
def __init__(self, **kwargs: Any) -> None:
self._defaults = kwargs
self._current_values = kwargs.copy()

def __call__(self, **kwargs: int) -> None:
def __call__(self, **kwargs: Any) -> None:
self.set_value(**kwargs)

def __getitem__(self, key: str) -> int:
def __getitem__(self, key: str) -> Any:
return self.get_value(key)

def __setitem__(self, key: str, value: int) -> None:
def __setitem__(self, key: str, value: Any) -> None:
return self.set_value(**{key: value})

def set_value(self, **kwargs: int) -> None:
def set_value(self, **kwargs: Any) -> None:
for k, v in kwargs.items():
if k not in self._defaults:
raise KeyError(f"{k} is not a valid setting.")
if k == "arithmetic_convention" and v not in VALID_ARITHMETIC_JOINS:
raise ValueError(
f"Invalid arithmetic_convention: {v!r}. "
f"Must be one of {VALID_ARITHMETIC_JOINS}."
)
self._current_values[k] = v

def get_value(self, name: str) -> int:
def get_value(self, name: str) -> Any:
if name in self._defaults:
return self._current_values[name]
else:
Expand All @@ -57,4 +75,8 @@ def __repr__(self) -> str:
return f"OptionSettings:\n {settings}"


options = OptionSettings(display_max_rows=14, display_max_terms=6)
options = OptionSettings(
display_max_rows=14,
display_max_terms=6,
arithmetic_convention="legacy",
)
Loading