Source code for marivo.semantic.parity

"""SQL parity checking and status propagation for marivo.semantic v1.1.

Implements parity_check and the derived-metric parity status propagation
algorithm.
"""

from __future__ import annotations

import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from marivo.datasource.ir import TableSourceIR, qualify_provenance_sql
from marivo.semantic.errors import ErrorKind, SemanticParityError, _raise
from marivo.semantic.ir import MetricIR, ParityStatus

if TYPE_CHECKING:
    from marivo.semantic.catalog import SemanticCatalog
    from marivo.semantic.reader import SemanticProject

__all__ = [
    "ParityResult",
    "parity_check",
    "propagated_parity_status",
]


[docs] @dataclass(frozen=True) class ParityResult: """Result of a single metric parity check.""" ok: bool expected: float | int | None = None actual: float | int | None = None rel_tol: float | None = None abs_tol: float | None = None error: SemanticParityError | None = None
def _extract_scalar( result: Any, metric_id: str, label: str, ) -> float: """Extract a single scalar float from an ibis to_pandas() result. Scalar expressions return a plain number; table-like results return a DataFrame. Raises SemanticParityError if the result is not scalar. """ if isinstance(result, (int, float)): return float(result) # DataFrame-like result — use duck-typing to avoid importing pandas if hasattr(result, "iloc") and hasattr(result, "columns"): if len(result) != 1 or len(result.columns) != 1: _raise( ErrorKind.PARITY_NOT_SCALAR, f"{label} for metric {metric_id!r} did not produce a single scalar value.", cls=SemanticParityError, refs=(metric_id,), ) return float(result.iloc[0, 0]) try: return float(result) except (TypeError, ValueError): _raise( ErrorKind.PARITY_NOT_SCALAR, f"{label} for metric {metric_id!r} did not produce a scalar value: {type(result).__name__}", cls=SemanticParityError, refs=(metric_id,), ) def _get_metric_or_raise(project: SemanticProject, metric_id: str) -> MetricIR: """Look up a metric by ID or raise METRIC_NOT_FOUND.""" reg = project._registry if reg is None: _raise( ErrorKind.METRIC_NOT_FOUND, f"Metric {metric_id!r} not found: project is not loaded.", cls=SemanticParityError, refs=(metric_id,), ) metric_ir = reg.metrics.get(metric_id) if metric_ir is None: _raise( ErrorKind.METRIC_NOT_FOUND, f"Metric {metric_id!r} not found in registry.", cls=SemanticParityError, refs=(metric_id,), ) return metric_ir def parity_check( project: SemanticProject | SemanticCatalog, metric_id: str, *, rel_tol: float | None = None, abs_tol: float | None = None, force: bool = False, ) -> ParityResult: """Run parity check for a base metric against its source SQL. Raises SemanticParityError for pre-condition violations: - Metric not found - Derived metric (not supported for direct SQL parity) - Missing sql_parity verification mode, provenance SQL, or dialect - Dialect mismatch with datasource backend_type - Cross-datasource metric Returns ParityResult on success or value mismatch. """ from marivo.semantic.catalog import SemanticCatalog, SemanticKind from marivo.semantic.refs import make_ref catalog = project if isinstance(project, SemanticCatalog) else SemanticCatalog(project) cache_project = catalog._project # Check cache first cached = cache_project._parity_results.get(metric_id) if cached is not None and not force: return cached metric_ir = _get_metric_or_raise(cache_project, metric_id) # Derived metrics don't support direct SQL parity if metric_ir.metric_type == "derived": _raise( ErrorKind.PROVENANCE_DIALECT_MISSING, f"Derived metric {metric_id!r} does not support direct SQL parity check. " f"Check component metrics instead.", cls=SemanticParityError, refs=(metric_id,), ) # Must have provenance SQL (enables parity verification) if metric_ir.provenance is None or not metric_ir.provenance.sql: _raise( ErrorKind.PROVENANCE_DIALECT_MISSING, f"Metric {metric_id!r} has no provenance SQL. " f"Add provenance=ms.from_sql(...) to the decorator before running parity checks.", cls=SemanticParityError, refs=(metric_id,), ) # Must have dialect assert metric_ir.provenance is not None if not metric_ir.provenance.dialect: _raise( ErrorKind.PROVENANCE_DIALECT_MISSING, f"Metric {metric_id!r} has no provenance dialect. " f"Add dialect= to ms.from_sql(...) before running parity checks.", cls=SemanticParityError, refs=(metric_id,), ) # Validate single datasource reg = cache_project._registry assert reg is not None # Already validated above datasource_ids: set[str] = set() for ds_ref in metric_ir.entities: ds_ir = reg.entities.get(ds_ref) if ds_ir is not None: datasource_ids.add(ds_ir.datasource) if len(datasource_ids) > 1: _raise( ErrorKind.CROSS_DATASOURCE_NOT_SUPPORTED, f"Metric {metric_id!r} references entities from " f"multiple datasources: {datasource_ids}. " f"All entities in a metric must share the same datasource.", cls=SemanticParityError, refs=(metric_id,), ) # Determine the single datasource if not datasource_ids: _raise( ErrorKind.PROVENANCE_DIALECT_MISSING, f"Metric {metric_id!r} has no entities; cannot determine datasource.", cls=SemanticParityError, refs=(metric_id,), ) datasource_id = next(iter(datasource_ids)) # Execute the ibis metric -> single scalar try: resolver = catalog._resolver() metric_expr = resolver.metric(make_ref(metric_id, SemanticKind.METRIC)) actual_result = metric_expr.to_pandas() actual_val = _extract_scalar(actual_result, metric_id, "Metric") except SemanticParityError: raise except Exception as exc: return ParityResult( ok=False, expected=None, actual=None, rel_tol=rel_tol, abs_tol=abs_tol, error=SemanticParityError( kind=ErrorKind.MATERIALIZE_FAILED, message=f"Failed to materialize metric {metric_id!r}: {exc}", refs=(metric_id,), ), ) # Build table qualifiers from entity sources for automatic qualification. # When source.database is set on the entity, use it directly. # When source.database is absent, fall back to the datasource's database # field (e.g. MySQL/ClickHouse datasources declare database at the # connection level). table_qualifiers: dict[str, str] = {} for ds_ref in metric_ir.entities: entity_ir = reg.entities.get(ds_ref) if entity_ir is None: continue source = entity_ir.source if not isinstance(source, TableSourceIR): continue db: str | tuple[str, ...] | None = source.database if db is None: # Fall back to the datasource's database field. datasource_ir = reg.datasources.get(entity_ir.datasource) if datasource_ir is not None: ds_db = datasource_ir.fields.get("database") if isinstance(ds_db, str): db = ds_db if db is not None: if isinstance(db, tuple): db = ".".join(db) table_qualifiers[source.table] = f"{db}.{source.table}" qualified_sql = qualify_provenance_sql( metric_ir.provenance.sql, table_qualifiers, dialect=metric_ir.provenance.dialect, ) # Execute the source SQL -> single scalar try: service = cache_project._connection_service() with service.use_backend(datasource_id) as backend: sql_result = backend.sql(qualified_sql) sql_pandas = sql_result.to_pandas() expected_val = _extract_scalar(sql_pandas, metric_id, "Source SQL") except SemanticParityError: raise except Exception as exc: return ParityResult( ok=False, expected=None, actual=actual_val, rel_tol=rel_tol, abs_tol=abs_tol, error=SemanticParityError( kind=ErrorKind.COMPILE_ERROR, message=f"Failed to execute source SQL for metric {metric_id!r}: {exc}", refs=(metric_id,), ), ) # Compare values ok = _values_match(actual_val, expected_val, rel_tol=rel_tol, abs_tol=abs_tol) result = ParityResult( ok=ok, expected=expected_val, actual=actual_val, rel_tol=rel_tol, abs_tol=abs_tol, ) # Cache the result cache_project._parity_results[metric_id] = result return result def _values_match( actual: float, expected: float, *, rel_tol: float | None = None, abs_tol: float | None = None, ) -> bool: """Compare two numeric values with optional tolerances. If both rel_tol and abs_tol are None, requires exact match. Otherwise uses math.isclose with the provided tolerances. """ if rel_tol is None and abs_tol is None: return actual == expected rt = rel_tol if rel_tol is not None else 0.0 at = abs_tol if abs_tol is not None else 0.0 return math.isclose(actual, expected, rel_tol=rt, abs_tol=at) def compute_self_status( project: Any, metric_id: str, ) -> ParityStatus: """Compute the self-status of a metric based on provenance and parity results. This is Step 1 of the two-step status computation. | provenance_sql | last parity_check | self status | |----------------|-------------------|-------------| | absent | any | VERIFIED | | present | no / not run | UNVERIFIED | | present | ok=True | VERIFIED | | present | ok=False | DRIFTED | Derived metrics do not have a self verification mode. Their status is determined by propagated component statuses. """ metric_ir = _get_metric_or_raise(project, metric_id) prov = metric_ir.provenance if metric_ir.metric_type == "derived": return ParityStatus.UNVERIFIED if prov is None: return ParityStatus.VERIFIED # Metrics with SQL provenance compute status from the latest in-memory parity result. parity_result = project._parity_results.get(metric_id) if parity_result is None: # No parity check has been run return ParityStatus.UNVERIFIED if parity_result.ok: return ParityStatus.VERIFIED else: return ParityStatus.DRIFTED def propagated_parity_status( project: Any, metric_id: str, ) -> ParityStatus: """Compute the effective parity status for a metric, including propagation. Step 1: Compute self-status from provenance and parity results. Step 2: For derived metrics, propagate from component statuses. Derived metrics cannot be directly SQL-parity-checked; their own UNVERIFIED self-status is ignored and component statuses determine the effective status. Propagation rules (derived metrics only): - If any status is DRIFTED -> DRIFTED - If any component is UNVERIFIED -> UNVERIFIED - If all components are VERIFIED -> VERIFIED """ metric_ir = _get_metric_or_raise(project, metric_id) self_status = compute_self_status(project, metric_id) # Base metrics: just return self status if metric_ir.metric_type != "derived": return self_status # Derived metrics: collect component statuses component_statuses: list[ParityStatus] = [] if metric_ir.composition is not None: from marivo.semantic.ir import composition_components for comp_id in composition_components(metric_ir.composition).values(): comp_status = propagated_parity_status(project, comp_id) component_statuses.append(comp_status) if not component_statuses: return ParityStatus.UNVERIFIED if any(s == ParityStatus.DRIFTED for s in component_statuses): return ParityStatus.DRIFTED if any(s == ParityStatus.UNVERIFIED for s in component_statuses): return ParityStatus.UNVERIFIED if all(s == ParityStatus.VERIFIED for s in component_statuses): return ParityStatus.VERIFIED return ParityStatus.UNVERIFIED