From 6500a3e3d5ff4054140b98c354fd30c0835bc9ae Mon Sep 17 00:00:00 2001 From: Jeremi Joslin Date: Tue, 17 Feb 2026 11:26:26 +0700 Subject: [PATCH 1/2] fix(spp_cel_domain): handle relational fields and SQL JOINs in CEL - Handle one2many/many2many fields as bare CEL predicates - Use query.select() for SQL with JOINs instead of raw field access - Add tests for relational predicate handling --- spp_cel_domain/models/cel_executor.py | 19 +- spp_cel_domain/models/cel_sql_builder.py | 47 +-- spp_cel_domain/models/cel_translator.py | 7 +- spp_cel_domain/models/cel_variable.py | 21 +- spp_cel_domain/tests/__init__.py | 1 + .../tests/test_cel_relational_predicate.py | 82 +++++ spp_cel_domain/tests/test_cel_variable.py | 279 ++++++++++++++++++ 7 files changed, 396 insertions(+), 60 deletions(-) create mode 100644 spp_cel_domain/tests/test_cel_relational_predicate.py diff --git a/spp_cel_domain/models/cel_executor.py b/spp_cel_domain/models/cel_executor.py index 753dc629..6b9d1aec 100644 --- a/spp_cel_domain/models/cel_executor.py +++ b/spp_cel_domain/models/cel_executor.py @@ -85,20 +85,11 @@ def _domain_to_id_sql(self, model: str, domain: list[Any]) -> SQL | None: expr = osv_expression.expression(model=Model, domain=domain) query = expr.query - # Get the WHERE clause - where_clause = query.where_clause - if not where_clause: - # No WHERE clause means all records - return SQL("(SELECT id FROM %s)", SQL.identifier(table)) - - # Build the full SELECT query - # Note: where_clause is already a SQL object in Odoo 19 - return SQL( - "(SELECT %s.id FROM %s WHERE %s)", - SQL.identifier(table), - SQL.identifier(table), - where_clause, - ) + # Use query.select() to get the full SQL including FROM clause + # with all JOINs (needed for related field domains like gender_id.uri) + select_sql = query.select(SQL.identifier(table, "id")) + + return SQL("(%s)", select_sql) except Exception as e: self._logger.debug("[CEL SQL] Failed to convert domain to SQL: %s", e) return None diff --git a/spp_cel_domain/models/cel_sql_builder.py b/spp_cel_domain/models/cel_sql_builder.py index dfbf6c48..423b0c8d 100644 --- a/spp_cel_domain/models/cel_sql_builder.py +++ b/spp_cel_domain/models/cel_sql_builder.py @@ -157,57 +157,28 @@ def select_ids_from_domain(self, model: str, domain: list[Any]) -> SQL | None: """Generate SELECT id FROM model WHERE domain. Uses expression.expression() to include record rules. + Uses Query.select() to include all necessary JOINs (e.g. for + related field lookups like gender_id.uri). Returns None if domain cannot be converted. This is the PRIMARY method for generating subqueries. All other methods should use this for model references. """ - if not domain: - # Empty domain selects all IDs (respecting record rules) - Model = self.env[model] - table = Model._table - # Still need to apply record rules even for empty domain - try: - from odoo.osv import expression as osv_expression - - expr = osv_expression.expression(model=Model, domain=[]) - query = expr.query - where_clause = query.where_clause - if where_clause: - return SQL( - "(SELECT %s.id FROM %s WHERE %s)", - SQL.identifier(table), - SQL.identifier(table), - where_clause, - ) - return SQL("(SELECT id FROM %s)", SQL.identifier(table)) - except Exception as e: - _logger.debug("[SQLBuilder] Failed for empty domain: %s", e) - return SQL("(SELECT id FROM %s)", SQL.identifier(table)) - try: from odoo.osv import expression as osv_expression Model = self.env[model] table = Model._table - # Build the expression and extract query - expr = osv_expression.expression(model=Model, domain=domain) + # Build the expression (applies record rules even for empty domain) + expr = osv_expression.expression(model=Model, domain=domain or []) query = expr.query - # Get the WHERE clause - where_clause = query.where_clause - if not where_clause: - # No WHERE clause means all records - return SQL("(SELECT id FROM %s)", SQL.identifier(table)) - - # Build the full SELECT query - return SQL( - "(SELECT %s.id FROM %s WHERE %s)", - SQL.identifier(table), - SQL.identifier(table), - where_clause, - ) + # Use query.select() to get the full SQL including FROM clause + # with all JOINs (needed for related field domains like gender_id.uri) + select_sql = query.select(SQL.identifier(table, "id")) + + return SQL("(%s)", select_sql) except Exception as e: _logger.debug("[SQLBuilder] Failed to convert domain to SQL: %s", e) return None diff --git a/spp_cel_domain/models/cel_translator.py b/spp_cel_domain/models/cel_translator.py index 85095f2e..0389bc70 100644 --- a/spp_cel_domain/models/cel_translator.py +++ b/spp_cel_domain/models/cel_translator.py @@ -628,7 +628,7 @@ def _flatten_attr(a): rec = self.env["spp.program"].search([("name", "=", name)], limit=1) pid = rec.id or None return LeafDomain(model, [("id", "!=", 0)]), f"PROGRAM({name})={pid}" - # Boolean field used as predicate (e.g., m._link.is_ended) + # Field used as bare predicate (e.g., m._link.is_ended, program_membership_ids) if isinstance(node, P.Attr | P.Ident): fld, mdl = self._resolve_field(model, node, cfg, ctx) target_model = mdl or model @@ -637,7 +637,10 @@ def _flatten_attr(a): ft = model_fields.get(fld) if ft and getattr(ft, "type", None) == "boolean": return LeafDomain(target_model, [(fld, "=", True)]), f"{fld} is True" - # Field exists but is not boolean — cannot use as bare predicate + # Relational fields as bare predicates: treat as "has records" + if ft and getattr(ft, "type", None) in ("one2many", "many2many"): + return LeafDomain(target_model, [(fld, "!=", False)]), f"{fld} is not empty" + # Field exists but is not boolean or relational — cannot use as bare predicate if ft: raise NotImplementedError( f"Field '{fld}' is of type '{getattr(ft, 'type', '?')}', not boolean. " diff --git a/spp_cel_domain/models/cel_variable.py b/spp_cel_domain/models/cel_variable.py index ca991c40..0b7db977 100644 --- a/spp_cel_domain/models/cel_variable.py +++ b/spp_cel_domain/models/cel_variable.py @@ -530,16 +530,25 @@ def unlink(self): return super().unlink() def _invalidate_resolver_cache(self): - """Invalidate the variable resolver cache. + """Invalidate all CEL caches. - Called when variable definitions change to ensure deferred resolution - uses the updated definitions. + Called when variable definitions change to ensure: + - Variable resolver cache uses updated variable definitions + - Translation cache rebuilds with new variable expansions + - Profile cache reflects any configuration changes + + This prevents stale cache issues where expressions continue to + use old variable definitions after modifications. """ try: - resolver = self.env["spp.cel.variable.resolver"] - resolver.invalidate_variable_cache() + # Invalidate all CEL caches via the service facade + # This ensures profile cache, translation cache, and resolver cache + # are all cleared in a coordinated manner + cel_service = self.env["spp.cel.service"] + cel_service.invalidate_caches() + _logger.debug("CEL caches invalidated after variable change") except Exception as e: - _logger.debug("Could not invalidate resolver cache: %s", e) + _logger.warning("Could not invalidate CEL caches: %s", e) # ═══════════════════════════════════════════════════════════════════════ # HELPER METHODS diff --git a/spp_cel_domain/tests/__init__.py b/spp_cel_domain/tests/__init__.py index 18643e9c..325ae510 100644 --- a/spp_cel_domain/tests/__init__.py +++ b/spp_cel_domain/tests/__init__.py @@ -25,3 +25,4 @@ from . import test_data_value from . import test_data_provider from . import test_multi_company +from . import test_cel_relational_predicate diff --git a/spp_cel_domain/tests/test_cel_relational_predicate.py b/spp_cel_domain/tests/test_cel_relational_predicate.py new file mode 100644 index 00000000..aa9cd003 --- /dev/null +++ b/spp_cel_domain/tests/test_cel_relational_predicate.py @@ -0,0 +1,82 @@ +# Part of OpenSPP. See LICENSE file for full copyright and licensing details. +"""Tests for relational fields (one2many, many2many) used as bare predicates. + +When a user writes a CEL expression like `program_membership_ids` (without +a comparison operator), the translator should treat it as a truthiness +check: "has records" → True, "no records" → False. This matches Python +and CEL semantics where non-empty collections are truthy. +""" + +from odoo.tests import TransactionCase, tagged + + +@tagged("post_install", "-at_install") +class TestRelationalBarePredicates(TransactionCase): + """Test that one2many and many2many fields work as bare predicates.""" + + def setUp(self): + super().setUp() + self.service = self.env["spp.cel.service"] + + def test_one2many_field_as_bare_predicate_compiles(self): + """A one2many field used as a bare predicate should compile. + + Expression like `program_membership_ids` should be treated as + checking whether the field has records (not empty). + """ + # program_membership_ids is a One2many on res.partner + # (from spp_programs module) + if "program_membership_ids" not in self.env["res.partner"]._fields: + self.skipTest("spp_programs not installed (no program_membership_ids field)") + + result = self.service.compile_expression( + "program_membership_ids", + "registry_groups", + ) + self.assertTrue(result["valid"], f"Error: {result.get('error')}") + # Should produce a domain checking the field is not empty + self.assertIsInstance(result["domain"], list) + + def test_many2many_field_as_bare_predicate_compiles(self): + """A many2many field used as a bare predicate should compile.""" + # Find any many2many field on res.partner for testing + partner_fields = self.env["res.partner"]._fields + m2m_field = None + for fname, fobj in partner_fields.items(): + if getattr(fobj, "type", None) == "many2many": + m2m_field = fname + break + + if not m2m_field: + self.skipTest("No many2many field found on res.partner") + + result = self.service.compile_expression( + m2m_field, + "registry_groups", + ) + self.assertTrue(result["valid"], f"Error: {result.get('error')}") + self.assertIsInstance(result["domain"], list) + + def test_one2many_predicate_produces_correct_domain(self): + """Bare one2many predicate should produce a '!= False' domain. + + This is the standard Odoo pattern for checking if a relational + field has records. + """ + if "program_membership_ids" not in self.env["res.partner"]._fields: + self.skipTest("spp_programs not installed (no program_membership_ids field)") + + result = self.service.compile_expression( + "program_membership_ids", + "registry_groups", + ) + self.assertTrue(result["valid"], f"Error: {result.get('error')}") + # The domain should contain a check for "has records" + domain = result["domain"] + # Look for the field check in the domain + found = False + for leaf in domain: + if isinstance(leaf, tuple) and leaf[0] == "program_membership_ids" and leaf[1] == "!=": + found = True + break + self.assertTrue(found, f"Expected '!= False' domain for one2many, got: {domain}") diff --git a/spp_cel_domain/tests/test_cel_variable.py b/spp_cel_domain/tests/test_cel_variable.py index 8e8cd713..feb1a063 100644 --- a/spp_cel_domain/tests/test_cel_variable.py +++ b/spp_cel_domain/tests/test_cel_variable.py @@ -785,3 +785,282 @@ def test_invalidate_caches_includes_resolver(self): # This shouldn't crash result = self.Service.invalidate_caches() self.assertTrue(result) + + +@tagged("post_install", "-at_install") +class TestCELVariableCacheInvalidation(TransactionCase): + """Test automatic cache invalidation when variables change.""" + + def setUp(self): + super().setUp() + self.Service = self.env["spp.cel.service"] + self.Variable = self.env["spp.cel.variable"] + self.Resolver = self.env["spp.cel.variable.resolver"] + self.Registry = self.env["spp.cel.registry"] + self.Translator = self.env["spp.cel.translator"] + + # Import cache modules for direct inspection + from ..models import cel_registry, cel_translator + + self.cel_registry = cel_registry + self.cel_translator = cel_translator + + # Clear all caches before each test + self.Service.invalidate_caches() + + # Create test variable with unique name + self.test_accessor = _unique("cache_test_var") + self.test_var = self.Variable.create( + { + "name": _unique("cache_test_variable"), + "cel_accessor": self.test_accessor, + "source_type": "constant", + "value_type": "number", + "default_value": "100", + "applies_to": "both", + } + ) + + def test_variable_write_invalidates_all_caches(self): + """Test that modifying a variable invalidates all CEL caches.""" + # Step 1: Compile an expression using the variable with a valid field comparison + # Use an expression that makes sense: r.id > constant_value + expr = f"r.id > {self.test_accessor}" + result1 = self.Service.compile_expression(expr, "registry_individuals") + self.assertTrue(result1["valid"], f"Initial compilation failed: {result1.get('error')}") + + # Step 2: Populate caches by resolving the expression + resolution1 = self.Resolver.resolve_for_evaluation(expr, context_type="individual") + self.assertIn("100", resolution1["expression"], "Should expand to r.id > 100") + + # Verify caches are populated + profile_cache_size_before = len(self.cel_registry._profile_cache) + translation_cache_size_before = len(self.cel_translator._translation_cache) + resolver_cache_size_before = len(self.Resolver._variable_cache) + + # At least one cache should have content + self.assertGreater( + profile_cache_size_before + translation_cache_size_before + resolver_cache_size_before, + 0, + "Caches should be populated after compilation", + ) + + # Step 3: Modify the variable's expression + self.test_var.write({"default_value": "200"}) + + # Step 4: Verify all caches were invalidated + profile_cache_size_after = len(self.cel_registry._profile_cache) + translation_cache_size_after = len(self.cel_translator._translation_cache) + resolver_cache_size_after = len(self.Resolver._variable_cache) + + # Caches should be cleared + self.assertEqual(profile_cache_size_after, 0, "Profile cache should be cleared") + self.assertEqual(translation_cache_size_after, 0, "Translation cache should be cleared") + self.assertEqual(resolver_cache_size_after, 0, "Resolver cache should be cleared") + + # Step 5: Compile again and verify new value is used + resolution2 = self.Resolver.resolve_for_evaluation(expr, context_type="individual") + self.assertIn("200", resolution2["expression"], "Should use new value after cache invalidation") + self.assertNotIn("100", resolution2["expression"], "Should not contain old value") + + def test_variable_unlink_invalidates_caches(self): + """Test that deleting a variable invalidates all caches.""" + # Step 1: Compile expression using the variable with valid field comparison + expr = f"r.id > {self.test_accessor}" + result1 = self.Service.compile_expression(expr, "registry_individuals") + self.assertTrue(result1["valid"], f"Initial compilation failed: {result1.get('error')}") + + # Populate resolver cache + resolution1 = self.Resolver.resolve_for_evaluation(expr, context_type="individual") + self.assertIn("100", resolution1["expression"]) + + # Verify caches are populated + cache_sizes_before = ( + len(self.cel_registry._profile_cache) + + len(self.cel_translator._translation_cache) + + len(self.Resolver._variable_cache) + ) + self.assertGreater(cache_sizes_before, 0, "Caches should be populated") + + # Step 2: Delete the variable + self.test_var.unlink() + + # Step 3: Verify all caches were cleared + profile_cache_size = len(self.cel_registry._profile_cache) + translation_cache_size = len(self.cel_translator._translation_cache) + resolver_cache_size = len(self.Resolver._variable_cache) + + self.assertEqual(profile_cache_size, 0, "Profile cache should be cleared") + self.assertEqual(translation_cache_size, 0, "Translation cache should be cleared") + self.assertEqual(resolver_cache_size, 0, "Resolver cache should be cleared") + + # Step 4: Compile again - should find variable missing + resolution2 = self.Resolver.resolve_for_evaluation(expr, context_type="individual") + self.assertIn(self.test_accessor, resolution2["missing_variables"], "Variable should be missing after unlink") + + def test_non_relevant_field_changes_do_not_invalidate(self): + """Test that changing non-cache-relevant fields doesn't invalidate caches.""" + # Compile expression to populate caches - use valid field comparison + expr = f"r.id > {self.test_accessor}" + result = self.Service.compile_expression(expr, "registry_individuals") + self.assertTrue(result["valid"], f"Compilation failed: {result.get('error')}") + self.Resolver.resolve_for_evaluation(expr, context_type="individual") + + # Record cache sizes + profile_cache_size_before = len(self.cel_registry._profile_cache) + translation_cache_size_before = len(self.cel_translator._translation_cache) + resolver_cache_size_before = len(self.Resolver._variable_cache) + + # Verify caches are populated + self.assertGreater(profile_cache_size_before, 0, "Profile cache should be populated") + self.assertGreater(translation_cache_size_before, 0, "Translation cache should be populated") + self.assertGreater(resolver_cache_size_before, 0, "Resolver cache should be populated") + + # Change a non-relevant field (e.g., sequence) + # sequence is NOT in cache_invalidating_fields, so no invalidation should occur + self.test_var.write({"sequence": 999}) + + # Caches should remain unchanged (current behavior: selective invalidation) + profile_cache_size_after = len(self.cel_registry._profile_cache) + translation_cache_size_after = len(self.cel_translator._translation_cache) + resolver_cache_size_after = len(self.Resolver._variable_cache) + + # Verify caches were NOT invalidated for non-relevant field changes + self.assertEqual(profile_cache_size_after, profile_cache_size_before, "Profile cache should not be cleared") + self.assertEqual( + translation_cache_size_after, + translation_cache_size_before, + "Translation cache should not be cleared", + ) + self.assertEqual(resolver_cache_size_after, resolver_cache_size_before, "Resolver cache should not be cleared") + + def test_variable_create_invalidates_caches(self): + """Test that creating a new variable invalidates caches.""" + # Compile expression to populate caches - use valid field comparison + expr = f"r.id > {self.test_accessor}" + result = self.Service.compile_expression(expr, "registry_individuals") + self.assertTrue(result["valid"], f"Compilation failed: {result.get('error')}") + self.Resolver.resolve_for_evaluation(expr, context_type="individual") + + # Verify caches are populated + cache_sizes_before = ( + len(self.cel_registry._profile_cache) + + len(self.cel_translator._translation_cache) + + len(self.Resolver._variable_cache) + ) + self.assertGreater(cache_sizes_before, 0) + + # Create a new variable + new_accessor = _unique("new_var") + self.Variable.create( + { + "name": _unique("new_variable"), + "cel_accessor": new_accessor, + "source_type": "constant", + "value_type": "number", + "default_value": "50", + "applies_to": "both", + } + ) + + # Caches should be cleared + self.assertEqual(len(self.cel_registry._profile_cache), 0) + self.assertEqual(len(self.cel_translator._translation_cache), 0) + self.assertEqual(len(self.Resolver._variable_cache), 0) + + def test_multiple_write_invalidations_safe(self): + """Test that multiple sequential writes don't cause errors.""" + # Use valid field comparison expression + expr = f"r.id > {self.test_accessor}" + + # Write multiple times in sequence + for i in range(5): + self.test_var.write({"default_value": str(100 + i * 10)}) + + # Should compile successfully after each write + result = self.Service.compile_expression(expr, "registry_individuals") + self.assertTrue(result["valid"], f"Compilation should succeed after write {i}: {result.get('error')}") + + # Resolver should use new value + resolution = self.Resolver.resolve_for_evaluation(expr, context_type="individual") + expected_value = str(100 + i * 10) + self.assertIn( + expected_value, resolution["expression"], f"Should use value {expected_value} after write {i}" + ) + + def test_cache_invalidation_with_aggregate_variable(self): + """Test cache invalidation when modifying aggregate variables.""" + # Create an aggregate variable + agg_accessor = _unique("test_agg") + agg_var = self.Variable.create( + { + "name": _unique("test_aggregate"), + "cel_accessor": agg_accessor, + "source_type": "aggregate", + "value_type": "number", + "aggregate_type": "count", + "aggregate_target": "members", + "aggregate_filter": "true", + "applies_to": "group", + } + ) + + # Compile expression using aggregate - this is already a valid expression + # because aggregate variables expand to members.count(...) which is valid + expr = f"{agg_accessor} >= 3" + result1 = self.Service.compile_expression(expr, "registry_groups") + self.assertTrue(result1["valid"], f"Initial compilation failed: {result1.get('error')}") + + # Populate caches + self.Resolver.resolve_for_evaluation(expr, context_type="group") + + # Verify caches populated + cache_sizes = ( + len(self.cel_registry._profile_cache) + + len(self.cel_translator._translation_cache) + + len(self.Resolver._variable_cache) + ) + self.assertGreater(cache_sizes, 0) + + # Modify the aggregate filter + agg_var.write({"aggregate_filter": "age_years(m.birthdate) < 18"}) + + # All caches should be cleared + self.assertEqual(len(self.cel_registry._profile_cache), 0) + self.assertEqual(len(self.cel_translator._translation_cache), 0) + self.assertEqual(len(self.Resolver._variable_cache), 0) + + # Compile again - should use new filter + resolution = self.Resolver.resolve_for_evaluation(expr, context_type="group") + self.assertIn("age_years", resolution["expression"], "Should use new aggregate filter") + + def test_cache_invalidation_cascades_properly(self): + """Test that cache invalidation cascades through all three caches.""" + # This is a comprehensive test to ensure all caches are cleared together + + # Step 1: Populate all caches + # Use valid field comparison expression + expr = f"r.id > {self.test_accessor}" + + # Populate profile cache + self.Registry.load_profile("registry_individuals") + + # Populate translation cache + cfg = self.Registry.load_profile("registry_individuals") + self.Translator.translate("res.partner", "r.id > 0", cfg) + + # Populate resolver cache + self.Resolver.resolve_for_evaluation(expr, context_type="individual") + + # Verify all caches have content + self.assertGreater(len(self.cel_registry._profile_cache), 0, "Profile cache should be populated") + self.assertGreater(len(self.cel_translator._translation_cache), 0, "Translation cache should be populated") + self.assertGreater(len(self.Resolver._variable_cache), 0, "Resolver cache should be populated") + + # Step 2: Modify variable + self.test_var.write({"default_value": "150"}) + + # Step 3: Verify ALL caches cleared in one operation + self.assertEqual(len(self.cel_registry._profile_cache), 0, "Profile cache should be cleared") + self.assertEqual(len(self.cel_translator._translation_cache), 0, "Translation cache should be cleared") + self.assertEqual(len(self.Resolver._variable_cache), 0, "Resolver cache should be cleared") From 3b5c0ce46b0a0c32e967c2f695d707c34ffdc472 Mon Sep 17 00:00:00 2001 From: Jeremi Joslin Date: Wed, 18 Feb 2026 16:37:52 +0700 Subject: [PATCH 2/2] fix(spp_cel_domain): strengthen relational predicate test assertion Replace weak operator-only check with an exact tuple match to prevent false positives from incorrect values (e.g. True instead of False). --- spp_cel_domain/tests/test_cel_relational_predicate.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/spp_cel_domain/tests/test_cel_relational_predicate.py b/spp_cel_domain/tests/test_cel_relational_predicate.py index aa9cd003..45d8e947 100644 --- a/spp_cel_domain/tests/test_cel_relational_predicate.py +++ b/spp_cel_domain/tests/test_cel_relational_predicate.py @@ -73,10 +73,7 @@ def test_one2many_predicate_produces_correct_domain(self): self.assertTrue(result["valid"], f"Error: {result.get('error')}") # The domain should contain a check for "has records" domain = result["domain"] - # Look for the field check in the domain - found = False - for leaf in domain: - if isinstance(leaf, tuple) and leaf[0] == "program_membership_ids" and leaf[1] == "!=": - found = True - break - self.assertTrue(found, f"Expected '!= False' domain for one2many, got: {domain}") + # Look for the exact expected leaf in the domain + expected_leaf = ("program_membership_ids", "!=", False) + found = any(leaf == expected_leaf for leaf in domain if isinstance(leaf, tuple)) + self.assertTrue(found, f"Expected '{expected_leaf}' to be in domain, but got: {domain}")