diff --git a/.gitignore b/.gitignore index 10ac8e45..7b962a6b 100644 --- a/.gitignore +++ b/.gitignore @@ -50,4 +50,3 @@ benchmark/scripts/leftovers/ # direnv .envrc AGENTS.md -coverage.xml diff --git a/linopy/__init__.py b/linopy/__init__.py index b1dc33b9..a372c087 100644 --- a/linopy/__init__.py +++ b/linopy/__init__.py @@ -13,7 +13,7 @@ # we need to extend their __mul__ functions with a quick special case import linopy.monkey_patch_xarray # noqa: F401 from linopy.common import align -from linopy.config import options +from linopy.config import LinopyDeprecationWarning, options from linopy.constants import EQUAL, GREATER_EQUAL, LESS_EQUAL from linopy.constraints import Constraint, Constraints from linopy.expressions import LinearExpression, QuadraticExpression, merge @@ -34,6 +34,7 @@ "EQUAL", "GREATER_EQUAL", "LESS_EQUAL", + "LinopyDeprecationWarning", "LinearExpression", "Model", "Objective", diff --git a/linopy/common.py b/linopy/common.py index 48755d6e..4b3f84d6 100644 --- a/linopy/common.py +++ b/linopy/common.py @@ -10,7 +10,7 @@ import operator import os from collections.abc import Callable, Generator, Hashable, Iterable, Sequence -from functools import reduce, wraps +from functools import partial, reduce, wraps from pathlib import Path from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload from warnings import warn @@ -1205,7 +1205,7 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: def align( *objects: LinearExpression | QuadraticExpression | Variable | T_Alignable, - join: JoinOptions = "exact", + join: JoinOptions | None = None, copy: bool = True, indexes: Any = None, exclude: str | Iterable[Hashable] = frozenset(), @@ -1265,41 +1265,56 @@ def align( """ + from linopy.config import options from linopy.expressions import LinearExpression, QuadraticExpression from linopy.variables import Variable - # Extract underlying Datasets for index computation. + if join is None: + join = options["arithmetic_convention"] + + if join == "legacy": + from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning + + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=2, + ) + join = "inner" + + elif join == "v1": + join = "exact" + + finisher: list[partial[Any] | Callable[[Any], Any]] = [] das: list[Any] = [] for obj in objects: - if isinstance(obj, LinearExpression | QuadraticExpression | Variable): + if isinstance(obj, LinearExpression | QuadraticExpression): + finisher.append(partial(obj.__class__, model=obj.model)) + das.append(obj.data) + elif isinstance(obj, Variable): + finisher.append( + partial( + obj.__class__, + model=obj.model, + name=obj.data.attrs["name"], + skip_broadcast=True, + ) + ) das.append(obj.data) else: + finisher.append(lambda x: x) das.append(obj) exclude = frozenset(exclude).union(HELPER_DIMS) - - # Compute target indexes. - target_aligned = xr_align( - *das, join=join, copy=False, indexes=indexes, exclude=exclude + aligned = xr_align( + *das, + join=join, + copy=copy, + indexes=indexes, + exclude=exclude, + fill_value=fill_value, ) - - # Reindex each object to target indexes. - reindex_kwargs: dict[str, Any] = {} - if fill_value is not dtypes.NA: - reindex_kwargs["fill_value"] = fill_value - results: list[Any] = [] - for obj, target in zip(objects, target_aligned): - indexers = { - dim: target.indexes[dim] - for dim in target.dims - if dim not in exclude and dim in target.indexes - } - # Variable.reindex has no fill_value — it always uses sentinels - if isinstance(obj, Variable): - results.append(obj.reindex(indexers)) - else: - results.append(obj.reindex(indexers, **reindex_kwargs)) # type: ignore[union-attr] - return tuple(results) + return tuple([f(da) for f, da in zip(finisher, aligned)]) LocT = TypeVar( diff --git a/linopy/config.py b/linopy/config.py index c098709d..9f04ce17 100644 --- a/linopy/config.py +++ b/linopy/config.py @@ -9,28 +9,46 @@ from typing import Any +VALID_ARITHMETIC_JOINS = {"legacy", "v1"} + +LEGACY_DEPRECATION_MESSAGE = ( + "The 'legacy' arithmetic convention is deprecated and will be removed in " + "linopy v1. Set linopy.options['arithmetic_convention'] = 'v1' to opt in " + "to the new behavior, or filter this warning with:\n" + " import warnings; warnings.filterwarnings('ignore', category=LinopyDeprecationWarning)" +) + + +class LinopyDeprecationWarning(FutureWarning): + """Warning for deprecated linopy features scheduled for removal.""" + class OptionSettings: - def __init__(self, **kwargs: int) -> None: + def __init__(self, **kwargs: Any) -> None: self._defaults = kwargs self._current_values = kwargs.copy() - def __call__(self, **kwargs: int) -> None: + def __call__(self, **kwargs: Any) -> None: self.set_value(**kwargs) - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Any: return self.get_value(key) - def __setitem__(self, key: str, value: int) -> None: + def __setitem__(self, key: str, value: Any) -> None: return self.set_value(**{key: value}) - def set_value(self, **kwargs: int) -> None: + def set_value(self, **kwargs: Any) -> None: for k, v in kwargs.items(): if k not in self._defaults: raise KeyError(f"{k} is not a valid setting.") + if k == "arithmetic_convention" and v not in VALID_ARITHMETIC_JOINS: + raise ValueError( + f"Invalid arithmetic_convention: {v!r}. " + f"Must be one of {VALID_ARITHMETIC_JOINS}." + ) self._current_values[k] = v - def get_value(self, name: str) -> int: + def get_value(self, name: str) -> Any: if name in self._defaults: return self._current_values[name] else: @@ -57,4 +75,8 @@ def __repr__(self) -> str: return f"OptionSettings:\n {settings}" -options = OptionSettings(display_max_rows=14, display_max_terms=6) +options = OptionSettings( + display_max_rows=14, + display_max_terms=6, + arithmetic_convention="legacy", +) diff --git a/linopy/expressions.py b/linopy/expressions.py index 22a1fb1d..64e2ecb7 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -32,6 +32,7 @@ from xarray.core.indexes import Indexes from xarray.core.types import JoinOptions from xarray.core.utils import Frozen +from xarray.structure.alignment import AlignmentError try: # resolve breaking change in xarray 2025.03.0 @@ -49,7 +50,6 @@ LocIndexer, as_dataarray, assign_multiindex_safe, - check_common_keys_values, check_has_nulls, check_has_nulls_polars, fill_missing_coords, @@ -68,7 +68,7 @@ to_dataframe, to_polars, ) -from linopy.config import options +from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning, options from linopy.constants import ( CV_DIM, EQUAL, @@ -563,7 +563,7 @@ def _align_constant( fill_value : float, default: 0 Fill value for missing coordinates. join : str, optional - Alignment method. If None, uses size-aware default behavior. + Alignment method. If None, uses ``options["arithmetic_convention"]``. Returns ------- @@ -575,6 +575,15 @@ def _align_constant( Whether the expression's data needs reindexing. """ if join is None: + join = options["arithmetic_convention"] + + if join == "legacy": + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=4, + ) + # Old behavior: override when same sizes, left join otherwise if other.sizes == self.const.sizes: return self.const, other.assign_coords(coords=self.coords), False return ( @@ -582,30 +591,52 @@ def _align_constant( other.reindex_like(self.const, fill_value=fill_value), False, ) - elif join == "override": + + elif join == "v1": + join = "exact" + + if join == "override": return self.const, other.assign_coords(coords=self.coords), False - else: - self_const, aligned = xr.align( + elif join == "left": + return ( self.const, - other, - join=join, - fill_value=fill_value, + other.reindex_like(self.const, fill_value=fill_value), + False, ) + else: + try: + self_const, aligned = xr.align( + self.const, other, join=join, fill_value=fill_value + ) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .add()/.sub()/.mul()/.div() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection of coordinates\n' + ' .add(other, join="outer") # union of coordinates (with fill)\n' + ' .add(other, join="left") # keep left operand\'s coordinates\n' + ' .add(other, join="override") # positional alignment' + ) from None + raise return self_const, aligned, True def _add_constant( self: GenericExpression, other: ConstantLike, join: JoinOptions | None = None ) -> GenericExpression: - # NaN values in self.const or other are filled with 0 (additive identity) - # so that missing data does not silently propagate through arithmetic. + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" if np.isscalar(other) and join is None: - return self.assign(const=self.const.fillna(0) + other) + const = self.const.fillna(0) + other if is_legacy else self.const + other + return self.assign(const=const) da = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, da, needs_data_reindex = self._align_constant( da, fill_value=0, join=join ) - da = da.fillna(0) - self_const = self_const.fillna(0) + if is_legacy: + da = da.fillna(0) + self_const = self_const.fillna(0) if needs_data_reindex: fv = {**self._fill_value, "const": 0} return self.__class__( @@ -623,31 +654,29 @@ def _apply_constant_op( fill_value: float, join: JoinOptions | None = None, ) -> GenericExpression: - """ - Apply a constant operation (mul, div, etc.) to this expression with a scalar or array. - - NaN values are filled with neutral elements before the operation: - - factor (other) is filled with fill_value (0 for mul, 1 for div) - - coeffs and const are filled with 0 (additive identity) - """ + is_legacy = ( + join is None and options["arithmetic_convention"] == "legacy" + ) or join == "legacy" factor = as_dataarray(other, coords=self.coords, dims=self.coord_dims) self_const, factor, needs_data_reindex = self._align_constant( factor, fill_value=fill_value, join=join ) - factor = factor.fillna(fill_value) - self_const = self_const.fillna(0) + if is_legacy: + factor = factor.fillna(fill_value) + self_const = self_const.fillna(0) if needs_data_reindex: fv = {**self._fill_value, "const": 0} data = self.data.reindex_like(self_const, fill_value=fv) + coeffs = data.coeffs.fillna(0) if is_legacy else data.coeffs return self.__class__( assign_multiindex_safe( data, - coeffs=op(data.coeffs.fillna(0), factor), + coeffs=op(coeffs, factor), const=op(self_const, factor), ), self.model, ) - coeffs = self.coeffs.fillna(0) + coeffs = self.coeffs.fillna(0) if is_legacy else self.coeffs return self.assign(coeffs=op(coeffs, factor), const=op(self_const, factor)) def _multiply_by_constant( @@ -1147,34 +1176,74 @@ def to_constraint( f"Both sides of the constraint are constant. At least one side must contain variables. {self} {rhs}" ) - if isinstance(rhs, SUPPORTED_CONSTANT_TYPES): - rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) + effective_join = join if join is not None else options["arithmetic_convention"] - extra_dims = set(rhs.dims) - set(self.coord_dims) - if extra_dims: - logger.warning( - f"Constant RHS contains dimensions {extra_dims} not present " - f"in the expression, which might lead to inefficiencies. " - f"Consider collapsing the dimensions by taking min/max." + if effective_join == "legacy": + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=3, + ) + # Old behavior: convert to DataArray, warn about extra dims, + # reindex_like (left join), then sub + if isinstance(rhs, SUPPORTED_CONSTANT_TYPES): + rhs = as_dataarray(rhs, coords=self.coords, dims=self.coord_dims) + extra_dims = set(rhs.dims) - set(self.coord_dims) + if extra_dims: + logger.warning( + f"Constant RHS contains dimensions {extra_dims} not present " + f"in the expression, which might lead to inefficiencies. " + f"Consider collapsing the dimensions by taking min/max." + ) + rhs = rhs.reindex_like(self.const, fill_value=np.nan) + # Alignment already done — compute constraint directly + constraint_rhs = rhs - self.const + data = assign_multiindex_safe( + self.data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs ) - rhs = rhs.reindex_like(self.const, fill_value=np.nan) + return constraints.Constraint(data, model=self.model) + # Non-constant rhs (Variable/Expression) — fall through to sub path + + if effective_join == "v1": + effective_join = "exact" - # Remember where RHS is NaN (meaning "no constraint") before the - # subtraction, which may fill NaN with 0 as part of normal - # expression arithmetic. if isinstance(rhs, DataArray): - rhs_nan_mask = rhs.isnull() - else: - rhs_nan_mask = None + if effective_join == "override": + aligned_rhs = rhs.assign_coords(coords=self.const.coords) + expr_const = self.const + expr_data = self.data + elif effective_join == "left": + aligned_rhs = rhs.reindex_like(self.const, fill_value=np.nan) + expr_const = self.const + expr_data = self.data + else: + try: + expr_const_aligned, aligned_rhs = xr.align( + self.const, rhs, join=effective_join, fill_value=np.nan + ) + except ValueError as e: + if "exact" in str(e): + raise ValueError( + f"{e}\n" + "Use .le()/.ge()/.eq() with an explicit join= parameter:\n" + ' .le(rhs, join="inner") # intersection of coordinates\n' + ' .le(rhs, join="left") # keep expression coordinates (NaN fill)\n' + ' .le(rhs, join="override") # positional alignment' + ) from None + raise + expr_const = expr_const_aligned.fillna(0) + expr_data = self.data.reindex_like( + expr_const_aligned, fill_value=self._fill_value + ) + constraint_rhs = aligned_rhs - expr_const + data = assign_multiindex_safe( + expr_data[["coeffs", "vars"]], sign=sign, rhs=constraint_rhs + ) + return constraints.Constraint(data, model=self.model) all_to_lhs = self.sub(rhs, join=join).data computed_rhs = -all_to_lhs.const - # Restore NaN at positions where the original constant RHS had no - # value so that downstream code still treats them as unconstrained. - if rhs_nan_mask is not None and rhs_nan_mask.any(): - computed_rhs = xr.where(rhs_nan_mask, np.nan, computed_rhs) - data = assign_multiindex_safe( all_to_lhs[["coeffs", "vars"]], sign=sign, rhs=computed_rhs ) @@ -1650,6 +1719,18 @@ def __add__( return self._add_constant(other) else: other = as_expression(other, model=self.model, dims=self.coord_dims) + if options["arithmetic_convention"] == "v1": + # Enforce exact coordinate alignment before merge + try: + xr.align(self.const, other.const, join="exact") + except (ValueError, AlignmentError) as e: + raise ValueError( + f"{e}\n" + "Use .add()/.sub() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection\n' + ' .add(other, join="outer") # union with fill\n' + ' .add(other, join="left") # keep left coordinates' + ) from None return merge([self, other], cls=self.__class__) except TypeError: return NotImplemented @@ -2188,6 +2269,18 @@ def __add__(self, other: SideLike) -> QuadraticExpression: if isinstance(other, LinearExpression): other = other.to_quadexpr() + if options["arithmetic_convention"] == "v1": + try: + xr.align(self.const, other.const, join="exact") + except (ValueError, AlignmentError) as e: + raise ValueError( + f"{e}\n" + "Use .add()/.sub() with an explicit join= parameter:\n" + ' .add(other, join="inner") # intersection\n' + ' .add(other, join="outer") # union with fill\n' + ' .add(other, join="left") # keep left coordinates' + ) from None + return merge([self, other], cls=self.__class__) except TypeError: return NotImplemented @@ -2441,16 +2534,6 @@ def merge( model = exprs[0].model - if join is not None: - override = join == "override" - elif cls in linopy_types and dim in HELPER_DIMS: - coord_dims = [ - {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} for e in exprs - ] - override = check_common_keys_values(coord_dims) # type: ignore - else: - override = False - data = [e.data if isinstance(e, linopy_types) else e for e in exprs] data = [fill_missing_coords(ds, fill_helper_dims=True) for ds in data] @@ -2464,12 +2547,38 @@ def merge( elif cls == variables.Variable: kwargs["fill_value"] = variables.FILL_VALUE - if join is not None: - kwargs["join"] = join - elif override: - kwargs["join"] = "override" + effective_join = join if join is not None else options["arithmetic_convention"] + + if effective_join == "legacy": + warn( + LEGACY_DEPRECATION_MESSAGE, + LinopyDeprecationWarning, + stacklevel=2, + ) + # Reproduce old behavior: override when all shared dims have + # matching sizes, outer otherwise. + if cls in linopy_types and dim in HELPER_DIMS: + coord_dims = [ + {k: v for k, v in e.sizes.items() if k not in HELPER_DIMS} + for e in exprs + ] + common_keys = set.intersection(*(set(d.keys()) for d in coord_dims)) + override = all( + len({d[k] for d in coord_dims if k in d}) == 1 for k in common_keys + ) + else: + override = False + + kwargs["join"] = "override" if override else "outer" + elif effective_join == "v1": + # Merge uses outer join for xr.concat since helper dims + # (_term, _factor) commonly have different sizes and + # expressions may have different user dimensions. + # Coordinate enforcement for v1 is done at the operator + # level (__add__, __sub__, etc.) before calling merge. + kwargs["join"] = "outer" else: - kwargs.setdefault("join", "outer") + kwargs["join"] = effective_join if dim == TERM_DIM: ds = xr.concat([d[["coeffs", "vars"]] for d in data], dim, **kwargs) diff --git a/linopy/variables.py b/linopy/variables.py index 396703fb..1e2ea6ae 100644 --- a/linopy/variables.py +++ b/linopy/variables.py @@ -401,6 +401,13 @@ def __mul__(self, other: SideLike) -> ExpressionLike: Multiply variables with a coefficient, variable, or expression. """ try: + if isinstance(other, Variable | ScalarVariable): + return self.to_linexpr() * other + + # Fast path for scalars: build expression directly with coefficient + if np.isscalar(other): + return self.to_linexpr(other) + return self.to_linexpr() * other except TypeError: return NotImplemented @@ -449,7 +456,13 @@ def __div__( """ Divide variables with a coefficient. """ - return self.to_linexpr() / other + if isinstance(other, expressions.LinearExpression | Variable): + raise TypeError( + "unsupported operand type(s) for /: " + f"{type(self)} and {type(other)}. " + "Non-linear expressions are not yet supported." + ) + return self.to_linexpr()._divide_by_constant(other) def __truediv__( self, coefficient: ConstantLike | LinearExpression | Variable diff --git a/test/conftest.py b/test/conftest.py index ee20cdc2..5e2170a3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from collections.abc import Generator from typing import TYPE_CHECKING import pandas as pd @@ -57,6 +58,16 @@ def pytest_collection_modifyitems( item.add_marker(pytest.mark.gpu) +@pytest.fixture +def v1_convention() -> Generator[None, None, None]: + """Set arithmetic_convention to 'v1' for the duration of a test.""" + import linopy + + linopy.options["arithmetic_convention"] = "v1" + yield + linopy.options["arithmetic_convention"] = "legacy" + + @pytest.fixture def m() -> Model: from linopy import Model diff --git a/test/test_algebraic_properties.py b/test/test_algebraic_properties.py index c0f04f22..04103b61 100644 --- a/test/test_algebraic_properties.py +++ b/test/test_algebraic_properties.py @@ -39,16 +39,27 @@ from __future__ import annotations +from collections.abc import Generator + import numpy as np import pandas as pd import pytest import xarray as xr +import linopy from linopy import Model from linopy.expressions import LinearExpression from linopy.variables import Variable +@pytest.fixture(autouse=True) +def _use_v1_convention() -> Generator[None, None, None]: + """Use v1 arithmetic convention for all tests in this module.""" + linopy.options["arithmetic_convention"] = "v1" + yield + linopy.options["arithmetic_convention"] = "legacy" + + @pytest.fixture def m() -> Model: return Model() diff --git a/test/test_common.py b/test/test_common.py index 64e4bf6f..719ab093 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -5,6 +5,8 @@ @author: fabian """ +from collections.abc import Generator + import numpy as np import pandas as pd import polars as pl @@ -13,6 +15,7 @@ from xarray import DataArray from xarray.testing.assertions import assert_equal +import linopy from linopy import LinearExpression, Model, Variable from linopy.common import ( align, @@ -27,6 +30,17 @@ from linopy.testing import assert_linequal, assert_varequal +@pytest.fixture(autouse=True) +def _use_v1_convention() -> Generator[None, None, None]: + """Use v1 arithmetic convention for all tests in this module.""" + linopy.options["arithmetic_convention"] = "v1" + yield + linopy.options["arithmetic_convention"] = "legacy" + + +# Fixtures m, u, x are provided by conftest.py + + def test_as_dataarray_with_series_dims_default() -> None: target_dim = "dim_0" target_index = [0, 1, 2] diff --git a/test/test_common_legacy.py b/test/test_common_legacy.py new file mode 100644 index 00000000..f1190024 --- /dev/null +++ b/test/test_common_legacy.py @@ -0,0 +1,734 @@ +#!/usr/bin/env python3 +""" +Created on Mon Jun 19 12:11:03 2023 + +@author: fabian +""" + +import numpy as np +import pandas as pd +import polars as pl +import pytest +import xarray as xr +from xarray import DataArray +from xarray.testing.assertions import assert_equal + +from linopy import LinearExpression, Model, Variable +from linopy.common import ( + align, + as_dataarray, + assign_multiindex_safe, + best_int, + get_dims_with_index_levels, + is_constant, + iterate_slices, + maybe_group_terms_polars, +) +from linopy.testing import assert_linequal, assert_varequal + + +def test_as_dataarray_with_series_dims_default() -> None: + target_dim = "dim_0" + target_index = [0, 1, 2] + s = pd.Series([1, 2, 3]) + da = as_dataarray(s) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_set() -> None: + target_dim = "dim1" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + dims = [target_dim] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_given() -> None: + target_dim = "dim1" + target_index = ["a", "b", "c"] + index = pd.Index(target_index, name=target_dim) + s = pd.Series([1, 2, 3], index=index) + dims: list[str] = [] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_priority() -> None: + """The dimension name from the pandas object should have priority.""" + target_dim = "dim1" + target_index = ["a", "b", "c"] + index = pd.Index(target_index, name=target_dim) + s = pd.Series([1, 2, 3], index=index) + dims = ["other"] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_subset() -> None: + target_dim = "dim_0" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + dims: list[str] = [] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_dims_superset() -> None: + target_dim = "dim_a" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + dims = [target_dim, "other"] + da = as_dataarray(s, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_series_aligned_coords() -> None: + """This should not give out a warning even though coords are given.""" + target_dim = "dim_0" + target_index = ["a", "b", "c"] + s = pd.Series([1, 2, 3], index=target_index) + da = as_dataarray(s, coords=[target_index]) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + da = as_dataarray(s, coords={target_dim: target_index}) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_with_pl_series_dims_default() -> None: + target_dim = "dim_0" + target_index = [0, 1, 2] + s = pl.Series([1, 2, 3]) + da = as_dataarray(s) + assert isinstance(da, DataArray) + assert da.dims == (target_dim,) + assert list(da.coords[target_dim].values) == target_index + + +def test_as_dataarray_dataframe_dims_default() -> None: + target_dims = ("dim_0", "dim_1") + target_index = [0, 1] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + da = as_dataarray(df) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_set() -> None: + target_dims = ("dim1", "dim2") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + da = as_dataarray(df, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_given() -> None: + target_dims = ("dim1", "dim2") + target_index = ["a", "b"] + target_columns = ["A", "B"] + index = pd.Index(target_index, name=target_dims[0]) + columns = pd.Index(target_columns, name=target_dims[1]) + df = pd.DataFrame([[1, 2], [3, 4]], index=index, columns=columns) + dims: list[str] = [] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_priority() -> None: + """The dimension name from the pandas object should have priority.""" + target_dims = ("dim1", "dim2") + target_index = ["a", "b"] + target_columns = ["A", "B"] + index = pd.Index(target_index, name=target_dims[0]) + columns = pd.Index(target_columns, name=target_dims[1]) + df = pd.DataFrame([[1, 2], [3, 4]], index=index, columns=columns) + dims = ["other"] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_subset() -> None: + target_dims = ("dim_0", "dim_1") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + dims: list[str] = [] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_dims_superset() -> None: + target_dims = ("dim_a", "dim_b") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + dims = [*target_dims, "other"] + da = as_dataarray(df, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_dataframe_aligned_coords() -> None: + """This should not give out a warning even though coords are given.""" + target_dims = ("dim_0", "dim_1") + target_index = ["a", "b"] + target_columns = ["A", "B"] + df = pd.DataFrame([[1, 2], [3, 4]], index=target_index, columns=target_columns) + da = as_dataarray(df, coords=[target_index, target_columns]) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + coords = dict(zip(target_dims, [target_index, target_columns])) + da = as_dataarray(df, coords=coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + assert list(da.coords[target_dims[0]].values) == target_index + assert list(da.coords[target_dims[1]].values) == target_columns + + +def test_as_dataarray_with_ndarray_no_coords_no_dims() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = [[0, 1], [0, 1]] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_list_no_dims() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_indexes_no_dims() -> None: + target_dims = ("dim1", "dim2") + target_coords = [ + pd.Index(["a", "b"], name="dim1"), + pd.Index(["A", "B"], name="dim2"), + ] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == list(target_coords[i]) + + +def test_as_dataarray_with_ndarray_coords_dict_set_no_dims() -> None: + """If no dims are given and coords are a dict, the keys of the dict should be used as dims.""" + target_dims = ("dim_0", "dim_2") + target_coords = {"dim_0": ["a", "b"], "dim_2": ["A", "B"]} + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for dim in target_dims: + assert list(da.coords[dim]) == target_coords[dim] + + +def test_as_dataarray_with_ndarray_coords_list_dims() -> None: + target_dims = ("dim1", "dim2") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_list_dims_superset() -> None: + target_dims = ("dim1", "dim2") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + dims = [*target_dims, "dim3"] + da = as_dataarray(arr, coords=target_coords, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_list_dims_subset() -> None: + target_dims = ("dim0", "dim_1") + target_coords = [["a", "b"], ["A", "B"]] + arr = np.array([[1, 2], [3, 4]]) + dims = ["dim0"] + da = as_dataarray(arr, coords=target_coords, dims=dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == target_coords[i] + + +def test_as_dataarray_with_ndarray_coords_indexes_dims_aligned() -> None: + target_dims = ("dim1", "dim2") + target_coords = [ + pd.Index(["a", "b"], name="dim1"), + pd.Index(["A", "B"], name="dim2"), + ] + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for i, dim in enumerate(target_dims): + assert list(da.coords[dim]) == list(target_coords[i]) + + +def test_as_dataarray_with_ndarray_coords_indexes_dims_not_aligned() -> None: + target_dims = ("dim3", "dim4") + target_coords = [ + pd.Index(["a", "b"], name="dim1"), + pd.Index(["A", "B"], name="dim2"), + ] + arr = np.array([[1, 2], [3, 4]]) + with pytest.raises(ValueError): + as_dataarray(arr, coords=target_coords, dims=target_dims) + + +def test_as_dataarray_with_ndarray_coords_dict_dims_aligned() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = {"dim_0": ["a", "b"], "dim_1": ["A", "B"]} + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert isinstance(da, DataArray) + assert da.dims == target_dims + for dim in target_dims: + assert list(da.coords[dim]) == target_coords[dim] + + +def test_as_dataarray_with_ndarray_coords_dict_set_dims_not_aligned() -> None: + target_dims = ("dim_0", "dim_1") + target_coords = {"dim_0": ["a", "b"], "dim_2": ["A", "B"]} + arr = np.array([[1, 2], [3, 4]]) + da = as_dataarray(arr, coords=target_coords, dims=target_dims) + assert da.dims == target_dims + assert list(da.coords["dim_0"].values) == ["a", "b"] + assert "dim_2" not in da.coords + + +def test_as_dataarray_with_number() -> None: + num = 1 + da = as_dataarray(num, dims=["dim1"], coords=[["a"]]) + assert isinstance(da, DataArray) + assert da.dims == ("dim1",) + assert list(da.coords["dim1"].values) == ["a"] + + +def test_as_dataarray_with_np_number() -> None: + num = np.float64(1) + da = as_dataarray(num, dims=["dim1"], coords=[["a"]]) + assert isinstance(da, DataArray) + assert da.dims == ("dim1",) + assert list(da.coords["dim1"].values) == ["a"] + + +def test_as_dataarray_with_number_default_dims_coords() -> None: + num = 1 + da = as_dataarray(num) + assert isinstance(da, DataArray) + assert da.dims == () + assert da.coords == {} + + +def test_as_dataarray_with_number_and_coords() -> None: + num = 1 + da = as_dataarray(num, coords=[pd.RangeIndex(10, name="a")]) + assert isinstance(da, DataArray) + assert da.dims == ("a",) + assert list(da.coords["a"].values) == list(range(10)) + + +def test_as_dataarray_with_dataarray() -> None: + da_in = DataArray( + data=[[1, 2], [3, 4]], + dims=["dim1", "dim2"], + coords={"dim1": ["a", "b"], "dim2": ["A", "B"]}, + ) + da_out = as_dataarray(da_in, dims=["dim1", "dim2"], coords=[["a", "b"], ["A", "B"]]) + assert isinstance(da_out, DataArray) + assert da_out.dims == da_in.dims + assert list(da_out.coords["dim1"].values) == list(da_in.coords["dim1"].values) + assert list(da_out.coords["dim2"].values) == list(da_in.coords["dim2"].values) + + +def test_as_dataarray_with_dataarray_default_dims_coords() -> None: + da_in = DataArray( + data=[[1, 2], [3, 4]], + dims=["dim1", "dim2"], + coords={"dim1": ["a", "b"], "dim2": ["A", "B"]}, + ) + da_out = as_dataarray(da_in) + assert isinstance(da_out, DataArray) + assert da_out.dims == da_in.dims + assert list(da_out.coords["dim1"].values) == list(da_in.coords["dim1"].values) + assert list(da_out.coords["dim2"].values) == list(da_in.coords["dim2"].values) + + +def test_as_dataarray_with_unsupported_type() -> None: + with pytest.raises(TypeError): + as_dataarray(lambda x: 1, dims=["dim1"], coords=[["a"]]) + + +def test_best_int() -> None: + # Test for int8 + assert best_int(127) == np.int8 + # Test for int16 + assert best_int(128) == np.int16 + assert best_int(32767) == np.int16 + # Test for int32 + assert best_int(32768) == np.int32 + assert best_int(2147483647) == np.int32 + # Test for int64 + assert best_int(2147483648) == np.int64 + assert best_int(9223372036854775807) == np.int64 + + # Test for value too large + with pytest.raises( + ValueError, match=r"Value 9223372036854775808 is too large for int64." + ): + best_int(9223372036854775808) + + +def test_assign_multiindex_safe() -> None: + # Create a multi-indexed dataset + index = pd.MultiIndex.from_product([["A", "B"], [1, 2]], names=["letter", "number"]) + data = xr.DataArray([1, 2, 3, 4], dims=["index"], coords={"index": index}) + ds = xr.Dataset({"value": data}) + + # This would now warn about the index deletion of single index level + # ds["humidity"] = data + + # Case 1: Assigning a single DataArray + result = assign_multiindex_safe(ds, humidity=data) + assert "humidity" in result + assert "value" in result + assert result["humidity"].equals(data) + + # Case 2: Assigning a Dataset + result = assign_multiindex_safe(ds, **xr.Dataset({"humidity": data})) # type: ignore + assert "humidity" in result + assert "value" in result + assert result["humidity"].equals(data) + + # Case 3: Assigning multiple DataArrays + result = assign_multiindex_safe(ds, humidity=data, pressure=data) + assert "humidity" in result + assert "pressure" in result + assert "value" in result + assert result["humidity"].equals(data) + assert result["pressure"].equals(data) + + +def test_iterate_slices_basic() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=20)) + assert len(slices) == 5 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_with_exclude_dims() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(20)}, + ) + slices = list(iterate_slices(ds, slice_size=20, slice_dims=["x"])) + assert len(slices) == 10 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_large_max_size() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=200)) + assert len(slices) == 1 + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_small_max_size() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(20)}, + ) + slices = list(iterate_slices(ds, slice_size=8, slice_dims=["x"])) + assert ( + len(slices) == 10 + ) # goes to the smallest slice possible which is 1 for the x dimension + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_slice_size_none() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=None)) + assert len(slices) == 1 + for s in slices: + assert ds.equals(s) + + +def test_iterate_slices_includes_last_slice() -> None: + ds = xr.Dataset( + {"var": (("x"), np.random.rand(10))}, # noqa: NPY002 + coords={"x": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=3, slice_dims=["x"])) + assert len(slices) == 4 # 10 slices for dimension 'x' with size 10 + total_elements = sum(s.sizes["x"] for s in slices) + assert total_elements == ds.sizes["x"] # Ensure all elements are included + for s in slices: + assert isinstance(s, xr.Dataset) + assert set(s.dims) == set(ds.dims) + + +def test_iterate_slices_empty_slice_dims() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + slices = list(iterate_slices(ds, slice_size=50, slice_dims=[])) + assert len(slices) == 1 + for s in slices: + assert ds.equals(s) + + +def test_iterate_slices_invalid_slice_dims() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002 + coords={"x": np.arange(10), "y": np.arange(10)}, + ) + with pytest.raises(ValueError): + list(iterate_slices(ds, slice_size=50, slice_dims=["z"])) + + +def test_iterate_slices_empty_dataset() -> None: + ds = xr.Dataset( + {"var": (("x", "y"), np.array([]).reshape(0, 0))}, coords={"x": [], "y": []} + ) + slices = list(iterate_slices(ds, slice_size=10, slice_dims=["x"])) + assert len(slices) == 1 + assert ds.equals(slices[0]) + + +def test_iterate_slices_single_element() -> None: + ds = xr.Dataset({"var": (("x", "y"), np.array([[1]]))}, coords={"x": [0], "y": [0]}) + slices = list(iterate_slices(ds, slice_size=1, slice_dims=["x"])) + assert len(slices) == 1 + assert ds.equals(slices[0]) + + +def test_get_dims_with_index_levels() -> None: + # Create test data + + # Case 1: Simple dataset with regular dimensions + ds1 = xr.Dataset( + {"temp": (("time", "lat"), np.random.rand(3, 2))}, # noqa: NPY002 + coords={"time": pd.date_range("2024-01-01", periods=3), "lat": [0, 1]}, + ) + + # Case 2: Dataset with a multi-index dimension + stations_index = pd.MultiIndex.from_product( + [["USA", "Canada"], ["NYC", "Toronto"]], names=["country", "city"] + ) + stations_coords = xr.Coordinates.from_pandas_multiindex(stations_index, "station") + ds2 = xr.Dataset( + {"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002 + coords={"time": pd.date_range("2024-01-01", periods=3), **stations_coords}, + ) + + # Case 3: Dataset with unnamed multi-index levels + unnamed_stations_index = pd.MultiIndex.from_product( + [["USA", "Canada"], ["NYC", "Toronto"]] + ) + unnamed_stations_coords = xr.Coordinates.from_pandas_multiindex( + unnamed_stations_index, "station" + ) + ds3 = xr.Dataset( + {"temp": (("time", "station"), np.random.rand(3, 4))}, # noqa: NPY002 + coords={ + "time": pd.date_range("2024-01-01", periods=3), + **unnamed_stations_coords, + }, + ) + + # Case 4: Dataset with multiple multi-indexed dimensions + locations_index = pd.MultiIndex.from_product( + [["North", "South"], ["A", "B"]], names=["region", "site"] + ) + locations_coords = xr.Coordinates.from_pandas_multiindex( + locations_index, "location" + ) + + ds4 = xr.Dataset( + {"temp": (("time", "station", "location"), np.random.rand(2, 4, 4))}, # noqa: NPY002 + coords={ + "time": pd.date_range("2024-01-01", periods=2), + **stations_coords, + **locations_coords, + }, + ) + + # Run tests + + # Test case 1: Regular dimensions + assert get_dims_with_index_levels(ds1) == ["time", "lat"] + + # Test case 2: Named multi-index + assert get_dims_with_index_levels(ds2) == ["time", "station (country, city)"] + + # Test case 3: Unnamed multi-index + assert get_dims_with_index_levels(ds3) == [ + "time", + "station (station_level_0, station_level_1)", + ] + + # Test case 4: Multiple multi-indices + expected = ["time", "station (country, city)", "location (region, site)"] + assert get_dims_with_index_levels(ds4) == expected + + # Test case 5: Empty dataset + ds5 = xr.Dataset() + assert get_dims_with_index_levels(ds5) == [] + + +def test_align(x: Variable, u: Variable) -> None: # noqa: F811 + alpha = xr.DataArray([1, 2], [[1, 2]]) + beta = xr.DataArray( + [1, 2, 3], + [ + ( + "dim_3", + pd.MultiIndex.from_tuples( + [(1, "b"), (2, "b"), (1, "c")], names=["level1", "level2"] + ), + ) + ], + ) + + # inner join + x_obs, alpha_obs = align(x, alpha) + assert isinstance(x_obs, Variable) + assert x_obs.shape == alpha_obs.shape == (1,) + assert_varequal(x_obs, x.loc[[1]]) + + # left-join + x_obs, alpha_obs = align(x, alpha, join="left") + assert x_obs.shape == alpha_obs.shape == (2,) + assert isinstance(x_obs, Variable) + assert_varequal(x_obs, x) + assert_equal(alpha_obs, DataArray([np.nan, 1], [[0, 1]])) + + # multiindex + beta_obs, u_obs = align(beta, u) + assert u_obs.shape == beta_obs.shape == (2,) + assert isinstance(u_obs, Variable) + assert_varequal(u_obs, u.loc[[(1, "b"), (2, "b")]]) + assert_equal(beta_obs, beta.loc[[(1, "b"), (2, "b")]]) + + # with linear expression + expr = 20 * x + x_obs, expr_obs, alpha_obs = align(x, expr, alpha) + assert x_obs.shape == alpha_obs.shape == (1,) + assert expr_obs.shape == (1, 1) # _term dim + assert isinstance(expr_obs, LinearExpression) + assert_linequal(expr_obs, expr.loc[[1]]) + + +def test_is_constant() -> None: + model = Model() + index = pd.Index(range(10), name="t") + a = model.add_variables(name="a", coords=[index]) + b = a.sel(t=1) + c = a * 2 + d = a * a + + non_constant = [a, b, c, d] + for nc in non_constant: + assert not is_constant(nc) + + constant_values = [ + 5, + 3.14, + np.int32(7), + np.float64(2.71), + pd.Series([1, 2, 3]), + np.array([4, 5, 6]), + xr.DataArray([k for k in range(10)], coords=[index]), + ] + for cv in constant_values: + assert is_constant(cv) + + +def test_maybe_group_terms_polars_no_duplicates() -> None: + """Fast path: distinct (labels, vars) pairs skip group_by.""" + df = pl.DataFrame({"labels": [0, 0], "vars": [1, 2], "coeffs": [3.0, 4.0]}) + result = maybe_group_terms_polars(df) + assert result.shape == (2, 3) + assert result.columns == ["labels", "vars", "coeffs"] + assert result["coeffs"].to_list() == [3.0, 4.0] + + +def test_maybe_group_terms_polars_with_duplicates() -> None: + """Slow path: duplicate (labels, vars) pairs trigger group_by.""" + df = pl.DataFrame({"labels": [0, 0], "vars": [1, 1], "coeffs": [3.0, 4.0]}) + result = maybe_group_terms_polars(df) + assert result.shape == (1, 3) + assert result["coeffs"].to_list() == [7.0] diff --git a/test/test_constraints.py b/test/test_constraints.py index 9fc0086b..e94f0152 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -5,6 +5,7 @@ @author: fabulous """ +from collections.abc import Generator from typing import Any import dask @@ -14,9 +15,19 @@ import pytest import xarray as xr +import linopy from linopy import EQUAL, GREATER_EQUAL, LESS_EQUAL, Model, Variable, available_solvers from linopy.testing import assert_conequal + +@pytest.fixture(autouse=True) +def _use_v1_convention() -> Generator[None, None, None]: + """Use v1 arithmetic convention for all tests in this module.""" + linopy.options["arithmetic_convention"] = "v1" + yield + linopy.options["arithmetic_convention"] = "legacy" + + # Test model functions @@ -347,67 +358,72 @@ def superset(self, request: Any) -> xr.DataArray | pd.Series: np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") ) - def test_var_le_subset(self, v: Variable, subset: xr.DataArray) -> None: - con = v <= subset + def test_var_le_subset_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v <= subset + + def test_var_le_subset_join_left(self, v: Variable) -> None: + subset_da = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + con = v.to_linexpr().le(subset_da, join="left") assert con.sizes["dim_2"] == v.sizes["dim_2"] assert con.rhs.sel(dim_2=1).item() == 10.0 assert con.rhs.sel(dim_2=3).item() == 30.0 assert np.isnan(con.rhs.sel(dim_2=0).item()) @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) - def test_var_comparison_subset( + def test_var_comparison_subset_raises( self, v: Variable, subset: xr.DataArray, sign: str ) -> None: - if sign == LESS_EQUAL: - con = v <= subset - elif sign == GREATER_EQUAL: - con = v >= subset - else: - con = v == subset - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert con.rhs.sel(dim_2=1).item() == 10.0 - assert np.isnan(con.rhs.sel(dim_2=0).item()) + with pytest.raises(ValueError, match="exact"): + if sign == LESS_EQUAL: + v <= subset + elif sign == GREATER_EQUAL: + v >= subset + else: + v == subset + + def test_expr_le_subset_raises(self, v: Variable, subset: xr.DataArray) -> None: + expr = v + 5 + with pytest.raises(ValueError, match="exact"): + expr <= subset - def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + def test_expr_le_subset_join_left(self, v: Variable) -> None: + subset_da = xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) expr = v + 5 - con = expr <= subset + con = expr.le(subset_da, join="left") assert con.sizes["dim_2"] == v.sizes["dim_2"] assert con.rhs.sel(dim_2=1).item() == pytest.approx(5.0) assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) assert np.isnan(con.rhs.sel(dim_2=0).item()) - @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) - def test_subset_comparison_var( - self, v: Variable, subset: xr.DataArray, sign: str + def test_subset_comparison_var_raises( + self, v: Variable, subset: xr.DataArray ) -> None: - if sign == LESS_EQUAL: - con = subset <= v - elif sign == GREATER_EQUAL: - con = subset >= v - else: - con = subset == v - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert np.isnan(con.rhs.sel(dim_2=0).item()) - assert con.rhs.sel(dim_2=1).item() == pytest.approx(10.0) + with pytest.raises(ValueError, match="exact"): + subset <= v - @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL]) - def test_superset_comparison_var( - self, v: Variable, superset: xr.DataArray, sign: str + def test_superset_comparison_var_raises( + self, v: Variable, superset: xr.DataArray ) -> None: - if sign == LESS_EQUAL: - con = superset <= v - else: - con = superset >= v - assert con.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(con.lhs.coeffs.values).any() - assert not np.isnan(con.rhs.values).any() + with pytest.raises(ValueError, match="exact"): + superset <= v - def test_constraint_rhs_extra_dims_broadcasts(self, v: Variable) -> None: + def test_constraint_rhs_extra_dims_raises_on_mismatch(self, v: Variable) -> None: rhs = xr.DataArray( [[1.0, 2.0]], dims=["extra", "dim_2"], coords={"dim_2": [0, 1]}, ) + # dim_2 coords [0,1] don't match v's [0..19] under exact join + with pytest.raises(ValueError, match="exact"): + v <= rhs + + def test_constraint_rhs_extra_dims_broadcasts_matching(self, v: Variable) -> None: + rhs = xr.DataArray( + np.ones((2, 20)), + dims=["extra", "dim_2"], + coords={"dim_2": range(20)}, + ) c = v <= rhs assert "extra" in c.dims @@ -419,7 +435,8 @@ def test_subset_constraint_solve_integration(self) -> None: coords = pd.RangeIndex(5, name="i") x = m.add_variables(lower=0, upper=100, coords=[coords], name="x") subset_ub = xr.DataArray([10.0, 20.0], dims=["i"], coords={"i": [1, 3]}) - m.add_constraints(x <= subset_ub, name="subset_ub") + # exact default raises — use explicit join="left" (NaN = no constraint) + m.add_constraints(x.to_linexpr().le(subset_ub, join="left"), name="subset_ub") m.add_objective(x.sum(), sense="max") m.solve(solver_name=solver) sol = m.solution["x"] diff --git a/test/test_constraints_legacy.py b/test/test_constraints_legacy.py new file mode 100644 index 00000000..9a467c8c --- /dev/null +++ b/test/test_constraints_legacy.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +""" +Created on Wed Mar 10 11:23:13 2021. + +@author: fabulous +""" + +from typing import Any + +import dask +import dask.array.core +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from linopy import EQUAL, GREATER_EQUAL, LESS_EQUAL, Model, Variable, available_solvers +from linopy.testing import assert_conequal + +# Test model functions + + +def test_constraint_assignment() -> None: + m: Model = Model() + + lower: xr.DataArray = xr.DataArray( + np.zeros((10, 10)), coords=[range(10), range(10)] + ) + upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + for attr in m.constraints.dataset_attrs: + assert "con0" in getattr(m.constraints, attr) + + assert m.constraints.labels.con0.shape == (10, 10) + assert m.constraints.labels.con0.dtype == int + assert m.constraints.coeffs.con0.dtype in (int, float) + assert m.constraints.vars.con0.dtype in (int, float) + assert m.constraints.rhs.con0.dtype in (int, float) + + assert_conequal(m.constraints.con0, con0) + + +def test_constraint_equality() -> None: + m: Model = Model() + + lower: xr.DataArray = xr.DataArray( + np.zeros((10, 10)), coords=[range(10), range(10)] + ) + upper: xr.DataArray = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + con0 = m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + assert_conequal(con0, 1 * x + 10 * y == 0, strict=False) + assert_conequal(1 * x + 10 * y == 0, 1 * x + 10 * y == 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(con0, 1 * x + 10 * y <= 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(con0, 1 * x + 10 * y >= 0, strict=False) + + with pytest.raises(AssertionError): + assert_conequal(10 * y + 2 * x == 0, 1 * x + 10 * y == 0, strict=False) + + +def test_constraints_getattr_formatted() -> None: + m: Model = Model() + x = m.add_variables(0, 10, name="x") + m.add_constraints(1 * x == 0, name="con-0") + assert_conequal(m.constraints.con_0, m.constraints["con-0"]) + + +def test_anonymous_constraint_assignment() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + con = 1 * x + 10 * y == 0 + m.add_constraints(con) + + for attr in m.constraints.dataset_attrs: + assert "con0" in getattr(m.constraints, attr) + + assert m.constraints.labels.con0.shape == (10, 10) + assert m.constraints.labels.con0.dtype == int + assert m.constraints.coeffs.con0.dtype in (int, float) + assert m.constraints.vars.con0.dtype in (int, float) + assert m.constraints.rhs.con0.dtype in (int, float) + + +def test_constraint_assignment_with_tuples() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + m.add_constraints([(1, x), (10, y)], EQUAL, 0, name="c") + for attr in m.constraints.dataset_attrs: + assert "c" in getattr(m.constraints, attr) + assert m.constraints.labels.c.shape == (10, 10) + + +def test_constraint_assignment_chunked() -> None: + # setting bounds with one pd.DataFrame and one pd.Series + m: Model = Model(chunk=5) + lower = pd.DataFrame(np.zeros((10, 10))) + upper = pd.Series(np.ones(10)) + x = m.add_variables(lower, upper) + m.add_constraints(x, GREATER_EQUAL, 0, name="c") + assert m.constraints.coeffs.c.data.shape == ( + 10, + 10, + 1, + ) + assert isinstance(m.constraints.coeffs.c.data, dask.array.core.Array) + + +def test_constraint_assignment_with_reindex() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + m.add_constraints(1 * x + 10 * y, EQUAL, 0) + + shuffled_coords = [2, 1, 3, 4, 6, 5, 7, 9, 8, 0] + + con = x.loc[shuffled_coords] + y >= 10 + assert (con.coords["dim_0"].values == shuffled_coords).all() + + +@pytest.mark.parametrize( + "rhs_factory", + [ + pytest.param(lambda m, v: v, id="numpy"), + pytest.param(lambda m, v: xr.DataArray(v, dims=["dim_0"]), id="dataarray"), + pytest.param(lambda m, v: pd.Series(v, index=v), id="series"), + pytest.param( + lambda m, v: m.add_variables(coords=[v]), + id="variable", + ), + pytest.param( + lambda m, v: 2 * m.add_variables(coords=[v]) + 1, + id="linexpr", + ), + ], +) +def test_constraint_rhs_lower_dim(rhs_factory: Any) -> None: + m = Model() + naxis = np.arange(10, dtype=float) + maxis = np.arange(10).astype(str) + x = m.add_variables(coords=[naxis, maxis]) + y = m.add_variables(coords=[naxis, maxis]) + + c = m.add_constraints(x - y >= rhs_factory(m, naxis)) + assert c.shape == (10, 10) + + +@pytest.mark.parametrize( + "rhs_factory", + [ + pytest.param(lambda m: np.ones((5, 3)), id="numpy"), + pytest.param(lambda m: pd.DataFrame(np.ones((5, 3))), id="dataframe"), + ], +) +def test_constraint_rhs_higher_dim_constant_warns( + rhs_factory: Any, caplog: Any +) -> None: + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + + with caplog.at_level("WARNING", logger="linopy.expressions"): + m.add_constraints(x >= rhs_factory(m)) + assert "dimensions" in caplog.text + + +def test_constraint_rhs_higher_dim_dataarray_reindexes() -> None: + """DataArray RHS with extra dims reindexes to expression coords (no raise).""" + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + rhs = xr.DataArray(np.ones((5, 3)), dims=["dim_0", "extra"]) + + c = m.add_constraints(x >= rhs) + assert c.shape == (5, 3) + + +@pytest.mark.parametrize( + "rhs_factory", + [ + pytest.param( + lambda m: m.add_variables(coords=[range(5), range(3)]), + id="variable", + ), + pytest.param( + lambda m: 2 * m.add_variables(coords=[range(5), range(3)]) + 1, + id="linexpr", + ), + ], +) +def test_constraint_rhs_higher_dim_expression(rhs_factory: Any) -> None: + m = Model() + x = m.add_variables(coords=[range(5)], name="x") + + c = m.add_constraints(x >= rhs_factory(m)) + assert c.shape == (5, 3) + + +def test_wrong_constraint_assignment_repeated() -> None: + # repeated variable assignment is forbidden + m: Model = Model() + x = m.add_variables() + m.add_constraints(x, LESS_EQUAL, 0, name="con") + with pytest.raises(ValueError): + m.add_constraints(x, LESS_EQUAL, 0, name="con") + + +def test_masked_constraints() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + mask = pd.Series([True] * 5 + [False] * 5) + m.add_constraints(1 * x + 10 * y, EQUAL, 0, mask=mask) + assert (m.constraints.labels.con0[0:5, :] != -1).all() + assert (m.constraints.labels.con0[5:10, :] == -1).all() + + +def test_masked_constraints_broadcast() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + mask = pd.Series([True] * 5 + [False] * 5) + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc1", mask=mask) + assert (m.constraints.labels.bc1[0:5, :] != -1).all() + assert (m.constraints.labels.bc1[5:10, :] == -1).all() + + mask2 = xr.DataArray([True] * 5 + [False] * 5, dims=["dim_1"]) + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc2", mask=mask2) + assert (m.constraints.labels.bc2[:, 0:5] != -1).all() + assert (m.constraints.labels.bc2[:, 5:10] == -1).all() + + mask3 = xr.DataArray( + [True, True, False, False, False], + dims=["dim_0"], + coords={"dim_0": range(5)}, + ) + with pytest.warns(FutureWarning, match="Missing values will be filled"): + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc3", mask=mask3) + assert (m.constraints.labels.bc3[0:2, :] != -1).all() + assert (m.constraints.labels.bc3[2:5, :] == -1).all() + assert (m.constraints.labels.bc3[5:10, :] == -1).all() + + # Mask with extra dimension not in data should raise + mask4 = xr.DataArray([True, False], dims=["extra_dim"]) + with pytest.raises(AssertionError, match="not a subset"): + m.add_constraints(1 * x + 10 * y, EQUAL, 0, name="bc4", mask=mask4) + + +def test_non_aligned_constraints() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros(10), coords=[range(10)]) + x = m.add_variables(lower, name="x") + + lower = xr.DataArray(np.zeros(8), coords=[range(8)]) + y = m.add_variables(lower, name="y") + + m.add_constraints(x == 0.0) + m.add_constraints(y == 0.0) + + with pytest.warns(UserWarning): + m.constraints.labels + + for dtype in m.constraints.labels.dtypes.values(): + assert np.issubdtype(dtype, np.integer) + + for dtype in m.constraints.coeffs.dtypes.values(): + assert np.issubdtype(dtype, np.floating) + + for dtype in m.constraints.vars.dtypes.values(): + assert np.issubdtype(dtype, np.integer) + + for dtype in m.constraints.rhs.dtypes.values(): + assert np.issubdtype(dtype, np.floating) + + +def test_constraints_flat() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper) + y = m.add_variables() + + assert isinstance(m.constraints.flat, pd.DataFrame) + assert m.constraints.flat.empty + with pytest.raises(ValueError): + m.constraints.to_matrix() + + m.add_constraints(1 * x + 10 * y, EQUAL, 0) + m.add_constraints(1 * x + 10 * y, LESS_EQUAL, 0) + m.add_constraints(1 * x + 10 * y, GREATER_EQUAL, 0) + + assert isinstance(m.constraints.flat, pd.DataFrame) + assert not m.constraints.flat.empty + + +def test_sanitize_infinities() -> None: + m: Model = Model() + + lower = xr.DataArray(np.zeros((10, 10)), coords=[range(10), range(10)]) + upper = xr.DataArray(np.ones((10, 10)), coords=[range(10), range(10)]) + x = m.add_variables(lower, upper, name="x") + y = m.add_variables(name="y") + + # Test correct infinities + m.add_constraints(x <= np.inf, name="con_inf") + m.add_constraints(y >= -np.inf, name="con_neg_inf") + m.constraints.sanitize_infinities() + assert (m.constraints["con_inf"].labels == -1).all() + assert (m.constraints["con_neg_inf"].labels == -1).all() + + # Test incorrect infinities + with pytest.raises(ValueError): + m.add_constraints(x >= np.inf, name="con_wrong_inf") + with pytest.raises(ValueError): + m.add_constraints(y <= -np.inf, name="con_wrong_neg_inf") + + +class TestConstraintCoordinateAlignment: + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) + def subset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "xarray": + return xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + return pd.Series([10.0, 30.0], index=pd.Index([1, 3], name="dim_2")) + + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) + def superset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "xarray": + return xr.DataArray( + np.arange(25, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + return pd.Series( + np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") + ) + + def test_var_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + con = v <= subset + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert con.rhs.sel(dim_2=3).item() == 30.0 + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_var_comparison_subset( + self, v: Variable, subset: xr.DataArray, sign: str + ) -> None: + if sign == LESS_EQUAL: + con = v <= subset + elif sign == GREATER_EQUAL: + con = v >= subset + else: + con = v == subset + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == 10.0 + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + def test_expr_le_subset(self, v: Variable, subset: xr.DataArray) -> None: + expr = v + 5 + con = expr <= subset + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert con.rhs.sel(dim_2=1).item() == pytest.approx(5.0) + assert con.rhs.sel(dim_2=3).item() == pytest.approx(25.0) + assert np.isnan(con.rhs.sel(dim_2=0).item()) + + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL, EQUAL]) + def test_subset_comparison_var( + self, v: Variable, subset: xr.DataArray, sign: str + ) -> None: + if sign == LESS_EQUAL: + con = subset <= v + elif sign == GREATER_EQUAL: + con = subset >= v + else: + con = subset == v + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert np.isnan(con.rhs.sel(dim_2=0).item()) + assert con.rhs.sel(dim_2=1).item() == pytest.approx(10.0) + + @pytest.mark.parametrize("sign", [LESS_EQUAL, GREATER_EQUAL]) + def test_superset_comparison_var( + self, v: Variable, superset: xr.DataArray, sign: str + ) -> None: + if sign == LESS_EQUAL: + con = superset <= v + else: + con = superset >= v + assert con.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(con.lhs.coeffs.values).any() + assert not np.isnan(con.rhs.values).any() + + def test_constraint_rhs_extra_dims_broadcasts(self, v: Variable) -> None: + rhs = xr.DataArray( + [[1.0, 2.0]], + dims=["extra", "dim_2"], + coords={"dim_2": [0, 1]}, + ) + c = v <= rhs + assert "extra" in c.dims + + def test_subset_constraint_solve_integration(self) -> None: + if not available_solvers: + pytest.skip("No solver available") + solver = "highs" if "highs" in available_solvers else available_solvers[0] + m = Model() + coords = pd.RangeIndex(5, name="i") + x = m.add_variables(lower=0, upper=100, coords=[coords], name="x") + subset_ub = xr.DataArray([10.0, 20.0], dims=["i"], coords={"i": [1, 3]}) + m.add_constraints(x <= subset_ub, name="subset_ub") + m.add_objective(x.sum(), sense="max") + m.solve(solver_name=solver) + sol = m.solution["x"] + assert sol.sel(i=1).item() == pytest.approx(10.0) + assert sol.sel(i=3).item() == pytest.approx(20.0) + assert sol.sel(i=0).item() == pytest.approx(100.0) + assert sol.sel(i=2).item() == pytest.approx(100.0) + assert sol.sel(i=4).item() == pytest.approx(100.0) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 1378f48d..a4e4abfa 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,6 +7,7 @@ from __future__ import annotations +from collections.abc import Generator from typing import Any import numpy as np @@ -16,6 +17,7 @@ import xarray as xr from xarray.testing import assert_equal +import linopy from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge from linopy.constants import HELPER_DIMS, TERM_DIM from linopy.expressions import ScalarLinearExpression @@ -23,6 +25,14 @@ from linopy.variables import ScalarVariable +@pytest.fixture(autouse=True) +def _use_v1_convention() -> Generator[None, None, None]: + """Use v1 arithmetic convention for all tests in this module.""" + linopy.options["arithmetic_convention"] = "v1" + yield + linopy.options["arithmetic_convention"] = "legacy" + + def test_empty_linexpr(m: Model) -> None: LinearExpression(None, m) @@ -403,8 +413,10 @@ def test_linear_expression_sum( assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) - # test special case otherride coords - expr = v.loc[:9] + v.loc[10:] + # test special case override coords using assign_coords + a = v.loc[:9] + b = v.loc[10:].assign_coords(dim_2=a.coords["dim_2"]) + expr = a + b assert expr.nterm == 2 assert len(expr.coords["dim_2"]) == 10 @@ -427,8 +439,10 @@ def test_linear_expression_sum_with_const( assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) - # test special case otherride coords - expr = v.loc[:9] + v.loc[10:] + # test special case override coords using assign_coords + a = v.loc[:9] + b = v.loc[10:].assign_coords(dim_2=a.coords["dim_2"]) + expr = a + b assert expr.nterm == 2 assert len(expr.coords["dim_2"]) == 10 @@ -538,6 +552,12 @@ def test_linear_expression_multiplication_invalid( class TestCoordinateAlignment: + @pytest.fixture + def matching(self) -> xr.DataArray: + return xr.DataArray( + np.arange(20, dtype=float), dims=["dim_2"], coords={"dim_2": range(20)} + ) + @pytest.fixture(params=["da", "series"]) def subset(self, request: Any) -> xr.DataArray | pd.Series: if request.param == "da": @@ -574,8 +594,24 @@ def nan_constant(self, request: Any) -> xr.DataArray | pd.Series: return pd.Series(vals, index=pd.Index(range(20), name="dim_2")) class TestSubset: + """ + Under v1, subset operations raise ValueError (exact join). + Use explicit join= to recover desired behavior. + """ + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_subset_raises( + self, + v: Variable, + subset: xr.DataArray, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="exact"): + target * subset + @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_mul_subset_fills_zeros( + def test_mul_subset_join_left( self, v: Variable, subset: xr.DataArray, @@ -583,13 +619,24 @@ def test_mul_subset_fills_zeros( operand: str, ) -> None: target = v if operand == "var" else 1 * v - result = target * subset + result = target.mul(subset, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_add_subset_fills_zeros( + def test_add_subset_raises( + self, + v: Variable, + subset: xr.DataArray, + operand: str, + ) -> None: + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="exact"): + target + subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_subset_join_left( self, v: Variable, subset: xr.DataArray, @@ -597,17 +644,28 @@ def test_add_subset_fills_zeros( operand: str, ) -> None: if operand == "var": - result = v + subset + result = v.add(subset, join="left") expected = expected_fill else: - result = (v + 5) + subset + result = (v + 5).add(subset, join="left") expected = expected_fill + 5 assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_sub_subset_fills_negated( + def test_sub_subset_raises( + self, + v: Variable, + subset: xr.DataArray, + operand: str, + ) -> None: + target = v if operand == "var" else v + 5 + with pytest.raises(ValueError, match="exact"): + target - subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_subset_join_left( self, v: Variable, subset: xr.DataArray, @@ -615,242 +673,264 @@ def test_sub_subset_fills_negated( operand: str, ) -> None: if operand == "var": - result = v - subset + result = v.sub(subset, join="left") expected = -expected_fill else: - result = (v + 5) - subset + result = (v + 5).sub(subset, join="left") expected = 5 - expected_fill assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_div_subset_inverts_nonzero( + def test_div_subset_raises( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else 1 * v + with pytest.raises(ValueError, match="exact"): + target / subset + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_subset_join_left( self, v: Variable, subset: xr.DataArray, operand: str ) -> None: target = v if operand == "var" else 1 * v - result = target / subset + result = target.div(subset, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) - def test_subset_add_var_coefficients( - self, v: Variable, subset: xr.DataArray - ) -> None: - result = subset + v - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + def test_subset_add_var_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + subset + v - def test_subset_sub_var_coefficients( - self, v: Variable, subset: xr.DataArray - ) -> None: - result = subset - v - np.testing.assert_array_equal(result.coeffs.squeeze().values, -np.ones(20)) + def test_subset_sub_var_raises(self, v: Variable, subset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + subset - v class TestSuperset: - def test_add_superset_pins_to_lhs_coords( + """Under v1, superset operations raise ValueError (exact join).""" + + def test_add_superset_raises(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v + superset + + def test_add_superset_join_left( self, v: Variable, superset: xr.DataArray ) -> None: - result = v + superset + result = v.add(superset, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() - def test_add_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset + v, v + superset) + def test_mul_superset_raises(self, v: Variable, superset: xr.DataArray) -> None: + with pytest.raises(ValueError, match="exact"): + v * superset - def test_sub_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset - v, -v + superset) - - def test_mul_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: - assert_linequal(superset * v, v * superset) - - def test_mul_superset_pins_to_lhs_coords( + def test_mul_superset_join_inner( self, v: Variable, superset: xr.DataArray ) -> None: - result = v * superset + result = v.mul(superset, join="inner") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() - def test_div_superset_pins_to_lhs_coords(self, v: Variable) -> None: + def test_div_superset_raises(self, v: Variable) -> None: superset_nonzero = xr.DataArray( np.arange(1, 26, dtype=float), dims=["dim_2"], coords={"dim_2": range(25)}, ) - result = v / superset_nonzero + with pytest.raises(ValueError, match="exact"): + v / superset_nonzero + + def test_div_superset_join_inner(self, v: Variable) -> None: + superset_nonzero = xr.DataArray( + np.arange(1, 26, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + result = v.div(superset_nonzero, join="inner") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() class TestDisjoint: - def test_add_disjoint_fills_zeros(self, v: Variable) -> None: + """Under v1, disjoint operations raise ValueError (exact join).""" + + def test_add_disjoint_raises(self, v: Variable) -> None: disjoint = xr.DataArray( [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} ) - result = v + disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.const.values).any() - np.testing.assert_array_equal(result.const.values, np.zeros(20)) + with pytest.raises(ValueError, match="exact"): + v + disjoint + + def test_add_disjoint_join_outer(self, v: Variable) -> None: + disjoint = xr.DataArray( + [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v.add(disjoint, join="outer") + assert result.sizes["dim_2"] == 22 # union of [0..19] and [50, 60] + + def test_mul_disjoint_raises(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + with pytest.raises(ValueError, match="exact"): + v * disjoint - def test_mul_disjoint_fills_zeros(self, v: Variable) -> None: + def test_mul_disjoint_join_left(self, v: Variable) -> None: disjoint = xr.DataArray( [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} ) - result = v * disjoint + result = v.mul(disjoint, join="left") assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) - def test_div_disjoint_preserves_coeffs(self, v: Variable) -> None: + def test_div_disjoint_raises(self, v: Variable) -> None: disjoint = xr.DataArray( [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} ) - result = v / disjoint - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + with pytest.raises(ValueError, match="exact"): + v / disjoint class TestCommutativity: - @pytest.mark.parametrize( - "make_lhs,make_rhs", - [ - (lambda v, s: s * v, lambda v, s: v * s), - (lambda v, s: s * (1 * v), lambda v, s: (1 * v) * s), - (lambda v, s: s + v, lambda v, s: v + s), - (lambda v, s: s + (v + 5), lambda v, s: (v + 5) + s), - ], - ids=["subset*var", "subset*expr", "subset+var", "subset+expr"], - ) - def test_commutativity( - self, - v: Variable, - subset: xr.DataArray, - make_lhs: Any, - make_rhs: Any, + """Commutativity tests with matching coordinates under v1.""" + + def test_add_commutativity_matching_coords( + self, v: Variable, matching: xr.DataArray ) -> None: - assert_linequal(make_lhs(v, subset), make_rhs(v, subset)) + assert_linequal(v + matching, matching + v) - def test_sub_var_anticommutative( - self, v: Variable, subset: xr.DataArray + def test_mul_commutativity_matching_coords( + self, v: Variable, matching: xr.DataArray ) -> None: - assert_linequal(subset - v, -v + subset) + assert_linequal(v * matching, matching * v) - def test_sub_expr_anticommutative( + def test_subset_raises_both_sides( self, v: Variable, subset: xr.DataArray ) -> None: - expr = v + 5 - assert_linequal(subset - expr, -(expr - subset)) + """Subset operations raise regardless of operand order.""" + with pytest.raises(ValueError, match="exact"): + v * subset + with pytest.raises(ValueError, match="exact"): + subset * v - def test_add_commutativity_full_coords(self, v: Variable) -> None: - full = xr.DataArray( - np.arange(20, dtype=float), - dims=["dim_2"], - coords={"dim_2": range(20)}, + def test_commutativity_with_join( + self, v: Variable, subset: xr.DataArray + ) -> None: + """Commutativity holds with explicit join.""" + assert_linequal( + v.add(subset, join="inner"), + subset + v.reindex({"dim_2": [1, 3]}), ) - assert_linequal(v + full, full + v) class TestQuadratic: - def test_quadexpr_add_subset( + """Under v1, subset operations on quadratic expressions raise.""" + + def test_quadexpr_add_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr + subset + + def test_quadexpr_add_subset_join_left( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray, ) -> None: qexpr = v * v - result = qexpr + subset + result = qexpr.add(subset, join="left") assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, expected_fill) - def test_quadexpr_sub_subset( + def test_quadexpr_sub_subset_raises( + self, v: Variable, subset: xr.DataArray + ) -> None: + qexpr = v * v + with pytest.raises(ValueError, match="exact"): + qexpr - subset + + def test_quadexpr_sub_subset_join_left( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray, ) -> None: qexpr = v * v - result = qexpr - subset + result = qexpr.sub(subset, join="left") assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.const.values).any() np.testing.assert_array_equal(result.const.values, -expected_fill) - def test_quadexpr_mul_subset( - self, - v: Variable, - subset: xr.DataArray, - expected_fill: np.ndarray, + def test_quadexpr_mul_subset_raises( + self, v: Variable, subset: xr.DataArray ) -> None: qexpr = v * v - result = qexpr * subset - assert isinstance(result, QuadraticExpression) - assert result.sizes["dim_2"] == v.sizes["dim_2"] - assert not np.isnan(result.coeffs.values).any() - np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + with pytest.raises(ValueError, match="exact"): + qexpr * subset - def test_subset_mul_quadexpr( + def test_quadexpr_mul_subset_join_left( self, v: Variable, subset: xr.DataArray, expected_fill: np.ndarray, ) -> None: qexpr = v * v - result = subset * qexpr + result = qexpr.mul(subset, join="left") assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == v.sizes["dim_2"] assert not np.isnan(result.coeffs.values).any() np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) - def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: + def test_quadexpr_add_matching( + self, v: Variable, matching: xr.DataArray + ) -> None: qexpr = v * v - assert_quadequal(subset + qexpr, qexpr + subset) + assert_quadequal(matching + qexpr, qexpr + matching) class TestMissingValues: """ Same shape as variable but with NaN entries in the constant. - NaN values are filled with operation-specific neutral elements: - - Addition/subtraction: NaN -> 0 (additive identity) - - Multiplication: NaN -> 0 (zeroes out the variable) - - Division: NaN -> 1 (multiplicative identity, no scaling) + Under v1 convention, NaN values propagate through arithmetic + (no implicit fillna). """ NAN_POSITIONS = [0, 5, 19] @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_add_nan_filled( + def test_add_nan_propagates( self, v: Variable, nan_constant: xr.DataArray | pd.Series, operand: str, ) -> None: - base_const = 0.0 if operand == "var" else 5.0 target = v if operand == "var" else v + 5 result = target + nan_constant assert result.sizes["dim_2"] == 20 - assert not np.isnan(result.const.values).any() - # At NaN positions, const should be unchanged (added 0) for i in self.NAN_POSITIONS: - assert result.const.values[i] == base_const + assert np.isnan(result.const.values[i]) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_sub_nan_filled( + def test_sub_nan_propagates( self, v: Variable, nan_constant: xr.DataArray | pd.Series, operand: str, ) -> None: - base_const = 0.0 if operand == "var" else 5.0 target = v if operand == "var" else v + 5 result = target - nan_constant assert result.sizes["dim_2"] == 20 - assert not np.isnan(result.const.values).any() - # At NaN positions, const should be unchanged (subtracted 0) for i in self.NAN_POSITIONS: - assert result.const.values[i] == base_const + assert np.isnan(result.const.values[i]) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_mul_nan_filled( + def test_mul_nan_propagates( self, v: Variable, nan_constant: xr.DataArray | pd.Series, @@ -859,13 +939,11 @@ def test_mul_nan_filled( target = v if operand == "var" else 1 * v result = target * nan_constant assert result.sizes["dim_2"] == 20 - assert not np.isnan(result.coeffs.squeeze().values).any() - # At NaN positions, coeffs should be 0 (variable zeroed out) for i in self.NAN_POSITIONS: - assert result.coeffs.squeeze().values[i] == 0.0 + assert np.isnan(result.coeffs.squeeze().values[i]) @pytest.mark.parametrize("operand", ["var", "expr"]) - def test_div_nan_filled( + def test_div_nan_propagates( self, v: Variable, nan_constant: xr.DataArray | pd.Series, @@ -874,11 +952,8 @@ def test_div_nan_filled( target = v if operand == "var" else 1 * v result = target / nan_constant assert result.sizes["dim_2"] == 20 - assert not np.isnan(result.coeffs.squeeze().values).any() - # At NaN positions, coeffs should be unchanged (divided by 1) - original_coeffs = (1 * v).coeffs.squeeze().values for i in self.NAN_POSITIONS: - assert result.coeffs.squeeze().values[i] == original_coeffs[i] + assert np.isnan(result.coeffs.squeeze().values[i]) def test_add_commutativity( self, @@ -887,8 +962,6 @@ def test_add_commutativity( ) -> None: result_a = v + nan_constant result_b = nan_constant + v - assert not np.isnan(result_a.const.values).any() - assert not np.isnan(result_b.const.values).any() np.testing.assert_array_equal(result_a.const.values, result_b.const.values) np.testing.assert_array_equal( result_a.coeffs.values, result_b.coeffs.values @@ -901,13 +974,11 @@ def test_mul_commutativity( ) -> None: result_a = v * nan_constant result_b = nan_constant * v - assert not np.isnan(result_a.coeffs.values).any() - assert not np.isnan(result_b.coeffs.values).any() np.testing.assert_array_equal( result_a.coeffs.values, result_b.coeffs.values ) - def test_quadexpr_add_nan( + def test_quadexpr_add_nan_propagates( self, v: Variable, nan_constant: xr.DataArray | pd.Series, @@ -916,75 +987,89 @@ def test_quadexpr_add_nan( result = qexpr + nan_constant assert isinstance(result, QuadraticExpression) assert result.sizes["dim_2"] == 20 - assert not np.isnan(result.const.values).any() + for i in self.NAN_POSITIONS: + assert np.isnan(result.const.values[i]) class TestExpressionWithNaN: - """Test that NaN in expression's own const/coeffs doesn't propagate.""" + """ + Under v1, NaN in expression's own const/coeffs propagates through + arithmetic (no implicit fillna). + """ def test_shifted_expr_add_scalar(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr + 5 - assert not np.isnan(result.const.values).any() - assert result.const.values[0] == 5.0 + # Position 0 has NaN from shift, NaN + 5 = NaN under v1 + assert np.isnan(result.const.values[0]) def test_shifted_expr_mul_scalar(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr * 2 - assert not np.isnan(result.coeffs.squeeze().values).any() - assert result.coeffs.squeeze().values[0] == 0.0 + # Position 0 has NaN coeffs from shift, NaN * 2 = NaN under v1 + assert np.isnan(result.coeffs.squeeze().values[0]) def test_shifted_expr_add_array(self, v: Variable) -> None: arr = np.arange(v.sizes["dim_2"], dtype=float) expr = (1 * v).shift(dim_2=1) result = expr + arr - assert not np.isnan(result.const.values).any() - assert result.const.values[0] == 0.0 + # Position 0 has NaN const from shift, NaN + 0 = NaN under v1 + assert np.isnan(result.const.values[0]) def test_shifted_expr_mul_array(self, v: Variable) -> None: arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 expr = (1 * v).shift(dim_2=1) result = expr * arr - assert not np.isnan(result.coeffs.squeeze().values).any() - assert result.coeffs.squeeze().values[0] == 0.0 + # Position 0 has NaN coeffs from shift, NaN * 1 = NaN under v1 + assert np.isnan(result.coeffs.squeeze().values[0]) def test_shifted_expr_div_scalar(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr / 2 - assert not np.isnan(result.coeffs.squeeze().values).any() - assert result.coeffs.squeeze().values[0] == 0.0 + assert np.isnan(result.coeffs.squeeze().values[0]) def test_shifted_expr_sub_scalar(self, v: Variable) -> None: expr = (1 * v).shift(dim_2=1) result = expr - 3 - assert not np.isnan(result.const.values).any() - assert result.const.values[0] == -3.0 + assert np.isnan(result.const.values[0]) def test_shifted_expr_div_array(self, v: Variable) -> None: arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 expr = (1 * v).shift(dim_2=1) result = expr / arr - assert not np.isnan(result.coeffs.squeeze().values).any() - assert result.coeffs.squeeze().values[0] == 0.0 + assert np.isnan(result.coeffs.squeeze().values[0]) def test_variable_to_linexpr_nan_coefficient(self, v: Variable) -> None: + """to_linexpr always fills NaN coefficients with 0 (not convention-aware).""" nan_coeff = np.ones(v.sizes["dim_2"]) nan_coeff[0] = np.nan result = v.to_linexpr(nan_coeff) - assert not np.isnan(result.coeffs.squeeze().values).any() assert result.coeffs.squeeze().values[0] == 0.0 class TestMultiDim: - def test_multidim_subset_mul(self, m: Model) -> None: + """Under v1, multi-dim subset operations raise.""" + + def test_multidim_subset_mul_raises(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + with pytest.raises(ValueError, match="exact"): + w * subset_2d + def test_multidim_subset_mul_join_left(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") subset_2d = xr.DataArray( [[2.0, 3.0], [4.0, 5.0]], dims=["a", "b"], coords={"a": [1, 3], "b": [0, 4]}, ) - result = w * subset_2d + result = w.mul(subset_2d, join="left") assert result.sizes["a"] == 4 assert result.sizes["b"] == 5 assert not np.isnan(result.coeffs.values).any() @@ -993,23 +1078,17 @@ def test_multidim_subset_mul(self, m: Model) -> None: assert result.coeffs.squeeze().sel(a=0, b=0).item() == pytest.approx(0.0) assert result.coeffs.squeeze().sel(a=1, b=2).item() == pytest.approx(0.0) - def test_multidim_subset_add(self, m: Model) -> None: + def test_multidim_subset_add_raises(self, m: Model) -> None: coords_a = pd.RangeIndex(4, name="a") coords_b = pd.RangeIndex(5, name="b") w = m.add_variables(coords=[coords_a, coords_b], name="w") - subset_2d = xr.DataArray( [[2.0, 3.0], [4.0, 5.0]], dims=["a", "b"], coords={"a": [1, 3], "b": [0, 4]}, ) - result = w + subset_2d - assert result.sizes["a"] == 4 - assert result.sizes["b"] == 5 - assert not np.isnan(result.const.values).any() - assert result.const.sel(a=1, b=0).item() == pytest.approx(2.0) - assert result.const.sel(a=3, b=4).item() == pytest.approx(5.0) - assert result.const.sel(a=0, b=0).item() == pytest.approx(0.0) + with pytest.raises(ValueError, match="exact"): + w + subset_2d class TestXarrayCompat: def test_da_eq_da_still_works(self) -> None: @@ -1877,12 +1956,14 @@ def c(self, m2: Model) -> Variable: return m2.variables["c"] class TestAddition: - def test_add_join_none_preserves_default( + def test_add_join_none_raises_on_mismatch( self, a: Variable, b: Variable ) -> None: - result_default = a.to_linexpr() + b.to_linexpr() - result_none = a.to_linexpr().add(b.to_linexpr(), join=None) - assert_linequal(result_default, result_none) + # a has i=[0,1,2], b has i=[1,2,3] — exact default raises + with pytest.raises(ValueError, match="exact"): + a.to_linexpr() + b.to_linexpr() + with pytest.raises(ValueError, match="exact"): + a.to_linexpr().add(b.to_linexpr(), join=None) def test_add_expr_join_inner(self, a: Variable, b: Variable) -> None: result = a.to_linexpr().add(b.to_linexpr(), join="inner") @@ -2138,12 +2219,12 @@ def test_div_constant_outer_fill_values(self, a: Variable) -> None: class TestQuadratic: def test_quadratic_add_constant_join_inner( - self, a: Variable, b: Variable + self, a: Variable, c: Variable ) -> None: - quad = a.to_linexpr() * b.to_linexpr() + quad = a.to_linexpr() * c.to_linexpr() const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.add(const, join="inner") - assert list(result.data.indexes["i"]) == [1, 2, 3] + assert list(result.data.indexes["i"]) == [1, 2] def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: quad = a.to_linexpr() * a.to_linexpr() @@ -2152,9 +2233,9 @@ def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: assert list(result.data.indexes["i"]) == [0, 1] def test_quadratic_mul_constant_join_inner( - self, a: Variable, b: Variable + self, a: Variable, c: Variable ) -> None: - quad = a.to_linexpr() * b.to_linexpr() + quad = a.to_linexpr() * c.to_linexpr() const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) result = quad.mul(const, join="inner") - assert list(result.data.indexes["i"]) == [1, 2, 3] + assert list(result.data.indexes["i"]) == [1, 2] diff --git a/test/test_linear_expression_legacy.py b/test/test_linear_expression_legacy.py new file mode 100644 index 00000000..1378f48d --- /dev/null +++ b/test/test_linear_expression_legacy.py @@ -0,0 +1,2160 @@ +#!/usr/bin/env python3 +""" +Created on Wed Mar 17 17:06:36 2021. + +@author: fabian +""" + +from __future__ import annotations + +from typing import Any + +import numpy as np +import pandas as pd +import polars as pl +import pytest +import xarray as xr +from xarray.testing import assert_equal + +from linopy import LinearExpression, Model, QuadraticExpression, Variable, merge +from linopy.constants import HELPER_DIMS, TERM_DIM +from linopy.expressions import ScalarLinearExpression +from linopy.testing import assert_linequal, assert_quadequal +from linopy.variables import ScalarVariable + + +def test_empty_linexpr(m: Model) -> None: + LinearExpression(None, m) + + +def test_linexpr_with_wrong_data(m: Model) -> None: + with pytest.raises(ValueError): + LinearExpression(xr.Dataset({"a": [1]}), m) + + coeffs = xr.DataArray([1, 2], dims=["a"]) + vars = xr.DataArray([1, 2], dims=["a"]) + data = xr.Dataset({"coeffs": coeffs, "vars": vars}) + with pytest.raises(ValueError): + LinearExpression(data, m) + + # with model as None + coeffs = xr.DataArray(np.array([1, 2]), dims=[TERM_DIM]) + vars = xr.DataArray(np.array([1, 2]), dims=[TERM_DIM]) + data = xr.Dataset({"coeffs": coeffs, "vars": vars}) + with pytest.raises(ValueError): + LinearExpression(data, None) # type: ignore + + +def test_linexpr_with_helper_dims_as_coords(m: Model) -> None: + coords = [pd.Index([0], name="a"), pd.Index([1, 2], name=TERM_DIM)] + coeffs = xr.DataArray(np.array([[1, 2]]), coords=coords) + vars = xr.DataArray(np.array([[1, 2]]), coords=coords) + + data = xr.Dataset({"coeffs": coeffs, "vars": vars}) + assert set(HELPER_DIMS).intersection(set(data.coords)) + + expr = LinearExpression(data, m) + assert not set(HELPER_DIMS).intersection(set(expr.data.coords)) + + +def test_linexpr_with_data_without_coords(m: Model) -> None: + lhs = 1 * m["x"] + vars = xr.DataArray(lhs.vars.values, dims=["dim_0", TERM_DIM]) + coeffs = xr.DataArray(lhs.coeffs.values, dims=["dim_0", TERM_DIM]) + data = xr.Dataset({"vars": vars, "coeffs": coeffs}) + expr = LinearExpression(data, m) + assert_linequal(expr, lhs) + + +def test_linexpr_from_constant_dataarray(m: Model) -> None: + const = xr.DataArray([1, 2], dims=["dim_0"]) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_pl_series(m: Model) -> None: + const = pl.Series([1, 2]) + expr = LinearExpression(const, m) + assert (expr.const == const.to_numpy()).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_pandas_series(m: Model) -> None: + const = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_pandas_dataframe(m: Model) -> None: + const = pd.DataFrame([[1, 2], [3, 4]], columns=["a", "b"]) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_numpy_array(m: Model) -> None: + const = np.array([1, 2]) + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_linexpr_from_constant_scalar(m: Model) -> None: + const = 1 + expr = LinearExpression(const, m) + assert (expr.const == const).all() + assert expr.nterm == 0 + + +def test_repr(m: Model) -> None: + expr = m.linexpr((10, "x"), (1, "y")) + expr.__repr__() + + +def test_fill_value() -> None: + isinstance(LinearExpression._fill_value, dict) + + +def test_linexpr_with_scalars(m: Model) -> None: + expr = m.linexpr((10, "x"), (1, "y")) + target = xr.DataArray( + [[10, 1], [10, 1]], coords={"dim_0": [0, 1]}, dims=["dim_0", TERM_DIM] + ) + assert_equal(expr.coeffs, target) + + +def test_linexpr_with_variables_and_constants( + m: Model, x: Variable, y: Variable +) -> None: + expr = m.linexpr((10, x), (1, y), 2) + assert (expr.const == 2).all() + + +def test_linexpr_with_series(m: Model, v: Variable) -> None: + lhs = pd.Series(np.arange(20)), v + expr = m.linexpr(lhs) + isinstance(expr, LinearExpression) + + +def test_linexpr_with_dataframe(m: Model, z: Variable) -> None: + lhs = pd.DataFrame(z.labels), z + expr = m.linexpr(lhs) + isinstance(expr, LinearExpression) + + +def test_linexpr_duplicated_index(m: Model) -> None: + expr = m.linexpr((10, "x"), (-1, "x")) + assert (expr.data._term == [0, 1]).all() + + +def test_linear_expression_with_multiplication(x: Variable) -> None: + expr = 1 * x + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert len(expr.vars.dim_0) == x.shape[0] + + expr = x * 1 + assert isinstance(expr, LinearExpression) + + expr2 = x.mul(1) + assert_linequal(expr, expr2) + + expr3 = expr.mul(1) + assert_linequal(expr, expr3) + + expr = x / 1 + assert isinstance(expr, LinearExpression) + + expr = x / 1.0 + assert isinstance(expr, LinearExpression) + + expr2 = x.div(1) + assert_linequal(expr, expr2) + + expr3 = expr.div(1) + assert_linequal(expr, expr3) + + expr = np.array([1, 2]) * x + assert isinstance(expr, LinearExpression) + + expr = np.array(1) * x + assert isinstance(expr, LinearExpression) + + expr = xr.DataArray(np.array([[1, 2], [2, 3]])) * x + assert isinstance(expr, LinearExpression) + + expr = pd.Series([1, 2], index=pd.RangeIndex(2, name="dim_0")) * x + assert isinstance(expr, LinearExpression) + + quad = x * x + assert isinstance(quad, QuadraticExpression) + + with pytest.raises(TypeError): + quad * quad + + expr = x * 1 + assert isinstance(expr, LinearExpression) + assert expr.__mul__(object()) is NotImplemented + assert expr.__rmul__(object()) is NotImplemented + + +def test_linear_expression_with_addition(m: Model, x: Variable, y: Variable) -> None: + expr = 10 * x + y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((10, "x"), (1, "y"))) + + expr = x + 8 * y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((1, "x"), (8, "y"))) + + expr = x + y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((1, "x"), (1, "y"))) + + expr2 = x.add(y) + assert_linequal(expr, expr2) + + expr3 = (x * 1).add(y) + assert_linequal(expr, expr3) + + expr3 = x + (x * x) + assert isinstance(expr3, QuadraticExpression) + + +def test_linear_expression_with_raddition(m: Model, x: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 + expr + assert isinstance(expr, LinearExpression) + expr_3: LinearExpression = expr + 10.0 + assert_linequal(expr_2, expr_3) + + +def test_linear_expression_with_subtraction(m: Model, x: Variable, y: Variable) -> None: + expr = x - y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((1, "x"), (-1, "y"))) + + expr2 = x.sub(y) + assert_linequal(expr, expr2) + + expr3: LinearExpression = x * 1 + expr4 = expr3.sub(y) + assert_linequal(expr, expr4) + + expr = -x - 8 * y + assert isinstance(expr, LinearExpression) + assert_linequal(expr, m.linexpr((-1, "x"), (-8, "y"))) + + +def test_linear_expression_rsubtraction(x: Variable, y: Variable) -> None: + expr = x * 1.0 + expr_2: LinearExpression = 10.0 - expr + assert isinstance(expr_2, LinearExpression) + expr_3: LinearExpression = (expr - 10.0) * -1 + assert_linequal(expr_2, expr_3) + assert expr.__rsub__(object()) is NotImplemented + + +def test_linear_expression_with_constant(m: Model, x: Variable, y: Variable) -> None: + expr = x + 1 + assert isinstance(expr, LinearExpression) + assert (expr.const == 1).all() + + expr = -x - 8 * y - 10 + assert isinstance(expr, LinearExpression) + assert (expr.const == -10).all() + assert expr.nterm == 2 + + +def test_linear_expression_with_constant_multiplication( + m: Model, x: Variable, y: Variable +) -> None: + expr = x + 1 + + obs = expr * 10 + assert isinstance(obs, LinearExpression) + assert (obs.const == 10).all() + + obs = expr * pd.Series([1, 2, 3], index=pd.RangeIndex(3, name="new_dim")) + assert isinstance(obs, LinearExpression) + assert obs.shape == (2, 3, 1) + + +def test_linear_expression_multi_indexed(u: Variable) -> None: + expr = 3 * u + 1 * u + assert isinstance(expr, LinearExpression) + + +def test_linear_expression_with_errors(m: Model, x: Variable) -> None: + with pytest.raises(TypeError): + x / x + + with pytest.raises(TypeError): + x / (1 * x) + + with pytest.raises(TypeError): + m.linexpr((10, x.labels), (1, "y")) + + with pytest.raises(TypeError): + m.linexpr(a=2) # type: ignore + + +def test_linear_expression_from_rule(m: Model, x: Variable, y: Variable) -> None: + def bound(m: Model, i: int) -> ScalarLinearExpression: + return ( + (i - 1) * x.at[i - 1] + y.at[i] + 1 * x.at[i] + if i == 1 + else i * x.at[i] - y.at[i] + ) + + expr = LinearExpression.from_rule(m, bound, x.coords) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 3 + repr(expr) # test repr + + +def test_linear_expression_from_rule_with_return_none( + m: Model, x: Variable, y: Variable +) -> None: + # with return type None + def bound(m: Model, i: int) -> ScalarLinearExpression | None: + if i == 1: + return (i - 1) * x.at[i - 1] + y.at[i] + return None + + expr = LinearExpression.from_rule(m, bound, x.coords) + assert isinstance(expr, LinearExpression) + assert (expr.vars[0] == -1).all() + assert (expr.vars[1] != -1).all() + assert expr.coeffs[0].isnull().all() + assert expr.coeffs[1].notnull().all() + repr(expr) # test repr + + +def test_linear_expression_addition(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + other = 2 * y + z + res = expr + other + + assert res.nterm == expr.nterm + other.nterm + assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() + assert (res.coords["dim_1"] == other.coords["dim_1"]).all() + assert res.data.notnull().all().to_array().all() + + res2 = expr.add(other) + assert_linequal(res, res2) + + assert isinstance(x - expr, LinearExpression) + assert isinstance(x + expr, LinearExpression) + + +def test_linear_expression_addition_with_constant( + x: Variable, y: Variable, z: Variable +) -> None: + expr = 10 * x + y + 10 + assert (expr.const == 10).all() + + expr = 10 * x + y + np.array([2, 3]) + assert list(expr.const) == [2, 3] + + expr = 10 * x + y + pd.Series([2, 3]) + assert list(expr.const) == [2, 3] + + +def test_linear_expression_subtraction(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y - 10 + assert (expr.const == -10).all() + + expr = 10 * x + y - np.array([2, 3]) + assert list(expr.const) == [-2, -3] + + expr = 10 * x + y - pd.Series([2, 3]) + assert list(expr.const) == [-2, -3] + + +def test_linear_expression_substraction( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + other = 2 * y - z + res = expr - other + + assert res.nterm == expr.nterm + other.nterm + assert (res.coords["dim_0"] == expr.coords["dim_0"]).all() + assert (res.coords["dim_1"] == other.coords["dim_1"]).all() + assert res.data.notnull().all().to_array().all() + + +def test_linear_expression_sum( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + z + res = expr.sum("dim_0") + + assert res.size == expr.size + assert res.nterm == expr.nterm * len(expr.data.dim_0) + + res = expr.sum() + assert res.size == expr.size + assert res.nterm == expr.size + assert res.data.notnull().all().to_array().all() + + assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) + + # test special case otherride coords + expr = v.loc[:9] + v.loc[10:] + assert expr.nterm == 2 + assert len(expr.coords["dim_2"]) == 10 + + +def test_linear_expression_sum_with_const( + x: Variable, y: Variable, z: Variable, v: Variable +) -> None: + expr = 10 * x + y + z + 10 + res = expr.sum("dim_0") + + assert res.size == expr.size + assert res.nterm == expr.nterm * len(expr.data.dim_0) + assert (res.const == 20).all() + + res = expr.sum() + assert res.size == expr.size + assert res.nterm == expr.size + assert res.data.notnull().all().to_array().all() + assert (res.const == 60).item() + + assert_linequal(expr.sum(["dim_0", TERM_DIM]), expr.sum("dim_0")) + + # test special case otherride coords + expr = v.loc[:9] + v.loc[10:] + assert expr.nterm == 2 + assert len(expr.coords["dim_2"]) == 10 + + +def test_linear_expression_sum_drop_zeros(z: Variable) -> None: + coeff = xr.zeros_like(z.labels) + coeff[1, 0] = 3 + coeff[0, 2] = 5 + expr = coeff * z + + res = expr.sum("dim_0", drop_zeros=True) + assert res.nterm == 1 + + res = expr.sum("dim_1", drop_zeros=True) + assert res.nterm == 1 + + coeff[1, 2] = 4 + expr.data["coeffs"] = coeff + res = expr.sum() + + res = expr.sum("dim_0", drop_zeros=True) + assert res.nterm == 2 + + res = expr.sum("dim_1", drop_zeros=True) + assert res.nterm == 2 + + +def test_linear_expression_sum_warn_using_dims(z: Variable) -> None: + with pytest.warns(DeprecationWarning): + (1 * z).sum(dims="dim_0") + + +def test_linear_expression_sum_warn_unknown_kwargs(z: Variable) -> None: + with pytest.raises(ValueError): + (1 * z).sum(unknown_kwarg="dim_0") + + +def test_linear_expression_power(x: Variable) -> None: + expr: LinearExpression = x * 1.0 + qd_expr = expr**2 + assert isinstance(qd_expr, QuadraticExpression) + + qd_expr2 = expr.pow(2) + assert_quadequal(qd_expr, qd_expr2) + + with pytest.raises(ValueError): + expr**3 + + +def test_linear_expression_multiplication( + x: Variable, y: Variable, z: Variable +) -> None: + expr = 10 * x + y + z + mexpr = expr * 10 + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 100).item() + + mexpr = 10 * expr + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 100).item() + + mexpr = expr / 100 + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 1 / 10).item() + + mexpr = expr / 100.0 + assert (mexpr.coeffs.sel(dim_1=0, dim_0=0, _term=0) == 1 / 10).item() + + +def test_matmul_variable_and_const(x: Variable, y: Variable) -> None: + const = np.array([1, 2]) + expr = x @ const + assert expr.nterm == 2 + assert_linequal(expr, (x * const).sum()) + + assert_linequal(x @ const, (x * const).sum()) + + assert_linequal(x.dot(const), x @ const) + + +def test_matmul_expr_and_const(x: Variable, y: Variable) -> None: + expr = 10 * x + y + const = np.array([1, 2]) + res = expr @ const + target = (10 * x) @ const + y @ const + assert res.nterm == 4 + assert_linequal(res, target) + + assert_linequal(expr.dot(const), target) + + +def test_matmul_wrong_input(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + z + with pytest.raises(TypeError): + expr @ expr + + +def test_linear_expression_multiplication_invalid( + x: Variable, y: Variable, z: Variable +) -> None: + expr = 10 * x + y + z + + with pytest.raises(TypeError): + expr = 10 * x + y + z + expr * expr + + with pytest.raises(TypeError): + expr = 10 * x + y + z + expr / x + + +class TestCoordinateAlignment: + @pytest.fixture(params=["da", "series"]) + def subset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "da": + return xr.DataArray([10.0, 30.0], dims=["dim_2"], coords={"dim_2": [1, 3]}) + return pd.Series([10.0, 30.0], index=pd.Index([1, 3], name="dim_2")) + + @pytest.fixture(params=["da", "series"]) + def superset(self, request: Any) -> xr.DataArray | pd.Series: + if request.param == "da": + return xr.DataArray( + np.arange(25, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + return pd.Series( + np.arange(25, dtype=float), index=pd.Index(range(25), name="dim_2") + ) + + @pytest.fixture + def expected_fill(self) -> np.ndarray: + arr = np.zeros(20) + arr[1] = 10.0 + arr[3] = 30.0 + return arr + + @pytest.fixture(params=["xarray", "pandas_series"], ids=["da", "series"]) + def nan_constant(self, request: Any) -> xr.DataArray | pd.Series: + vals = np.arange(20, dtype=float) + vals[0] = np.nan + vals[5] = np.nan + vals[19] = np.nan + if request.param == "xarray": + return xr.DataArray(vals, dims=["dim_2"], coords={"dim_2": range(20)}) + return pd.Series(vals, index=pd.Index(range(20), name="dim_2")) + + class TestSubset: + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_subset_fills_zeros( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + result = target * subset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_subset_fills_zeros( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + operand: str, + ) -> None: + if operand == "var": + result = v + subset + expected = expected_fill + else: + result = (v + 5) + subset + expected = expected_fill + 5 + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, expected) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_subset_fills_negated( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + operand: str, + ) -> None: + if operand == "var": + result = v - subset + expected = -expected_fill + else: + result = (v + 5) - subset + expected = 5 - expected_fill + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, expected) + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_subset_inverts_nonzero( + self, v: Variable, subset: xr.DataArray, operand: str + ) -> None: + target = v if operand == "var" else 1 * v + result = target / subset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + assert result.coeffs.squeeze().sel(dim_2=1).item() == pytest.approx(0.1) + assert result.coeffs.squeeze().sel(dim_2=0).item() == pytest.approx(1.0) + + def test_subset_add_var_coefficients( + self, v: Variable, subset: xr.DataArray + ) -> None: + result = subset + v + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + + def test_subset_sub_var_coefficients( + self, v: Variable, subset: xr.DataArray + ) -> None: + result = subset - v + np.testing.assert_array_equal(result.coeffs.squeeze().values, -np.ones(20)) + + class TestSuperset: + def test_add_superset_pins_to_lhs_coords( + self, v: Variable, superset: xr.DataArray + ) -> None: + result = v + superset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + + def test_add_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: + assert_linequal(superset + v, v + superset) + + def test_sub_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: + assert_linequal(superset - v, -v + superset) + + def test_mul_var_commutative(self, v: Variable, superset: xr.DataArray) -> None: + assert_linequal(superset * v, v * superset) + + def test_mul_superset_pins_to_lhs_coords( + self, v: Variable, superset: xr.DataArray + ) -> None: + result = v * superset + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + + def test_div_superset_pins_to_lhs_coords(self, v: Variable) -> None: + superset_nonzero = xr.DataArray( + np.arange(1, 26, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(25)}, + ) + result = v / superset_nonzero + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + + class TestDisjoint: + def test_add_disjoint_fills_zeros(self, v: Variable) -> None: + disjoint = xr.DataArray( + [100.0, 200.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v + disjoint + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, np.zeros(20)) + + def test_mul_disjoint_fills_zeros(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v * disjoint + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.zeros(20)) + + def test_div_disjoint_preserves_coeffs(self, v: Variable) -> None: + disjoint = xr.DataArray( + [10.0, 20.0], dims=["dim_2"], coords={"dim_2": [50, 60]} + ) + result = v / disjoint + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, np.ones(20)) + + class TestCommutativity: + @pytest.mark.parametrize( + "make_lhs,make_rhs", + [ + (lambda v, s: s * v, lambda v, s: v * s), + (lambda v, s: s * (1 * v), lambda v, s: (1 * v) * s), + (lambda v, s: s + v, lambda v, s: v + s), + (lambda v, s: s + (v + 5), lambda v, s: (v + 5) + s), + ], + ids=["subset*var", "subset*expr", "subset+var", "subset+expr"], + ) + def test_commutativity( + self, + v: Variable, + subset: xr.DataArray, + make_lhs: Any, + make_rhs: Any, + ) -> None: + assert_linequal(make_lhs(v, subset), make_rhs(v, subset)) + + def test_sub_var_anticommutative( + self, v: Variable, subset: xr.DataArray + ) -> None: + assert_linequal(subset - v, -v + subset) + + def test_sub_expr_anticommutative( + self, v: Variable, subset: xr.DataArray + ) -> None: + expr = v + 5 + assert_linequal(subset - expr, -(expr - subset)) + + def test_add_commutativity_full_coords(self, v: Variable) -> None: + full = xr.DataArray( + np.arange(20, dtype=float), + dims=["dim_2"], + coords={"dim_2": range(20)}, + ) + assert_linequal(v + full, full + v) + + class TestQuadratic: + def test_quadexpr_add_subset( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = qexpr + subset + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, expected_fill) + + def test_quadexpr_sub_subset( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = qexpr - subset + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.const.values).any() + np.testing.assert_array_equal(result.const.values, -expected_fill) + + def test_quadexpr_mul_subset( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = qexpr * subset + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + def test_subset_mul_quadexpr( + self, + v: Variable, + subset: xr.DataArray, + expected_fill: np.ndarray, + ) -> None: + qexpr = v * v + result = subset * qexpr + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == v.sizes["dim_2"] + assert not np.isnan(result.coeffs.values).any() + np.testing.assert_array_equal(result.coeffs.squeeze().values, expected_fill) + + def test_subset_add_quadexpr(self, v: Variable, subset: xr.DataArray) -> None: + qexpr = v * v + assert_quadequal(subset + qexpr, qexpr + subset) + + class TestMissingValues: + """ + Same shape as variable but with NaN entries in the constant. + + NaN values are filled with operation-specific neutral elements: + - Addition/subtraction: NaN -> 0 (additive identity) + - Multiplication: NaN -> 0 (zeroes out the variable) + - Division: NaN -> 1 (multiplicative identity, no scaling) + """ + + NAN_POSITIONS = [0, 5, 19] + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_add_nan_filled( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + base_const = 0.0 if operand == "var" else 5.0 + target = v if operand == "var" else v + 5 + result = target + nan_constant + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.const.values).any() + # At NaN positions, const should be unchanged (added 0) + for i in self.NAN_POSITIONS: + assert result.const.values[i] == base_const + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_sub_nan_filled( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + base_const = 0.0 if operand == "var" else 5.0 + target = v if operand == "var" else v + 5 + result = target - nan_constant + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.const.values).any() + # At NaN positions, const should be unchanged (subtracted 0) + for i in self.NAN_POSITIONS: + assert result.const.values[i] == base_const + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_mul_nan_filled( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + result = target * nan_constant + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.coeffs.squeeze().values).any() + # At NaN positions, coeffs should be 0 (variable zeroed out) + for i in self.NAN_POSITIONS: + assert result.coeffs.squeeze().values[i] == 0.0 + + @pytest.mark.parametrize("operand", ["var", "expr"]) + def test_div_nan_filled( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + operand: str, + ) -> None: + target = v if operand == "var" else 1 * v + result = target / nan_constant + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.coeffs.squeeze().values).any() + # At NaN positions, coeffs should be unchanged (divided by 1) + original_coeffs = (1 * v).coeffs.squeeze().values + for i in self.NAN_POSITIONS: + assert result.coeffs.squeeze().values[i] == original_coeffs[i] + + def test_add_commutativity( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + ) -> None: + result_a = v + nan_constant + result_b = nan_constant + v + assert not np.isnan(result_a.const.values).any() + assert not np.isnan(result_b.const.values).any() + np.testing.assert_array_equal(result_a.const.values, result_b.const.values) + np.testing.assert_array_equal( + result_a.coeffs.values, result_b.coeffs.values + ) + + def test_mul_commutativity( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + ) -> None: + result_a = v * nan_constant + result_b = nan_constant * v + assert not np.isnan(result_a.coeffs.values).any() + assert not np.isnan(result_b.coeffs.values).any() + np.testing.assert_array_equal( + result_a.coeffs.values, result_b.coeffs.values + ) + + def test_quadexpr_add_nan( + self, + v: Variable, + nan_constant: xr.DataArray | pd.Series, + ) -> None: + qexpr = v * v + result = qexpr + nan_constant + assert isinstance(result, QuadraticExpression) + assert result.sizes["dim_2"] == 20 + assert not np.isnan(result.const.values).any() + + class TestExpressionWithNaN: + """Test that NaN in expression's own const/coeffs doesn't propagate.""" + + def test_shifted_expr_add_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr + 5 + assert not np.isnan(result.const.values).any() + assert result.const.values[0] == 5.0 + + def test_shifted_expr_mul_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr * 2 + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_shifted_expr_add_array(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + expr = (1 * v).shift(dim_2=1) + result = expr + arr + assert not np.isnan(result.const.values).any() + assert result.const.values[0] == 0.0 + + def test_shifted_expr_mul_array(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 + expr = (1 * v).shift(dim_2=1) + result = expr * arr + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_shifted_expr_div_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr / 2 + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_shifted_expr_sub_scalar(self, v: Variable) -> None: + expr = (1 * v).shift(dim_2=1) + result = expr - 3 + assert not np.isnan(result.const.values).any() + assert result.const.values[0] == -3.0 + + def test_shifted_expr_div_array(self, v: Variable) -> None: + arr = np.arange(v.sizes["dim_2"], dtype=float) + 1 + expr = (1 * v).shift(dim_2=1) + result = expr / arr + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + def test_variable_to_linexpr_nan_coefficient(self, v: Variable) -> None: + nan_coeff = np.ones(v.sizes["dim_2"]) + nan_coeff[0] = np.nan + result = v.to_linexpr(nan_coeff) + assert not np.isnan(result.coeffs.squeeze().values).any() + assert result.coeffs.squeeze().values[0] == 0.0 + + class TestMultiDim: + def test_multidim_subset_mul(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") + + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + result = w * subset_2d + assert result.sizes["a"] == 4 + assert result.sizes["b"] == 5 + assert not np.isnan(result.coeffs.values).any() + assert result.coeffs.squeeze().sel(a=1, b=0).item() == pytest.approx(2.0) + assert result.coeffs.squeeze().sel(a=3, b=4).item() == pytest.approx(5.0) + assert result.coeffs.squeeze().sel(a=0, b=0).item() == pytest.approx(0.0) + assert result.coeffs.squeeze().sel(a=1, b=2).item() == pytest.approx(0.0) + + def test_multidim_subset_add(self, m: Model) -> None: + coords_a = pd.RangeIndex(4, name="a") + coords_b = pd.RangeIndex(5, name="b") + w = m.add_variables(coords=[coords_a, coords_b], name="w") + + subset_2d = xr.DataArray( + [[2.0, 3.0], [4.0, 5.0]], + dims=["a", "b"], + coords={"a": [1, 3], "b": [0, 4]}, + ) + result = w + subset_2d + assert result.sizes["a"] == 4 + assert result.sizes["b"] == 5 + assert not np.isnan(result.const.values).any() + assert result.const.sel(a=1, b=0).item() == pytest.approx(2.0) + assert result.const.sel(a=3, b=4).item() == pytest.approx(5.0) + assert result.const.sel(a=0, b=0).item() == pytest.approx(0.0) + + class TestXarrayCompat: + def test_da_eq_da_still_works(self) -> None: + da1 = xr.DataArray([1, 2, 3]) + da2 = xr.DataArray([1, 2, 3]) + result = da1 == da2 + assert result.values.all() + + def test_da_eq_scalar_still_works(self) -> None: + da = xr.DataArray([1, 2, 3]) + result = da == 2 + np.testing.assert_array_equal(result.values, [False, True, False]) + + def test_da_truediv_var_raises(self, v: Variable) -> None: + da = xr.DataArray(np.ones(20), dims=["dim_2"], coords={"dim_2": range(20)}) + with pytest.raises(TypeError): + da / v # type: ignore[operator] + + +def test_expression_inherited_properties(x: Variable, y: Variable) -> None: + expr = 10 * x + y + assert isinstance(expr.attrs, dict) + assert isinstance(expr.coords, xr.Coordinates) + assert isinstance(expr.indexes, xr.core.indexes.Indexes) + assert isinstance(expr.sizes, xr.core.utils.Frozen) + + +def test_linear_expression_getitem_single(x: Variable, y: Variable) -> None: + expr = 10 * x + y + 3 + sel = expr[0] + assert isinstance(sel, LinearExpression) + assert sel.nterm == 2 + # one expression with two terms (constant is not counted) + assert sel.size == 2 + + +def test_linear_expression_getitem_slice(x: Variable, y: Variable) -> None: + expr = 10 * x + y + 3 + sel = expr[:1] + + assert isinstance(sel, LinearExpression) + assert sel.nterm == 2 + # one expression with two terms (constant is not counted) + assert sel.size == 2 + + +def test_linear_expression_getitem_list(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + z + 10 + sel = expr[:, [0, 2]] + assert isinstance(sel, LinearExpression) + assert sel.nterm == 2 + # four expressions with two terms (constant is not counted) + assert sel.size == 8 + + +def test_linear_expression_loc(x: Variable, y: Variable) -> None: + expr = x + y + assert expr.loc[0].size < expr.loc[:5].size + + +def test_linear_expression_empty(v: Variable) -> None: + expr = 7 * v + assert not expr.empty + assert expr.loc[[]].empty + + with pytest.warns(DeprecationWarning, match="use `.empty` property instead"): + assert expr.loc[[]].empty() + + +def test_linear_expression_isnull(v: Variable) -> None: + expr = np.arange(20) * v + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter) + assert expr.isnull().sum() == 10 + + +def test_linear_expression_flat(v: Variable) -> None: + coeff = np.arange(1, 21) # use non-zero coefficients + expr = coeff * v + df = expr.flat + assert isinstance(df, pd.DataFrame) + assert (df.coeffs == coeff).all() + + +def test_iterate_slices(x: Variable, y: Variable) -> None: + expr = x + 10 * y + for s in expr.iterate_slices(slice_size=2): + assert isinstance(s, LinearExpression) + assert s.nterm == expr.nterm + assert s.coord_dims == expr.coord_dims + + +def test_linear_expression_to_polars(v: Variable) -> None: + coeff = np.arange(1, 21) # use non-zero coefficients + expr = coeff * v + df = expr.to_polars() + assert isinstance(df, pl.DataFrame) + assert (df["coeffs"].to_numpy() == coeff).all() + + +def test_linear_expression_where(v: Variable) -> None: + expr = np.arange(20) * v + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + + expr = np.arange(20) * v + expr = expr.where(filter, drop=True).sum() + assert isinstance(expr, LinearExpression) + assert expr.nterm == 10 + + +def test_linear_expression_where_with_const(v: Variable) -> None: + expr = np.arange(20) * v + 10 + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert expr.const[:10].isnull().all() + assert (expr.const[10:] == 10).all() + + expr = np.arange(20) * v + 10 + expr = expr.where(filter, drop=True).sum() + assert isinstance(expr, LinearExpression) + assert expr.nterm == 10 + assert expr.const == 100 + + +def test_linear_expression_where_scalar_fill_value(v: Variable) -> None: + expr = np.arange(20) * v + 10 + filter = (expr.coeffs >= 10).any(TERM_DIM) + expr = expr.where(filter, 200) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert (expr.const[:10] == 200).all() + assert (expr.const[10:] == 10).all() + + +def test_linear_expression_where_array_fill_value(v: Variable) -> None: + expr = np.arange(20) * v + 10 + filter = (expr.coeffs >= 10).any(TERM_DIM) + other = expr.coeffs + expr = expr.where(filter, other) + assert isinstance(expr, LinearExpression) + assert expr.nterm == 1 + assert (expr.const[:10] == other[:10]).all() + assert (expr.const[10:] == 10).all() + + +def test_linear_expression_where_expr_fill_value(v: Variable) -> None: + expr = np.arange(20) * v + 10 + expr2 = np.arange(20) * v + 5 + filter = (expr.coeffs >= 10).any(TERM_DIM) + res = expr.where(filter, expr2) + assert isinstance(res, LinearExpression) + assert res.nterm == 1 + assert (res.const[:10] == expr2.const[:10]).all() + assert (res.const[10:] == 10).all() + + +def test_where_with_helper_dim_false(v: Variable) -> None: + expr = np.arange(20) * v + with pytest.raises(ValueError): + filter = expr.coeffs >= 10 + expr.where(filter) + + +def test_linear_expression_shift(v: Variable) -> None: + shifted = v.to_linexpr().shift(dim_2=2) + assert shifted.nterm == 1 + assert shifted.coeffs.loc[:1].isnull().all() + assert (shifted.vars.loc[:1] == -1).all() + + +def test_linear_expression_swap_dims(v: Variable) -> None: + expr = v.to_linexpr() + expr = expr.assign_coords({"second": ("dim_2", expr.indexes["dim_2"] + 100)}) + expr = expr.swap_dims({"dim_2": "second"}) + assert isinstance(expr, LinearExpression) + assert expr.coord_dims == ("second",) + + +def test_linear_expression_set_index(v: Variable) -> None: + expr = v.to_linexpr() + expr = expr.assign_coords({"second": ("dim_2", expr.indexes["dim_2"] + 100)}) + expr = expr.set_index({"multi": ["dim_2", "second"]}) + assert isinstance(expr, LinearExpression) + assert expr.coord_dims == ("multi",) + assert isinstance(expr.indexes["multi"], pd.MultiIndex) + + +def test_linear_expression_fillna(v: Variable) -> None: + expr = np.arange(20) * v + 10 + assert expr.const.sum() == 200 + + filter = (expr.coeffs >= 10).any(TERM_DIM) + filtered = expr.where(filter) + assert isinstance(filtered, LinearExpression) + assert filtered.const.sum() == 100 + + filled = filtered.fillna(10) + assert isinstance(filled, LinearExpression) + assert filled.const.sum() == 200 + assert filled.coeffs.isnull().sum() == 10 + + +def test_variable_expand_dims(v: Variable) -> None: + result = v.to_linexpr().expand_dims("new_dim") + assert isinstance(result, LinearExpression) + assert result.coord_dims == ("dim_2", "new_dim") + + +def test_variable_stack(v: Variable) -> None: + result = v.to_linexpr().expand_dims("new_dim").stack(new=("new_dim", "dim_2")) + assert isinstance(result, LinearExpression) + assert result.coord_dims == ("new",) + + +def test_linear_expression_unstack(v: Variable) -> None: + result = v.to_linexpr().expand_dims("new_dim").stack(new=("new_dim", "dim_2")) + result = result.unstack("new") + assert isinstance(result, LinearExpression) + assert result.coord_dims == ("new_dim", "dim_2") + + +def test_linear_expression_diff(v: Variable) -> None: + diff = v.to_linexpr().diff("dim_2") + assert diff.nterm == 2 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + dim = v.dims[0] + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords, name=dim) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert dim in grouped.dims + assert (grouped.data[dim] == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_on_same_name_as_target_dim( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True]) +def test_linear_expression_groupby_ndim(z: Variable, use_fallback: bool) -> None: + # TODO: implement fallback for n-dim groupby, see https://github.com/PyPSA/linopy/issues/299 + expr = 1 * z + groups = xr.DataArray([[1, 1, 2], [1, 3, 3]], coords=z.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # there are three groups, 1, 2 and 3, the largest group has 3 elements + assert (grouped.data.group == [1, 2, 3]).all() + assert grouped.nterm == 3 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_name(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords, name="my_group") + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "my_group" in grouped.dims + assert (grouped.data.my_group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_series(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"]) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_series_with_name( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes[v.dims[0]], name="my_group") + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "my_group" in grouped.dims + assert (grouped.data.my_group == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_series_with_same_group_name( + v: Variable, use_fallback: bool +) -> None: + """ + Test that the group by works with a series whose name is the same as + the dimension to group. + """ + expr = 1 * v + groups = pd.Series([1] * 10 + [2] * 10, index=v.indexes["dim_2"]) + groups.name = "dim_2" + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "dim_2" in grouped.dims + assert (grouped.data.dim_2 == [1, 2]).all() + assert grouped.nterm == 10 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_series_on_multiindex( + u: Variable, use_fallback: bool +) -> None: + expr = 1 * u + len_grouped_dim = len(u.data["dim_3"]) + groups = pd.Series([1] * len_grouped_dim, index=u.indexes["dim_3"]) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1]).all() + assert grouped.nterm == len_grouped_dim + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataframe( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + groups = pd.DataFrame( + {"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"] + ) + if use_fallback: + with pytest.raises(ValueError): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + index = pd.MultiIndex.from_frame(groups) + assert "group" in grouped.dims + assert set(grouped.data.group.values) == set(index.values) + assert grouped.nterm == 3 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataframe_with_same_group_name( + v: Variable, use_fallback: bool +) -> None: + """ + Test that the group by works with a dataframe whose column name is the same as + the dimension to group. + """ + expr = 1 * v + groups = pd.DataFrame( + {"dim_2": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, + index=v.indexes["dim_2"], + ) + if use_fallback: + with pytest.raises(ValueError): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + index = pd.MultiIndex.from_frame(groups) + assert "group" in grouped.dims + assert set(grouped.data.group.values) == set(index.values) + assert grouped.nterm == 3 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataframe_on_multiindex( + u: Variable, use_fallback: bool +) -> None: + expr = 1 * u + len_grouped_dim = len(u.data["dim_3"]) + groups = pd.DataFrame({"a": [1] * len_grouped_dim}, index=u.indexes["dim_3"]) + + if use_fallback: + with pytest.raises(ValueError): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert isinstance(grouped.indexes["group"], pd.MultiIndex) + assert grouped.nterm == len_grouped_dim + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_dataarray( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + df = pd.DataFrame( + {"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"] + ) + groups = xr.DataArray(df) + + # this should not be the case, see https://github.com/PyPSA/linopy/issues/351 + if use_fallback: + with pytest.raises((KeyError, IndexError)): + expr.groupby(groups).sum(use_fallback=use_fallback) + return + + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + index = pd.MultiIndex.from_frame(df) + assert "group" in grouped.dims + assert set(grouped.data.group.values) == set(index.values) + assert grouped.nterm == 3 + + +def test_linear_expression_groupby_with_dataframe_non_aligned(v: Variable) -> None: + expr = 1 * v + groups = pd.DataFrame( + {"a": [1] * 10 + [2] * 10, "b": list(range(4)) * 5}, index=v.indexes["dim_2"] + ) + target = expr.groupby(groups).sum() + + groups_non_aligned = groups[::-1] + grouped = expr.groupby(groups_non_aligned).sum() + assert_linequal(grouped, target) + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_with_const(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + 15 + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + assert (grouped.const == 150).all() + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_asymmetric(v: Variable, use_fallback: bool) -> None: + expr = 1 * v + # now asymetric groups which result in different nterms + groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # first group must be full with vars + assert (grouped.data.sel(group=1) > 0).all() + # the last 4 entries of the second group must be empty, i.e. -1 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.nterm == 12 + + +@pytest.mark.parametrize("use_fallback", [True, False]) +def test_linear_expression_groupby_asymmetric_with_const( + v: Variable, use_fallback: bool +) -> None: + expr = 1 * v + 15 + # now asymetric groups which result in different nterms + groups = xr.DataArray([1] * 12 + [2] * 8, coords=v.coords) + grouped = expr.groupby(groups).sum(use_fallback=use_fallback) + assert "group" in grouped.dims + # first group must be full with vars + assert (grouped.data.sel(group=1) > 0).all() + # the last 4 entries of the second group must be empty, i.e. -1 + assert (grouped.data.sel(group=2).isel(_term=slice(None, -4)).vars >= 0).all() + assert (grouped.data.sel(group=2).isel(_term=slice(-4, None)).vars == -1).all() + assert grouped.nterm == 12 + assert list(grouped.const) == [180, 120] + + +def test_linear_expression_groupby_roll(v: Variable) -> None: + expr = 1 * v + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + + +def test_linear_expression_groupby_roll_with_const(v: Variable) -> None: + expr = 1 * v + np.arange(20) + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = expr.groupby(groups).roll(dim_2=1) + assert grouped.nterm == 1 + assert grouped.vars[0].item() == 19 + assert grouped.const[0].item() == 9 + + +def test_linear_expression_groupby_from_variable(v: Variable) -> None: + groups = xr.DataArray([1] * 10 + [2] * 10, coords=v.coords) + grouped = v.groupby(groups).sum() + assert "group" in grouped.dims + assert (grouped.data.group == [1, 2]).all() + assert grouped.nterm == 10 + + +def test_linear_expression_rolling(v: Variable) -> None: + expr = 1 * v + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 + + with pytest.raises(ValueError): + expr.rolling().sum() + + +def test_linear_expression_rolling_with_const(v: Variable) -> None: + expr = 1 * v + 15 + rolled = expr.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + assert rolled.const[0].item() == 15 + assert (rolled.const[1:] == 30).all() + + rolled = expr.rolling(dim_2=3).sum() + assert rolled.nterm == 3 + assert rolled.const[0].item() == 15 + assert rolled.const[1].item() == 30 + assert (rolled.const[2:] == 45).all() + + +def test_linear_expression_rolling_from_variable(v: Variable) -> None: + rolled = v.rolling(dim_2=2).sum() + assert rolled.nterm == 2 + + +def test_linear_expression_from_tuples(x: Variable, y: Variable) -> None: + expr = LinearExpression.from_tuples((10, x), (1, y)) + assert isinstance(expr, LinearExpression) + + with pytest.warns(DeprecationWarning): + expr2 = LinearExpression.from_tuples((10, x), (1,)) + assert isinstance(expr2, LinearExpression) + assert (expr2.const == 1).all() + + expr3 = LinearExpression.from_tuples((10, x), 1) + assert isinstance(expr3, LinearExpression) + assert_linequal(expr2, expr3) + + expr4 = LinearExpression.from_tuples((10, x), (1, y), 1) + assert isinstance(expr4, LinearExpression) + assert (expr4.const == 1).all() + + expr5 = LinearExpression.from_tuples(1, model=x.model) + assert isinstance(expr5, LinearExpression) + + +def test_linear_expression_from_tuples_bad_calls( + m: Model, x: Variable, y: Variable +) -> None: + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x), (1, y), x) + + with pytest.raises(ValueError): + LinearExpression.from_tuples((10, x, 3), (1, y), 1) + + sv = ScalarVariable(label=0, model=m) + with pytest.raises(TypeError): + LinearExpression.from_tuples((np.array([1, 1]), sv)) + + with pytest.raises(TypeError): + LinearExpression.from_tuples((x, x)) + + with pytest.raises(ValueError): + LinearExpression.from_tuples(10) + + +def test_linear_expression_from_constant_scalar(m: Model) -> None: + expr = LinearExpression.from_constant(model=m, constant=10) + assert expr.is_constant + assert isinstance(expr, LinearExpression) + assert (expr.const == 10).all() + + +def test_linear_expression_from_constant_1D(m: Model) -> None: + arr = pd.Series(index=pd.Index([0, 1], name="t"), data=[10, 20]) + expr = LinearExpression.from_constant(model=m, constant=arr) + assert isinstance(expr, LinearExpression) + assert list(expr.coords.keys())[0] == "t" + assert expr.nterm == 0 + assert (expr.const.values == [10, 20]).all() + assert expr.is_constant + + +def test_constant_linear_expression_to_polars_2D(m: Model) -> None: + index_a = pd.Index([0, 1], name="a") + index_b = pd.Index([0, 1, 2], name="b") + arr = np.array([[10, 20, 30], [40, 50, 60]]) + const = xr.DataArray(data=arr, coords=[index_a, index_b]) + + le_variable = m.add_variables(name="var", coords=[index_a, index_b]) * 1 + const + assert not le_variable.is_constant + le_const = LinearExpression.from_constant(model=m, constant=const) + assert le_const.is_constant + + var_pol = le_variable.to_polars() + const_pol = le_const.to_polars() + assert var_pol.shape == const_pol.shape + assert var_pol.columns == const_pol.columns + assert all(const_pol["const"] == var_pol["const"]) + assert all(const_pol["coeffs"].is_null()) + assert all(const_pol["vars"].is_null()) + + +def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + z + assert isinstance(expr.sanitize(), LinearExpression) + + +def test_merge(x: Variable, y: Variable, z: Variable) -> None: + expr1 = (10 * x + y).sum("dim_0") + expr2 = z.sum("dim_0") + + res = merge([expr1, expr2], cls=LinearExpression) + assert res.nterm == 6 + + res: LinearExpression = merge([expr1, expr2]) # type: ignore + assert isinstance(res, LinearExpression) + + # now concat with same length of terms + expr1 = z.sel(dim_0=0).sum("dim_1") + expr2 = z.sel(dim_0=1).sum("dim_1") + + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) + assert res.nterm == 3 + + # now with different length of terms + expr1 = z.sel(dim_0=0, dim_1=slice(0, 1)).sum("dim_1") + expr2 = z.sel(dim_0=1).sum("dim_1") + + res = merge([expr1, expr2], dim="dim_1", cls=LinearExpression) + assert res.nterm == 3 + assert res.sel(dim_1=0).vars[2].item() == -1 + + with pytest.warns(DeprecationWarning): + merge(expr1, expr2) + + +def test_linear_expression_outer_sum(x: Variable, y: Variable) -> None: + expr = x + y + expr2: LinearExpression = sum([x, y]) # type: ignore + assert_linequal(expr, expr2) + + expr = 1 * x + 2 * y + expr2: LinearExpression = sum([1 * x, 2 * y]) # type: ignore + assert_linequal(expr, expr2) + + assert isinstance(expr.sum(), LinearExpression) + + +def test_rename(x: Variable, y: Variable, z: Variable) -> None: + expr = 10 * x + y + z + renamed = expr.rename({"dim_0": "dim_5"}) + assert set(renamed.dims) == {"dim_1", "dim_5", TERM_DIM} + assert renamed.nterm == 3 + + renamed = expr.rename({"dim_0": "dim_1", "dim_1": "dim_2"}) + assert set(renamed.dims) == {"dim_1", "dim_2", TERM_DIM} + assert renamed.nterm == 3 + + +@pytest.mark.parametrize("multiple", [1.0, 0.5, 2.0, 0.0]) +def test_cumsum(m: Model, multiple: float) -> None: + # Test cumsum on variable x + var = m.variables["x"] + cumsum = (multiple * var).cumsum() + cumsum.nterm == 2 + + # Test cumsum on sum of variables + expr = m.variables["x"] + m.variables["y"] + cumsum = (multiple * expr).cumsum() + cumsum.nterm == 2 + + +def test_simplify_basic(x: Variable) -> None: + """Test basic simplification with duplicate terms.""" + expr = 2 * x + 3 * x + 1 * x + simplified = expr.simplify() + assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}" + + x_len = len(x.coords["dim_0"]) + # Check that the coefficient is 6 (2 + 3 + 1) + coeffs: np.ndarray = simplified.coeffs.values + assert len(coeffs) == x_len, f"Expected {x_len} coefficients, got {len(coeffs)}" + assert all(coeffs == 6.0), f"Expected coefficient 6.0, got {coeffs[0]}" + + +def test_simplify_multiple_dimensions() -> None: + model = Model() + a_index = pd.Index([0, 1, 2, 3], name="a") + b_index = pd.Index([0, 1, 2], name="b") + coords = [a_index, b_index] + x = model.add_variables(name="x", coords=coords) + + expr = 2 * x + 3 * x + x + # Simplify + simplified = expr.simplify() + assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}" + assert simplified.ndim == 2, f"Expected 2 dimensions, got {simplified.ndim}" + assert all(simplified.coeffs.values.reshape(-1) == 6), ( + f"Expected coefficients of 6, got {simplified.coeffs.values}" + ) + + +def test_simplify_with_different_variables(x: Variable, y: Variable) -> None: + """Test that different variables are kept separate.""" + # Create expression: 2*x + 3*x + 4*y + expr = 2 * x + 3 * x + 4 * y + + # Simplify + simplified = expr.simplify() + # Should have 2 terms (one for x with coeff 5, one for y with coeff 4) + assert simplified.nterm == 2, f"Expected 2 terms, got {simplified.nterm}" + + coeffs: list[float] = simplified.coeffs.values.flatten().tolist() + assert set(coeffs) == {5.0, 4.0}, ( + f"Expected coefficients {{5.0, 4.0}}, got {set(coeffs)}" + ) + + +def test_simplify_with_constant(x: Variable) -> None: + """Test that constants are preserved.""" + expr = 2 * x + 3 * x + 10 + + # Simplify + simplified = expr.simplify() + + # Check constant is preserved + assert all(simplified.const.values == 10.0), ( + f"Expected constant 10.0, got {simplified.const.values}" + ) + + # Check coefficients + assert all(simplified.coeffs.values == 5.0), ( + f"Expected coefficient 5.0, got {simplified.coeffs.values}" + ) + + +def test_simplify_cancellation(x: Variable) -> None: + """Test that terms cancel out correctly when coefficients sum to zero.""" + expr = x - x + simplified = expr.simplify() + + assert simplified.nterm == 0, f"Expected 0 terms, got {simplified.nterm}" + assert simplified.coeffs.values.size == 0 + assert simplified.vars.values.size == 0 + + +def test_simplify_partial_cancellation(x: Variable, y: Variable) -> None: + """Test partial cancellation where some terms cancel but others remain.""" + expr = 2 * x - 2 * x + 3 * y + simplified = expr.simplify() + + assert simplified.nterm == 1, f"Expected 1 term, got {simplified.nterm}" + assert all(simplified.coeffs.values == 3.0), ( + f"Expected coefficient 3.0, got {simplified.coeffs.values}" + ) + + +def test_constant_only_expression_mul_dataarray(m: Model) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + assert const_expr.nterm == 0 + + data_arr = xr.DataArray([10, 20], dims=["dim_0"]) + expected_const = const_arr * data_arr + + result = const_expr * data_arr + assert isinstance(result, LinearExpression) + assert result.is_constant + assert (result.const == expected_const).all() + + result_rev = data_arr * const_expr + assert isinstance(result_rev, LinearExpression) + assert result_rev.is_constant + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_linexpr_with_vars(m: Model, x: Variable) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + assert const_expr.nterm == 0 + + expr_with_vars = 1 * x + 5 + expected_coeffs = const_arr + expected_const = const_arr * 5 + + result = const_expr * expr_with_vars + assert isinstance(result, LinearExpression) + assert (result.coeffs == expected_coeffs).all() + assert (result.const == expected_const).all() + + result_rev = expr_with_vars * const_expr + assert isinstance(result_rev, LinearExpression) + assert (result_rev.coeffs == expected_coeffs).all() + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_constant_only(m: Model) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_arr2 = xr.DataArray([4, 5], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + const_expr2 = LinearExpression(const_arr2, m) + assert const_expr.is_constant + assert const_expr2.is_constant + + expected_const = const_arr * const_arr2 + + result = const_expr * const_expr2 + assert isinstance(result, LinearExpression) + assert result.is_constant + assert (result.const == expected_const).all() + + result_rev = const_expr2 * const_expr + assert isinstance(result_rev, LinearExpression) + assert result_rev.is_constant + assert (result_rev.const == expected_const).all() + + +def test_constant_only_expression_mul_linexpr_with_vars_and_const( + m: Model, x: Variable +) -> None: + const_arr = xr.DataArray([2, 3], dims=["dim_0"]) + const_expr = LinearExpression(const_arr, m) + assert const_expr.is_constant + + expr_with_vars_and_const = 4 * x + 10 + expected_coeffs = const_arr * 4 + expected_const = const_arr * 10 + + result = const_expr * expr_with_vars_and_const + assert isinstance(result, LinearExpression) + assert not result.is_constant + assert (result.coeffs == expected_coeffs).all() + assert (result.const == expected_const).all() + + result_rev = expr_with_vars_and_const * const_expr + assert isinstance(result_rev, LinearExpression) + assert not result_rev.is_constant + assert (result_rev.coeffs == expected_coeffs).all() + assert (result_rev.const == expected_const).all() + + +class TestJoinParameter: + @pytest.fixture + def m2(self) -> Model: + m = Model() + m.add_variables(coords=[pd.Index([0, 1, 2], name="i")], name="a") + m.add_variables(coords=[pd.Index([1, 2, 3], name="i")], name="b") + m.add_variables(coords=[pd.Index([0, 1, 2], name="i")], name="c") + return m + + @pytest.fixture + def a(self, m2: Model) -> Variable: + return m2.variables["a"] + + @pytest.fixture + def b(self, m2: Model) -> Variable: + return m2.variables["b"] + + @pytest.fixture + def c(self, m2: Model) -> Variable: + return m2.variables["c"] + + class TestAddition: + def test_add_join_none_preserves_default( + self, a: Variable, b: Variable + ) -> None: + result_default = a.to_linexpr() + b.to_linexpr() + result_none = a.to_linexpr().add(b.to_linexpr(), join=None) + assert_linequal(result_default, result_none) + + def test_add_expr_join_inner(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_add_expr_join_outer(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + + def test_add_expr_join_left(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="left") + assert list(result.data.indexes["i"]) == [0, 1, 2] + + def test_add_expr_join_right(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().add(b.to_linexpr(), join="right") + assert list(result.data.indexes["i"]) == [1, 2, 3] + + def test_add_constant_join_inner(self, a: Variable) -> None: + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().add(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_add_constant_join_outer(self, a: Variable) -> None: + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().add(const, join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + + def test_add_constant_join_override(self, a: Variable, c: Variable) -> None: + expr = a.to_linexpr() + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [0, 1, 2]}) + result = expr.add(const, join="override") + assert list(result.data.indexes["i"]) == [0, 1, 2] + assert (result.const.values == const.values).all() + + def test_add_same_coords_all_joins(self, a: Variable, c: Variable) -> None: + expr_a = 1 * a + 5 + const = xr.DataArray([1, 2, 3], dims=["i"], coords={"i": [0, 1, 2]}) + for join in ("override", "outer", "inner"): + result = expr_a.add(const, join=join) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [6, 7, 8]) + + def test_add_scalar_with_explicit_join(self, a: Variable) -> None: + expr = 1 * a + 5 + result = expr.add(10, join="override") + np.testing.assert_array_equal(result.const.values, [15, 15, 15]) + assert list(result.coords["i"].values) == [0, 1, 2] + + class TestSubtraction: + def test_sub_expr_join_inner(self, a: Variable, b: Variable) -> None: + result = a.to_linexpr().sub(b.to_linexpr(), join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_sub_constant_override(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.sub(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [-5, -15, -25]) + + class TestMultiplication: + def test_mul_constant_join_inner(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_mul_constant_join_outer(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().mul(const, join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + assert result.coeffs.sel(i=0).item() == 0 + assert result.coeffs.sel(i=1).item() == 2 + assert result.coeffs.sel(i=2).item() == 3 + + def test_mul_expr_with_join_raises(self, a: Variable, b: Variable) -> None: + with pytest.raises(TypeError, match="join parameter is not supported"): + a.to_linexpr().mul(b.to_linexpr(), join="inner") + + class TestDivision: + def test_div_constant_join_inner(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().div(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_div_constant_join_outer(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.to_linexpr().div(const, join="outer") + assert list(result.data.indexes["i"]) == [0, 1, 2, 3] + + def test_div_expr_with_join_raises(self, a: Variable, b: Variable) -> None: + with pytest.raises(TypeError): + a.to_linexpr().div(b.to_linexpr(), join="outer") + + class TestVariableOperations: + def test_variable_add_join(self, a: Variable, b: Variable) -> None: + result = a.add(b, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_sub_join(self, a: Variable, b: Variable) -> None: + result = a.sub(b, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_mul_join(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_div_join(self, a: Variable) -> None: + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = a.div(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2] + + def test_variable_add_outer_values(self, a: Variable, b: Variable) -> None: + result = a.add(b, join="outer") + assert isinstance(result, LinearExpression) + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.nterm == 2 + + def test_variable_mul_override(self, a: Variable) -> None: + other = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [5, 6, 7]}) + result = a.mul(other, join="override") + assert isinstance(result, LinearExpression) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 3, 4]) + + def test_variable_div_override(self, a: Variable) -> None: + other = xr.DataArray([2.0, 5.0, 10.0], dims=["i"], coords={"i": [5, 6, 7]}) + result = a.div(other, join="override") + assert isinstance(result, LinearExpression) + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_almost_equal( + result.coeffs.squeeze().values, [0.5, 0.2, 0.1] + ) + + def test_same_shape_add_join_override(self, a: Variable, c: Variable) -> None: + result = a.to_linexpr().add(c.to_linexpr(), join="override") + assert list(result.data.indexes["i"]) == [0, 1, 2] + + class TestMerge: + def test_merge_join_parameter(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="inner" + ) + assert list(result.data.indexes["i"]) == [1, 2] + + def test_merge_outer_join(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="outer" + ) + assert set(result.coords["i"].values) == {0, 1, 2, 3} + + def test_merge_join_left(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="left" + ) + assert list(result.data.indexes["i"]) == [0, 1, 2] + + def test_merge_join_right(self, a: Variable, b: Variable) -> None: + result: LinearExpression = merge( + [a.to_linexpr(), b.to_linexpr()], join="right" + ) + assert list(result.data.indexes["i"]) == [1, 2, 3] + + class TestValueVerification: + def test_add_expr_outer_const_values(self, a: Variable, b: Variable) -> None: + expr_a = 1 * a + 5 + expr_b = 2 * b + 10 + result = expr_a.add(expr_b, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 5 + assert result.const.sel(i=1).item() == 15 + assert result.const.sel(i=2).item() == 15 + assert result.const.sel(i=3).item() == 10 + + def test_add_expr_inner_const_values(self, a: Variable, b: Variable) -> None: + expr_a = 1 * a + 5 + expr_b = 2 * b + 10 + result = expr_a.add(expr_b, join="inner") + assert list(result.coords["i"].values) == [1, 2] + assert result.const.sel(i=1).item() == 15 + assert result.const.sel(i=2).item() == 15 + + def test_add_constant_outer_fill_values(self, a: Variable) -> None: + expr = 1 * a + 5 + const = xr.DataArray([10, 20], dims=["i"], coords={"i": [1, 3]}) + result = expr.add(const, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 5 + assert result.const.sel(i=1).item() == 15 + assert result.const.sel(i=2).item() == 5 + assert result.const.sel(i=3).item() == 20 + + def test_add_constant_inner_fill_values(self, a: Variable) -> None: + expr = 1 * a + 5 + const = xr.DataArray([10, 20], dims=["i"], coords={"i": [1, 3]}) + result = expr.add(const, join="inner") + assert list(result.coords["i"].values) == [1] + assert result.const.sel(i=1).item() == 15 + + def test_add_constant_override_positional(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.add(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [15, 25, 35]) + + def test_sub_expr_outer_const_values(self, a: Variable, b: Variable) -> None: + expr_a = 1 * a + 5 + expr_b = 2 * b + 10 + result = expr_a.sub(expr_b, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 5 + assert result.const.sel(i=1).item() == -5 + assert result.const.sel(i=2).item() == -5 + assert result.const.sel(i=3).item() == -10 + + def test_mul_constant_override_positional(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.mul(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [10, 15, 20]) + np.testing.assert_array_equal(result.coeffs.squeeze().values, [2, 3, 4]) + + def test_mul_constant_outer_fill_values(self, a: Variable) -> None: + expr = 1 * a + 5 + other = xr.DataArray([2, 3], dims=["i"], coords={"i": [1, 3]}) + result = expr.mul(other, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=0).item() == 0 + assert result.const.sel(i=1).item() == 10 + assert result.const.sel(i=2).item() == 0 + assert result.const.sel(i=3).item() == 0 + assert result.coeffs.squeeze().sel(i=1).item() == 2 + assert result.coeffs.squeeze().sel(i=0).item() == 0 + + def test_div_constant_override_positional(self, a: Variable) -> None: + expr = 1 * a + 10 + other = xr.DataArray([2.0, 5.0, 10.0], dims=["i"], coords={"i": [5, 6, 7]}) + result = expr.div(other, join="override") + assert list(result.coords["i"].values) == [0, 1, 2] + np.testing.assert_array_equal(result.const.values, [5.0, 2.0, 1.0]) + + def test_div_constant_outer_fill_values(self, a: Variable) -> None: + expr = 1 * a + 10 + other = xr.DataArray([2.0, 5.0], dims=["i"], coords={"i": [1, 3]}) + result = expr.div(other, join="outer") + assert set(result.coords["i"].values) == {0, 1, 2, 3} + assert result.const.sel(i=1).item() == pytest.approx(5.0) + assert result.coeffs.squeeze().sel(i=1).item() == pytest.approx(0.5) + assert result.const.sel(i=0).item() == pytest.approx(10.0) + assert result.coeffs.squeeze().sel(i=0).item() == pytest.approx(1.0) + + class TestQuadratic: + def test_quadratic_add_constant_join_inner( + self, a: Variable, b: Variable + ) -> None: + quad = a.to_linexpr() * b.to_linexpr() + const = xr.DataArray([10, 20, 30], dims=["i"], coords={"i": [1, 2, 3]}) + result = quad.add(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2, 3] + + def test_quadratic_add_expr_join_inner(self, a: Variable) -> None: + quad = a.to_linexpr() * a.to_linexpr() + const = xr.DataArray([10, 20], dims=["i"], coords={"i": [0, 1]}) + result = quad.add(const, join="inner") + assert list(result.data.indexes["i"]) == [0, 1] + + def test_quadratic_mul_constant_join_inner( + self, a: Variable, b: Variable + ) -> None: + quad = a.to_linexpr() * b.to_linexpr() + const = xr.DataArray([2, 3, 4], dims=["i"], coords={"i": [1, 2, 3]}) + result = quad.mul(const, join="inner") + assert list(result.data.indexes["i"]) == [1, 2, 3] diff --git a/test/test_typing.py b/test/test_typing.py index 2375dc72..566583c2 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -1,8 +1,19 @@ +from collections.abc import Generator + +import pytest import xarray as xr import linopy +@pytest.fixture(autouse=True) +def _use_v1_convention() -> Generator[None, None, None]: + """Use v1 arithmetic convention for all tests in this module.""" + linopy.options["arithmetic_convention"] = "v1" + yield + linopy.options["arithmetic_convention"] = "legacy" + + def test_operations_with_data_arrays_are_typed_correctly() -> None: m = linopy.Model() diff --git a/test/test_typing_legacy.py b/test/test_typing_legacy.py new file mode 100644 index 00000000..99a27033 --- /dev/null +++ b/test/test_typing_legacy.py @@ -0,0 +1,25 @@ +import xarray as xr + +import linopy + + +def test_operations_with_data_arrays_are_typed_correctly() -> None: + m = linopy.Model() + + a: xr.DataArray = xr.DataArray([1, 2, 3]) + + v: linopy.Variable = m.add_variables(lower=0.0, name="v") + e: linopy.LinearExpression = v * 1.0 + q = v * v + + _ = a * v + _ = v * a + _ = v + a + + _ = a * e + _ = e * a + _ = e + a + + _ = a * q + _ = q * a + _ = q + a