Source code for marivo.datasource.manage

"""Unified datasource management API (md.*)."""

from __future__ import annotations

import builtins
import time
from collections.abc import Iterable, Iterator, Mapping
from contextlib import suppress
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

from marivo.datasource import backends as _backends
from marivo.datasource import secrets as _secrets
from marivo.datasource import store as _store
from marivo.datasource.authoring import (
    DatasourceSpec,
)
from marivo.datasource.errors import DatasourceMissingError, DatasourcePreviewError
from marivo.datasource.ir import CsvSourceIR, EntitySourceIR, ParquetSourceIR, TableSourceIR
from marivo.datasource.metadata import TableMetadata
from marivo.datasource.metadata import inspect_source as _inspect_source
from marivo.datasource.runtime import DatasourceConnectionService
from marivo.datasource.scan import (
    ColumnInspection,
    ColumnProfile,
    JoinKeyProbe,
    JoinSide,
    ScanReport,
    ScanScope,
)
from marivo.preview import (
    PREVIEW_DEFAULT_LIMIT,
    PreviewFilter,
    PreviewOrder,
    PreviewResult,
    PreviewSamplePolicy,
    preview_ibis_table,
)
from marivo.render import format_bounded_card, result_repr


[docs] @dataclass(frozen=True, repr=False) class DatasourceSummary: """Summary row for one configured project datasource.""" name: str backend_type: str @property def semantic_id(self) -> str: """Stable id used by discovery surfaces; equals ``name``.""" return self.name def _repr_identity(self) -> str: return f"DatasourceSummary name={self.name} backend={self.backend_type}" def render(self) -> str: return format_bounded_card( identity=self._repr_identity(), status=None, available=(".render()", ".show()"), ) def __repr__(self) -> str: return result_repr(self._repr_identity()) def show(self) -> None: print(self.render())
[docs] @dataclass(frozen=True, repr=False) class DatasourceList: """Displayable collection of configured project datasource summaries.""" _items: tuple[DatasourceSummary, ...] @property def items(self) -> tuple[DatasourceSummary, ...]: """Return all datasource summary rows.""" return self._items
[docs] def ids(self) -> builtins.list[str]: """Return datasource names in display order.""" return [item.name for item in self._items]
def __len__(self) -> int: return len(self._items) def __iter__(self) -> Iterator[DatasourceSummary]: return iter(self._items) def __getitem__(self, index: int) -> DatasourceSummary: return self._items[index] def _repr_identity(self) -> str: return f"DatasourceList count={len(self._items)}" def render(self) -> str: rows = [[item.name, item.backend_type] for item in self._items] return format_bounded_card( identity=self._repr_identity(), columns=["name", "backend"], rows=rows, row_count=len(self._items), preview_truncation_hint="inspect .items for all datasources", available=(".items", ".ids()", ".render()", ".show()"), ) def __repr__(self) -> str: return result_repr(self._repr_identity()) def show(self) -> None: print(self.render())
[docs] @dataclass(frozen=True, repr=False) class DatasourceDescription: """Literal fields and env refs for one datasource.""" name: str backend_type: str literal_fields: dict[str, Any] env_refs: dict[str, str] def _repr_identity(self) -> str: return ( f"DatasourceDescription name={self.name} backend={self.backend_type} " f"fields={len(self.literal_fields)} env_refs={len(self.env_refs)}" ) def render(self) -> str: field_names = sorted(self.literal_fields)[:8] env_ref_names = sorted(self.env_refs)[:8] return format_bounded_card( identity=self._repr_identity(), columns=field_names + [f"{name}_env" for name in env_ref_names], available=(".render()", ".show()"), ) def __repr__(self) -> str: return result_repr(self._repr_identity()) def show(self) -> None: print(self.render())
[docs] @dataclass(frozen=True, repr=False) class DatasourceTestResult: """Result of a datasource connectivity round-trip.""" name: str ok: bool error: str | None latency_ms: int | None def _repr_identity(self) -> str: latency = "n/a" if self.latency_ms is None else f"{self.latency_ms}ms" return f"DatasourceTestResult name={self.name} ok={self.ok} latency={latency}" def render(self) -> str: return format_bounded_card( identity=self._repr_identity(), status=self.error, available=(".render()", ".show()"), ) def __repr__(self) -> str: return result_repr(self._repr_identity()) def show(self) -> None: print(self.render())
[docs] def register( spec: DatasourceSpec, *, project_root: Path | None = None, ) -> DatasourceSummary: """Create or replace a project datasource file from a DatasourceSpec. Args: spec: An internal backend datasource spec such as ``_DuckDBSpec`` or ``_TrinoSpec``. project_root: Optional project root directory; defaults to cwd. Returns: A ``DatasourceSummary`` for the newly stored datasource. Example: >>> from marivo.datasource.authoring import _DuckDBSpec >>> import marivo.datasource as md >>> md.register(_DuckDBSpec(name="wh", path=":memory:")) Constraints: Use one of the internal backend specs such as ``_DuckDBSpec`` or ``_TrinoSpec``. Sensitive fields use named ``*_env`` references, not plaintext literals or generic keyword bags. For datasource file authoring, prefer ``md.duckdb()``, ``md.trino()``, etc. """ stored = _store.save_one(spec, project_root=project_root) return DatasourceSummary(name=stored.name, backend_type=stored.backend_type)
[docs] def remove(name: str) -> bool: """Delete the named project datasource file. Args: name: The datasource name to remove. Returns: True if the file existed and was deleted; False if it was not found. Example: >>> import marivo.datasource as md >>> md.remove("wh") True Constraints: Only the project-local ``models/datasources/<name>.py`` file is removed. """ return _store.delete_one(name)
[docs] def list() -> DatasourceList: """List configured project datasources as a displayable DatasourceList. Returns: ``DatasourceList`` containing sorted ``DatasourceSummary`` rows. Example: >>> import marivo.datasource as md >>> md.list().show() >>> md.list().items Constraints: Only datasources with a persisted project file are included. """ return DatasourceList( tuple( DatasourceSummary(name=p.name, backend_type=p.backend_type) for p in sorted(_store.load_all().values(), key=lambda item: item.name) ) )
[docs] def describe(name: str) -> DatasourceDescription: """Show literal fields and env refs for one datasource. Args: name: The datasource name to describe. Returns: A ``DatasourceDescription`` with literal_fields and env_refs. Example: >>> import marivo.datasource as md >>> md.describe("wh") Constraints: Raises ``DatasourceMissingError`` when the name has no project file. """ datasource = _store.load_one(name) if datasource is None: raise DatasourceMissingError( message=f"datasource {name!r} is not configured", details={"datasource": name, "available": _store.list_names()}, ) return DatasourceDescription( name=datasource.name, backend_type=datasource.backend_type, literal_fields=dict(datasource.fields), env_refs=dict(datasource.env_refs), )
[docs] def connect(name: str) -> Any: """Open a live ibis backend for a datasource; caller disconnects. Args: name: The datasource name to connect to. Returns: A live ibis backend. The caller must call ``.disconnect()`` when done. Example: >>> import marivo.datasource as md >>> backend = md.connect("wh") >>> try: ... backend.raw_sql("SELECT 1") ... finally: ... backend.disconnect() Constraints: The caller owns the backend lifetime and must call ``disconnect()``. Env-sourced secrets used to open this backend are remembered on the backend object so that a subsequent round-trip validation can persist them via ``secrets.persist_backend_env_sourced``. """ datasource = _store.load_one(name) if datasource is None: raise DatasourceMissingError( message=f"datasource {name!r} is not configured", details={"datasource": name, "available": _store.list_names()}, ) built = _backends.build_backend_with_secrets(datasource) _secrets.remember_env_sourced(built.backend, built.env_sourced_secrets) return built.backend
def _preview_ref(datasource: str, table: str, database: str | tuple[str, ...] | None) -> str: if database is None: return f"{datasource}.{table}" namespace = ".".join(database) if isinstance(database, tuple) else database return f"{datasource}.{namespace}.{table}" def _validate_filter(raw_filter: object) -> PreviewFilter: if not isinstance(raw_filter, Mapping): raise DatasourcePreviewError( message="preview where entries must be structured preview filter mappings", details={"field": "where", "value": repr(raw_filter)}, ) column = raw_filter.get("column") op = raw_filter.get("op") if not isinstance(column, str) or not column: raise DatasourcePreviewError( message="preview filter column must be a non-empty string", details={"field": "where.column", "value": repr(column)}, ) allowed_ops = {"=", "!=", "<", "<=", ">", ">=", "in", "is_null", "is_not_null"} if op not in allowed_ops: raise DatasourcePreviewError( message="preview filter op is not supported", details={"field": "where.op", "value": repr(op), "allowed": sorted(allowed_ops)}, ) if op not in {"is_null", "is_not_null"} and "value" not in raw_filter: raise DatasourcePreviewError( message="preview filter value is required for this op", details={"field": "where.value", "op": op}, ) out: PreviewFilter = {"column": column, "op": op} if "value" in raw_filter: out["value"] = raw_filter["value"] return out def _validate_order(raw_order: object) -> PreviewOrder: if not isinstance(raw_order, Mapping): raise DatasourcePreviewError( message="preview order_by entries must be structured preview order mappings", details={"field": "order_by", "value": repr(raw_order)}, ) column = raw_order.get("column") direction = raw_order.get("direction", "asc") if not isinstance(column, str) or not column: raise DatasourcePreviewError( message="preview order column must be a non-empty string", details={"field": "order_by.column", "value": repr(column)}, ) if direction not in {"asc", "desc"}: raise DatasourcePreviewError( message="preview order direction must be 'asc' or 'desc'", details={"field": "order_by.direction", "value": repr(direction)}, ) return {"column": column, "direction": direction} def _require_column(available: Iterable[str], column: str, *, field: str) -> None: available_columns = tuple(available) if column not in available_columns: raise DatasourcePreviewError( message=f"preview references unknown column {column!r}", details={"field": field, "column": column, "available": available_columns}, ) def _apply_preview_filter(expr: Any, preview_filter: PreviewFilter) -> Any: column = preview_filter["column"] _require_column(expr.columns, column, field="where.column") value = expr[column] op = preview_filter["op"] if op == "=": return expr.filter(value == preview_filter["value"]) if op == "!=": return expr.filter(value != preview_filter["value"]) if op == "<": return expr.filter(value < preview_filter["value"]) if op == "<=": return expr.filter(value <= preview_filter["value"]) if op == ">": return expr.filter(value > preview_filter["value"]) if op == ">=": return expr.filter(value >= preview_filter["value"]) if op == "in": raw_value = preview_filter["value"] if isinstance(raw_value, str) or not isinstance(raw_value, Iterable): raise DatasourcePreviewError( message="preview 'in' filter value must be a non-string iterable", details={"field": "where.value", "op": "in", "value": repr(raw_value)}, ) return expr.filter(value.isin(builtins.list(raw_value))) if op == "is_null": return expr.filter(value.isnull()) if op == "is_not_null": return expr.filter(value.notnull()) raise DatasourcePreviewError( message="preview filter op is not supported", details={"field": "where.op", "value": op}, ) def _apply_preview_order(expr: Any, preview_order: PreviewOrder) -> tuple[Any, str]: column = preview_order["column"] _require_column(expr.columns, column, field="order_by.column") direction = preview_order.get("direction", "asc") column_expr = expr[column] if direction == "desc": return expr.order_by(column_expr.desc()), f"{column} desc" return expr.order_by(column_expr), f"{column} asc"
[docs] def preview( datasource: str, *, table: str, database: str | tuple[str, ...] | None = None, columns: Iterable[str] | None = None, limit: int = PREVIEW_DEFAULT_LIMIT, where: Iterable[PreviewFilter] | None = None, order_by: Iterable[PreviewOrder] | None = None, include_types: bool = True, ) -> PreviewResult: """Bounded, filtered preview of one datasource table. Args: datasource: Name of the project datasource. table: Table name within the datasource. database: Optional database/catalog path. columns: Optional column subset to select. limit: Maximum rows to return (default 100). where: Structured filter mappings (column, op, value). order_by: Structured order mappings (column, direction). include_types: Whether to include column type information. Returns: A ``PreviewResult`` with rows, columns, types, and sample metadata. Example: >>> import marivo.datasource as md >>> md.preview("wh", table="orders", limit=5) Constraints: The backend is always disconnected before returning, even on error. Raw SQL filters are rejected; use structured ``where`` mappings. """ backend: Any | None = None try: backend = connect(datasource) expr = backend.table(table) if database is None else backend.table(table, database=database) selected_columns = tuple(columns or ()) for column in selected_columns: _require_column(expr.columns, column, field="columns") if selected_columns: expr = expr.select(*selected_columns) filters = tuple(_validate_filter(item) for item in (where or ())) for preview_filter in filters: expr = _apply_preview_filter(expr, preview_filter) order_labels: builtins.list[str] = [] orders = tuple(_validate_order(item) for item in (order_by or ())) for preview_order in orders: expr, label = _apply_preview_order(expr, preview_order) order_labels.append(label) sample_policy = PreviewSamplePolicy( method="ordered_limit" if order_labels else "bounded_limit", limit=limit, order_by=tuple(order_labels), filters=filters, ) from marivo.datasource.timezone import system_timezone_name report_tz = system_timezone_name() return preview_ibis_table( expr, kind="datasource_table", ref=_preview_ref(datasource, table, database), limit=limit, sample_policy=sample_policy, include_types=include_types, report_tz=report_tz, ) except DatasourcePreviewError: raise except Exception as exc: raise DatasourcePreviewError( message=f"failed to preview datasource table {datasource!r}.{table!r}: {exc}", details={"datasource": datasource, "table": table, "database": database}, ) from exc finally: if backend is not None: disconnect = getattr(backend, "disconnect", None) if callable(disconnect): with suppress(Exception): disconnect()
[docs] def inspect_table( datasource: str, source: EntitySourceIR | None = None, *, table: str | None = None, database: str | tuple[str, ...] | None = None, include_partitions: bool = True, project_root: Path | None = None, ) -> TableMetadata: """Schema, comments, nullability, and partition metadata for a table. Args: datasource: Name of the project datasource. source: An ``EntitySourceIR`` (from ``md.table()``, ``md.parquet()``, or ``md.csv()``). Pass either ``source`` or ``table``, not both. table: Table name within the datasource (alternative to ``source``). database: Optional database/catalog path. include_partitions: Whether to include partition hints. project_root: Optional project root directory; defaults to cwd. Returns: A ``TableMetadata`` with columns, warnings, and optional partitions. Example: >>> import marivo.datasource as md >>> md.inspect_table("wh", md.table("orders")) Constraints: Opens and closes a backend connection internally. """ if source is not None and table is not None: raise TypeError("Pass either source or table, not both.") if source is None: if table is None: raise TypeError("inspect_table requires a structured source or table name.") source = TableSourceIR(table=table, database=database) return _inspect_source( datasource, source=source, include_partitions=include_partitions, project_root=project_root, )
[docs] def inspect_source( datasource: str, *, source: EntitySourceIR, include_partitions: bool = True, project_root: Path | None = None, ) -> TableMetadata: """Table metadata for a semantic entity source (table or file). Args: datasource: Name of the project datasource. source: An ``EntitySourceIR`` describing the table or file. include_partitions: Whether to include partition hints. project_root: Optional project root directory; defaults to cwd. Returns: A ``TableMetadata`` with columns, warnings, and optional partitions. Example: >>> import marivo.datasource as md >>> md.inspect_source("wh", source=source_ir) Constraints: Opens and closes a backend connection internally. """ return _inspect_source( datasource, source=source, include_partitions=include_partitions, project_root=project_root, )
[docs] def inspect_columns( datasource: str, source: EntitySourceIR, *, columns: tuple[str, ...] | None = None, scope: ScanScope | None = None, project_root: Path | None = None, ) -> ColumnInspection: """Profile selected columns from a datasource source with bounded scan. Args: datasource: Name of the project datasource. source: An ``EntitySourceIR`` (from ``md.table()``, ``md.parquet()``, or ``md.csv()``). columns: Column names to profile; ``None`` profiles all columns (capped by ``scope.max_columns``). scope: Bounded scan configuration; defaults to ``ScanScope()``. project_root: Optional project root directory; defaults to cwd. Returns: A ``ColumnInspection`` with per-column profiles and a ``ScanReport``. Example: >>> import marivo.datasource as md >>> md.inspect_columns( ... "wh", ... md.table("orders"), ... columns=("status", "amount"), ... scope=md.ScanScope(partition=None, max_rows=100), ... ) Constraints: The backend is always disconnected before returning, even on error. Scan scope limits (max_rows, max_columns) are always enforced. """ if scope is None: scope = ScanScope() metadata = _inspect_source( datasource, source=source, include_partitions=False, project_root=project_root, ) # Determine which columns to profile. all_column_names = tuple(column.name for column in metadata.columns) requested = columns if columns is not None else all_column_names selected_columns = requested[: scope.max_columns] warnings: builtins.list[str] = [] # Warn when columns are truncated by max_columns. if len(requested) > scope.max_columns: truncated_count = len(requested) - len(selected_columns) truncated_names = requested[scope.max_columns :] warnings.append( f"column list truncated by max_columns={scope.max_columns}: " f"{truncated_count} columns not profiled " f"(first omitted: {', '.join(str(c) for c in truncated_names[:3])}); " f"pass scope=ScanScope(max_columns={len(requested)}) to profile all columns" ) # Build column spec lookup from metadata. column_specs: dict[str, tuple[str, bool | None, str | None]] = { column.name: (column.type, column.nullable, column.comment) for column in metadata.columns } # Resolve partition for the scan report. partition_resolution: str partition_used: Mapping[str, str] | None = None if scope.partition is None: partition_resolution = "unpruned" elif scope.partition == "latest": partition_resolution = "latest" else: partition_resolution = "explicit" partition_used = dict(scope.partition) # Execute the bounded sample. start = time.perf_counter() frame = _execute_scoped_sample( datasource, source, selected_columns=selected_columns, scope=scope, project_root=project_root, ) elapsed = time.perf_counter() - start rows_scanned = len(frame) truncated = rows_scanned >= scope.max_rows # Profile each column. profiles: builtins.list[ColumnProfile] = [] for column_name in selected_columns: spec = column_specs.get(column_name) if spec is None: warnings.append(f"column {column_name!r} absent from source schema") profiles.append( ColumnProfile( name=column_name, data_type="UNKNOWN", nullable=None, comment=None, null_count=0, empty_count=0, distinct_count=0, top_values=(), sample_values=(), min_value=None, max_value=None, ) ) continue data_type, nullable, comment = spec if column_name not in frame: warnings.append(f"column {column_name!r} absent from bounded sample") profiles.append( ColumnProfile( name=column_name, data_type=data_type, nullable=nullable, comment=comment, null_count=0, empty_count=0, distinct_count=0, top_values=(), sample_values=(), min_value=None, max_value=None, ) ) continue profiles.append(_profile_column(frame, column_name, data_type, nullable, comment)) scan_report = ScanReport( partition_used=partition_used, partition_resolution=partition_resolution, # type: ignore[arg-type] rows_scanned=rows_scanned, columns_scanned=tuple(selected_columns), truncated=truncated, elapsed_seconds=elapsed, warnings=tuple(warnings), ) return ColumnInspection( datasource=datasource, source=source if isinstance(source, (TableSourceIR, ParquetSourceIR, CsvSourceIR)) else TableSourceIR(table=str(source)), profiles=tuple(profiles), scan=scan_report, )
def _execute_scoped_sample( datasource: str, source: EntitySourceIR, *, selected_columns: tuple[str, ...], scope: ScanScope, project_root: Path | None, ) -> Any: """Execute a bounded sample against a datasource source and return a DataFrame.""" service = DatasourceConnectionService(project_root) with service.use_backend(datasource) as backend: expr: Any if isinstance(source, TableSourceIR): if source.database is None: expr = backend.table(source.table) else: expr = backend.table(source.table, database=source.database) elif isinstance(source, ParquetSourceIR): pq_kwargs: dict[str, object] = {} if source.hive_partitioning: pq_kwargs["hive_partitioning"] = source.hive_partitioning if source.columns is not None: pq_kwargs["columns"] = builtins.list(source.columns) expr = backend.read_parquet(source.path, **pq_kwargs) elif isinstance(source, CsvSourceIR): csv_kwargs: dict[str, object] = {} if not source.header: csv_kwargs["header"] = source.header if source.delimiter != ",": csv_kwargs["delimiter"] = source.delimiter if source.columns is not None: csv_kwargs["columns"] = builtins.list(source.columns) expr = backend.read_csv(source.path, **csv_kwargs) else: raise TypeError(f"unsupported source type: {type(source).__name__}") # Apply partition filter if scope has an explicit partition. if ( scope.partition is not None and scope.partition != "latest" and isinstance(scope.partition, Mapping) ): for column, value in scope.partition.items(): if column in expr.columns: expr = expr.filter(expr[column] == value) # Select requested columns and limit rows. if selected_columns: available = set(expr.columns) present = [col for col in selected_columns if col in available] if present: expr = expr.select(*present) expr = expr.limit(scope.max_rows) return expr.execute() def _profile_column( frame: Any, column_name: str, data_type: str, nullable: bool | None, comment: str | None, ) -> ColumnProfile: """Profile a single column from a pandas DataFrame.""" from collections import Counter from marivo.preview import normalize_preview_cell series = frame[column_name] non_null = series.dropna() null_count = int(series.isna().sum()) empty_count = 0 if series.dtype == object: empty_count = int((series.dropna() == "").sum()) # Distinct count from non-null values. distinct_count = int(non_null.nunique()) # Top values from non-null values. counter = Counter(non_null) top_values = tuple( (normalize_preview_cell(value), count) for value, count in counter.most_common(10) ) # Sample values (first 10 non-null). sample_values = tuple(normalize_preview_cell(value) for value in non_null.head(10)) # Min/max for orderable types. min_value: object | None = None max_value: object | None = None if not non_null.empty: try: min_value = normalize_preview_cell(non_null.min()) max_value = normalize_preview_cell(non_null.max()) except TypeError: min_value = None max_value = None return ColumnProfile( name=column_name, data_type=data_type, nullable=nullable, comment=comment, null_count=null_count, empty_count=empty_count, distinct_count=distinct_count, top_values=top_values, sample_values=sample_values, min_value=min_value, max_value=max_value, ) def _sample_distinct_keys( side: JoinSide, scope: ScanScope, key_sample_size: int, project_root: Path | None, ) -> tuple[builtins.list[tuple[object, ...]], ScanReport]: """Sample distinct key tuples from one join side. Returns: A pair of (distinct key tuples, scan report). """ start = time.perf_counter() frame = _execute_scoped_sample( side.datasource, side.source, selected_columns=tuple(side.columns), scope=scope, project_root=project_root, ) elapsed = time.perf_counter() - start rows_scanned = len(frame) truncated = rows_scanned >= scope.max_rows warnings: builtins.list[str] = [] # Extract distinct key tuples. key_columns = builtins.list(side.columns) seen: set[tuple[object, ...]] = set() distinct_keys: builtins.list[tuple[object, ...]] = [] for row_values in frame[key_columns].itertuples(index=False, name=None): key_tuple = tuple(row_values) if key_tuple not in seen: seen.add(key_tuple) distinct_keys.append(key_tuple) if len(distinct_keys) >= key_sample_size: break partition_resolution: str partition_used: Mapping[str, str] | None = None if scope.partition is None: partition_resolution = "unpruned" elif scope.partition == "latest": partition_resolution = "latest" else: partition_resolution = "explicit" partition_used = dict(scope.partition) scan_report = ScanReport( partition_used=partition_used, partition_resolution=partition_resolution, # type: ignore[arg-type] rows_scanned=rows_scanned, columns_scanned=tuple(side.columns), truncated=truncated, elapsed_seconds=elapsed, warnings=tuple(warnings), ) return distinct_keys, scan_report def _count_matching_keys( side: JoinSide, key_tuples: builtins.list[tuple[object, ...]], scope: ScanScope, project_root: Path | None, ) -> tuple[dict[tuple[object, ...], int], ScanReport]: """Count how many rows on the to-side match each from-side key. Returns: A pair of (key -> count mapping, scan report). """ start = time.perf_counter() frame = _execute_scoped_sample( side.datasource, side.source, selected_columns=tuple(side.columns), scope=scope, project_root=project_root, ) elapsed = time.perf_counter() - start rows_scanned = len(frame) truncated = rows_scanned >= scope.max_rows warnings: builtins.list[str] = [] # Build a lookup of key -> count from the to-side sample. key_columns = builtins.list(side.columns) from_key_set = set(key_tuples) counts: dict[tuple[object, ...], int] = {} for row_values in frame[key_columns].itertuples(index=False, name=None): key_tuple = tuple(row_values) if key_tuple in from_key_set: counts[key_tuple] = counts.get(key_tuple, 0) + 1 partition_resolution: str partition_used: Mapping[str, str] | None = None if scope.partition is None: partition_resolution = "unpruned" elif scope.partition == "latest": partition_resolution = "latest" else: partition_resolution = "explicit" partition_used = dict(scope.partition) scan_report = ScanReport( partition_used=partition_used, partition_resolution=partition_resolution, # type: ignore[arg-type] rows_scanned=rows_scanned, columns_scanned=tuple(side.columns), truncated=truncated, elapsed_seconds=elapsed, warnings=tuple(warnings), ) return counts, scan_report
[docs] def probe_join_keys( *, from_side: JoinSide, to_side: JoinSide, scope: ScanScope | None = None, key_sample_size: int = 500, project_root: Path | None = None, ) -> JoinKeyProbe: """Probe join compatibility between two sources on specified key columns. Samples distinct keys from the from-side, then counts matching rows on the to-side to estimate match rate and join cardinality. Args: from_side: The left side of the join, defining keys to probe. to_side: The right side of the join, checked for key matches. scope: Bounded scan configuration; defaults to ``ScanScope()``. key_sample_size: Maximum distinct keys to sample from the from-side. project_root: Optional project root directory; defaults to cwd. Returns: A ``JoinKeyProbe`` with match statistics and cardinality estimate. Example: >>> import marivo.datasource as md >>> md.probe_join_keys( ... from_side=md.JoinSide("wh", md.table("orders"), columns=("customer_id",)), ... to_side=md.JoinSide("wh", md.table("customers"), columns=("customer_id",)), ... scope=md.ScanScope(partition=None, max_rows=100), ... project_root=project_root, ... ) Constraints: Both from-side and to-side may reference the same or different datasources. Key comparison uses tuple equality. Matching is performed client-side after a bounded sample. """ if scope is None: scope = ScanScope() # Step 1: Sample distinct keys from the from-side. distinct_keys, from_scan = _sample_distinct_keys( from_side, scope, key_sample_size, project_root ) # Step 2: Count matching keys from the to-side. counts_by_key, to_scan = _count_matching_keys(to_side, distinct_keys, scope, project_root) # Step 3: Compute metrics. sampled_key_count = len(distinct_keys) matched_key_count = sum(1 for key_tuple in distinct_keys if key_tuple in counts_by_key) match_rate = matched_key_count / sampled_key_count if sampled_key_count > 0 else 0.0 max_rows_per_key = 0 total_rows = 0 for key_tuple in distinct_keys: count = counts_by_key.get(key_tuple, 0) total_rows += count if count > max_rows_per_key: max_rows_per_key = count avg_rows_per_key = total_rows / sampled_key_count if sampled_key_count > 0 else 0.0 # Cardinality estimate. if matched_key_count == 0: cardinality_estimate: Literal["one_to_one", "many_to_one", "indeterminate"] = ( "indeterminate" ) elif max_rows_per_key > 1: cardinality_estimate = "many_to_one" else: cardinality_estimate = "one_to_one" return JoinKeyProbe( type_compatible=True, sampled_key_count=sampled_key_count, matched_key_count=matched_key_count, match_rate=match_rate, max_rows_per_key=max_rows_per_key, avg_rows_per_key=avg_rows_per_key, cardinality_estimate=cardinality_estimate, from_scan=from_scan, to_scan=to_scan, )
[docs] def test(name: str) -> DatasourceTestResult: """Round-trip the backend and persist validated env secrets. Args: name: The datasource name to test. Returns: A ``DatasourceTestResult`` with ok/error status and latency. Example: >>> import marivo.datasource as md >>> md.test("wh") Constraints: On success, env-sourced secrets that resolved correctly are persisted to the user-global plaintext cache. The backend is always disconnected. """ start = time.perf_counter() backend: Any | None = None try: backend = connect(name) backend.raw_sql("SELECT 1") _secrets.persist_backend_env_sourced(backend) latency_ms = int((time.perf_counter() - start) * 1000) return DatasourceTestResult(name=name, ok=True, error=None, latency_ms=latency_ms) except Exception as exc: latency_ms = int((time.perf_counter() - start) * 1000) return DatasourceTestResult( name=name, ok=False, error=f"{type(exc).__name__}: {exc}", latency_ms=latency_ms, ) finally: if backend is not None: disconnect = getattr(backend, "disconnect", None) if callable(disconnect): with suppress(Exception): disconnect()