Source code for marivo.datasource.ir

"""Intermediate representation for project-level datasources."""

from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import PurePosixPath
from typing import Any, Literal

__all__ = [
    "AiContextIR",
    "CsvSourceIR",
    "DatasourceAiContextIR",
    "DatasourceIR",
    "DatasourceSourceLocation",
    "EntitySourceIR",
    "ParquetSourceIR",
    "TableSourceIR",
    "qualify_provenance_sql",
    "source_name",
    "source_to_dict",
]


[docs] @dataclass(frozen=True) class DatasourceSourceLocation: """Absolute source location for datasource error reporting.""" file: str line: int
[docs] @dataclass(frozen=True) class AiContextIR: """Immutable AI-facing context stored on semantic and datasource objects.""" business_definition: str | None = None guardrails: tuple[str, ...] = () synonyms: tuple[str, ...] = () examples: tuple[str, ...] = () instructions: str | None = None owner_notes: str | None = None
DatasourceAiContextIR = AiContextIR
[docs] @dataclass(frozen=True) class DatasourceIR: """Project-level datasource configuration.""" semantic_id: str name: str backend_type: str fields: dict[str, Any] env_refs: dict[str, str] ai_context: AiContextIR python_symbol: str location: DatasourceSourceLocation
# --------------------------------------------------------------------------- # Physical source descriptors # --------------------------------------------------------------------------- @dataclass(frozen=True) class TableSourceIR: """Physical table source for a dataset.""" table: str database: str | tuple[str, ...] | None = None kind: Literal["table"] = "table" def to_dict(self) -> dict[str, object]: database: str | list[str] | None = ( list(self.database) if isinstance(self.database, tuple) else self.database ) return {"kind": self.kind, "table": self.table, "database": database} def to_ir(self) -> TableSourceIR: return self
[docs] @dataclass(frozen=True) class ParquetSourceIR: """Physical parquet source for an entity.""" path: str hive_partitioning: bool = False columns: tuple[str, ...] | None = None kind: Literal["parquet"] = "parquet" def to_dict(self) -> dict[str, object]: return { "kind": self.kind, "path": self.path, "hive_partitioning": self.hive_partitioning, "columns": list(self.columns) if self.columns is not None else None, } def to_ir(self) -> ParquetSourceIR: return self
[docs] @dataclass(frozen=True) class CsvSourceIR: """Physical CSV source for an entity.""" path: str header: bool = True delimiter: str = "," columns: tuple[str, ...] | None = None kind: Literal["csv"] = "csv" def to_dict(self) -> dict[str, object]: return { "kind": self.kind, "path": self.path, "header": self.header, "delimiter": self.delimiter, "columns": list(self.columns) if self.columns is not None else None, } def to_ir(self) -> CsvSourceIR: return self
EntitySourceIR = TableSourceIR | ParquetSourceIR | CsvSourceIR _GLOB_CHARS = re.compile(r"[*?\\[]") _SOURCE_NAME_CHARS = re.compile(r"[^0-9A-Za-z_]+") def _sanitize_source_name(value: str) -> str: name = _SOURCE_NAME_CHARS.sub("_", value).strip("_").lower() return name or "file_source" def source_name(source: EntitySourceIR) -> str: if isinstance(source, TableSourceIR): return source.table normalized_path = source.path.replace("\\", "/").rstrip("/") path = PurePosixPath(normalized_path) raw_name = path.name raw_name = path.parent.name if _GLOB_CHARS.search(raw_name) else PurePosixPath(raw_name).stem return _sanitize_source_name(raw_name) def source_to_dict(source: EntitySourceIR) -> dict[str, object]: if isinstance(source, TableSourceIR): database: str | list[str] | None = ( list(source.database) if isinstance(source.database, tuple) else source.database ) return {"kind": "table", "table": source.table, "database": database} return source.to_dict() def qualify_provenance_sql( provenance_sql: str, table_qualifiers: dict[str, str], *, dialect: str | None = None, ) -> str: """Qualify unqualified table references in provenance SQL. Rewrites bare table names that match keys in *table_qualifiers* to their fully-qualified form (e.g. ``orders`` -> ``iceberg_inf.orders``). Tables that are already qualified or that reference CTE aliases are left unchanged. If sqlglot cannot parse the SQL, the original string is returned unchanged. Args: provenance_sql: Raw SQL string from metric provenance. table_qualifiers: Mapping from bare table name to database-qualified name (e.g. ``{"orders": "iceberg_inf.orders"}``). dialect: Optional sqlglot dialect for parsing and generating. Returns: SQL string with unqualified table references replaced by qualified ones. """ if not table_qualifiers: return provenance_sql import sqlglot from sqlglot import exp try: parsed = sqlglot.parse_one(provenance_sql, dialect=dialect) except sqlglot.errors.ParseError: return provenance_sql # Collect CTE alias names so we don't qualify CTE references. cte_names: set[str] = set() for cte in parsed.find_all(exp.CTE): alias = cte.alias cte_names.add(alias if isinstance(alias, str) else alias.sql(dialect=dialect)) for table_node in parsed.find_all(exp.Table): # Skip CTE references. if table_node.name in cte_names: continue # Skip tables that are already qualified. if table_node.db: continue qualified = table_qualifiers.get(table_node.name) if qualified is None: continue # Split qualified name into catalog/db/name parts. # "db.table" -> db + table # "catalog.db.table" -> catalog + db + table parts = qualified.split(".") if len(parts) == 2: table_node.set("db", exp.to_identifier(parts[0])) table_node.set("this", exp.to_identifier(parts[1])) elif len(parts) == 3: table_node.set("catalog", exp.to_identifier(parts[0])) table_node.set("db", exp.to_identifier(parts[1])) table_node.set("this", exp.to_identifier(parts[2])) else: # Can't map arbitrary multi-part names; skip. continue return str(parsed.sql(dialect=dialect))