def _build_spark_partition_snapshot(df, *, dataset_name: str, table_name: str, partition_column: str, business_keys: list[str], watermark_column: str | None, run_id: str | None) -> list[dict]:
from pyspark.sql import functions as F
generated_at = datetime.now(timezone.utc).isoformat()
key_cols = [F.coalesce(F.col(c).cast("string"), F.lit("")) for c in business_keys]
with_key = df.withColumn("_business_key_row_hash", F.sha2(F.concat_ws("||", *key_cols), 256))
agg_exprs = [
F.count(F.lit(1)).alias("row_count"),
F.countDistinct(*[F.col(c) for c in business_keys]).alias("business_key_count"),
F.sha2(F.concat_ws("##", F.sort_array(F.collect_set(F.col("_business_key_row_hash")))), 256).alias("business_key_hash"),
]
if watermark_column:
agg_exprs.extend([F.max(F.col(watermark_column)).alias("max_watermark"), F.min(F.col(watermark_column)).alias("min_watermark")])
else:
agg_exprs.extend([F.lit(None).alias("max_watermark"), F.lit(None).alias("min_watermark")])
snapshot_df = with_key.groupBy(F.col(partition_column)).agg(*agg_exprs)
collected = snapshot_df.collect()
rows = []
for row in collected:
part_val = row[partition_column]
max_w = to_jsonable(row["max_watermark"])
min_w = to_jsonable(row["min_watermark"])
bkh = str(row["business_key_hash"])
rows.append(
{
"dataset_name": str(dataset_name),
"table_name": str(table_name),
"run_id": run_id,
"engine": "spark",
"generated_at": generated_at,
"partition_column": str(partition_column),
"partition_value": to_jsonable(part_val),
"row_count": int(row["row_count"]),
"business_key_count": int(row["business_key_count"]),
"max_watermark": max_w,
"min_watermark": min_w,
"partition_hash": _build_partition_hash(part_val, int(row["row_count"]), int(row["business_key_count"]), max_w, min_w, bkh),
"business_key_hash": bkh,
}
)
return sorted(rows, key=lambda r: str(r["partition_value"]))