Skip to content

Classes

spark_expectations.sinks.utils.writer.SparkExpectationsWriter dataclass

This class implements/supports writing data into the sink system

Functions

generate_rules_exceeds_threshold(rules: dict) -> None

This function implements/supports summarising row dq error threshold Args: rules: accepts rule metadata within dict Returns: None

Source code in spark_expectations/sinks/utils/writer.py
def generate_rules_exceeds_threshold(self, rules: dict) -> None:
    """
    This function implements/supports summarising row dq error threshold
    Args:
        rules: accepts rule metadata within dict
    Returns:
        None
    """
    try:
        error_threshold_list = []
        rules_failed_row_count: Dict[str, int] = {}
        if self._context.get_summarised_row_dq_res is None:
            return None

        rules_failed_row_count = {
            itr["rule"]: int(itr["failed_row_count"])
            for itr in self._context.get_summarised_row_dq_res
        }

        for rule in rules[f"{self._context.get_row_dq_rule_type_name}_rules"]:
            # if (
            #         not rule["enable_error_drop_alert"]
            #         or rule["rule"] not in rules_failed_row_count.keys()
            # ):
            #     continue  # pragma: no cover

            rule_name = rule["rule"]
            rule_action = rule["action_if_failed"]
            if rule_name in rules_failed_row_count.keys():
                failed_row_count = int(rules_failed_row_count[rule_name])
            else:
                failed_row_count = 0

            if failed_row_count is not None and failed_row_count > 0:
                error_drop_percentage = round(
                    (failed_row_count / self._context.get_input_count) * 100, 2
                )
                error_threshold_list.append(
                    {
                        "rule_name": rule_name,
                        "action_if_failed": rule_action,
                        "description": rule["description"],
                        "rule_type": rule["rule_type"],
                        "error_drop_threshold": str(rule["error_drop_threshold"]),
                        "error_drop_percentage": str(error_drop_percentage),
                    }
                )

        if len(error_threshold_list) > 0:
            self._context.set_rules_exceeds_threshold(error_threshold_list)

    except Exception as e:
        raise SparkExpectationsMiscException(
            f"An error occurred while creating error threshold list : {e}"
        )

generate_summarised_row_dq_res(df: DataFrame, rule_type: str) -> None

This function implements/supports summarising row dq error result Args: df: error dataframe(DataFrame) rule_type: type of the rule(str)

Returns:

Type Description
None

None

Source code in spark_expectations/sinks/utils/writer.py
def generate_summarised_row_dq_res(self, df: DataFrame, rule_type: str) -> None:
    """
    This function implements/supports summarising row dq error result
    Args:
        df: error dataframe(DataFrame)
        rule_type: type of the rule(str)

    Returns:
        None

    """
    try:

        def update_dict(accumulator: dict) -> dict:  # pragma: no cover
            if accumulator.get("failed_row_count") is None:  # pragma: no cover
                accumulator["failed_row_count"] = str(2)  # pragma: no cover
            else:  # pragma: no cover
                accumulator["failed_row_count"] = str(  # pragma: no cover
                    int(accumulator["failed_row_count"]) + 1  # pragma: no cover
                )  # pragma: no cover

            return accumulator  # pragma: no cover

        summarised_row_dq_dict: Dict[str, Dict[str, str]] = (
            df.select(explode(f"meta_{rule_type}_results").alias("row_dq_res"))
            .rdd.map(
                lambda rule_meta_dict: (
                    rule_meta_dict[0]["rule"],
                    {**rule_meta_dict[0], "failed_row_count": 1},
                )
            )
            .reduceByKey(lambda acc, itr: update_dict(acc))
        ).collectAsMap()

        self._context.set_summarised_row_dq_res(
            list(summarised_row_dq_dict.values())
        )

    except Exception as e:
        raise SparkExpectationsMiscException(
            f"error occurred created summarised row dq statistics {e}"
        )

save_df_as_table(df: DataFrame, table_name: str, config: dict, stats_table: bool = False) -> None

This function takes a dataframe and writes into a table

Parameters:

Name Type Description Default
df DataFrame

Provide the dataframe which need to be written as a table

required
table_name str

Provide the table name to which the dataframe need to be written to

required
config dict

Provide the config to write the dataframe into the table

required
stats_table bool

Provide if this is for writing stats table

False

Returns:

Name Type Description
None None
Source code in spark_expectations/sinks/utils/writer.py
def save_df_as_table(
    self, df: DataFrame, table_name: str, config: dict, stats_table: bool = False
) -> None:
    """
    This function takes a dataframe and writes into a table

    Args:
        df: Provide the dataframe which need to be written as a table
        table_name: Provide the table name to which the dataframe need to be written to
        config: Provide the config to write the dataframe into the table
        stats_table: Provide if this is for writing stats table

    Returns:
        None:

    """
    try:
        print("run date ", self._context.get_run_date)
        if not stats_table:
            df = df.withColumn(
                self._context.get_run_id_name, lit(f"{self._context.get_run_id}")
            ).withColumn(
                self._context.get_run_date_name,
                to_timestamp(
                    lit(f"{self._context.get_run_date}"), "yyyy-MM-dd HH:mm:ss"
                ),
            )
        _log.info("_save_df_as_table started")

        _df_writer = df.write

        if config["mode"] is not None:
            _df_writer = _df_writer.mode(config["mode"])
        if config["format"] is not None:
            _df_writer = _df_writer.format(config["format"])
        if config["partitionBy"] is not None and config["partitionBy"] != []:
            _df_writer = _df_writer.partitionBy(config["partitionBy"])
        if config["sortBy"] is not None and config["sortBy"] != []:
            _df_writer = _df_writer.sortBy(config["sortBy"])
        if config["bucketBy"] is not None and config["bucketBy"] != {}:
            bucket_by_config = config["bucketBy"]
            _df_writer = _df_writer.bucketBy(
                bucket_by_config["numBuckets"], bucket_by_config["colName"]
            )
        if config["options"] is not None and config["options"] != {}:
            _df_writer = _df_writer.options(**config["options"])

        if config["format"] == "bigquery":
            _df_writer.option("table", table_name).save()
        else:
            _df_writer.saveAsTable(name=table_name)
            self.spark.sql(
                f"ALTER TABLE {table_name} SET TBLPROPERTIES ('product_id' = '{self._context.product_id}')"
            )
        _log.info("finished writing records to table: %s,", table_name)

    except Exception as e:
        raise SparkExpectationsUserInputOrConfigInvalidException(
            f"error occurred while writing data in to the table - {table_name}: {e}"
        )

write_error_records_final(df: DataFrame, error_table: str, rule_type: str) -> Tuple[int, DataFrame]

Source code in spark_expectations/sinks/utils/writer.py
def write_error_records_final(
    self, df: DataFrame, error_table: str, rule_type: str
) -> Tuple[int, DataFrame]:
    try:
        _log.info("_write_error_records_final started")

        failed_records = [
            f"size({dq_column}) != 0"
            for dq_column in df.columns
            if dq_column.startswith(f"{rule_type}")
        ]

        failed_records_rules = " or ".join(failed_records)
        # df = df.filter(expr(failed_records_rules))

        df = df.withColumn(
            f"meta_{rule_type}_results",
            when(
                expr(failed_records_rules),
                array(
                    *[
                        _col
                        for _col in df.columns
                        if _col.startswith(f"{rule_type}")
                    ]
                ),
            ).otherwise(array(create_map())),
        ).drop(*[_col for _col in df.columns if _col.startswith(f"{rule_type}")])

        df = (
            df.withColumn(
                f"meta_{rule_type}_results",
                remove_empty_maps(df[f"meta_{rule_type}_results"]),
            )
            .withColumn(
                self._context.get_run_id_name, lit(self._context.get_run_id)
            )
            .withColumn(
                self._context.get_run_date_name,
                lit(self._context.get_run_date),
            )
        )
        error_df = df.filter(f"size(meta_{rule_type}_results) != 0")
        self._context.print_dataframe_with_debugger(error_df)

        self.save_df_as_table(
            error_df,
            error_table,
            self._context.get_target_and_error_table_writer_config,
        )

        _error_count = error_df.count()
        if _error_count > 0:
            self.generate_summarised_row_dq_res(error_df, rule_type)

        _log.info("_write_error_records_final ended")
        return _error_count, df

    except Exception as e:
        raise SparkExpectationsMiscException(
            f"error occurred while saving data into the final error table {e}"
        )

write_error_stats() -> None

This functions takes the stats table and write it into error table

Parameters:

Name Type Description Default
config

Provide the config to write the dataframe into the table

required

Returns:

Name Type Description
None None
Source code in spark_expectations/sinks/utils/writer.py
def write_error_stats(self) -> None:
    """
    This functions takes the stats table and write it into error table

    Args:
        config: Provide the config to write the dataframe into the table

    Returns:
        None:

    """
    try:
        self.spark.conf.set("spark.sql.session.timeZone", "Etc/UTC")
        from datetime import date

        table_name: str = self._context.get_table_name
        input_count: int = self._context.get_input_count
        error_count: int = self._context.get_error_count
        output_count: int = self._context.get_output_count
        source_agg_dq_result: Optional[
            List[Dict[str, str]]
        ] = self._context.get_source_agg_dq_result
        final_agg_dq_result: Optional[
            List[Dict[str, str]]
        ] = self._context.get_final_agg_dq_result
        source_query_dq_result: Optional[
            List[Dict[str, str]]
        ] = self._context.get_source_query_dq_result
        final_query_dq_result: Optional[
            List[Dict[str, str]]
        ] = self._context.get_final_query_dq_result

        error_stats_data = [
            (
                self._context.product_id,
                table_name,
                input_count,
                error_count,
                output_count,
                self._context.get_output_percentage,
                self._context.get_success_percentage,
                self._context.get_error_percentage,
                source_agg_dq_result
                if source_agg_dq_result and len(source_agg_dq_result) > 0
                else None,
                final_agg_dq_result
                if final_agg_dq_result and len(final_agg_dq_result) > 0
                else None,
                source_query_dq_result
                if source_query_dq_result and len(source_query_dq_result) > 0
                else None,
                final_query_dq_result
                if final_query_dq_result and len(final_query_dq_result) > 0
                else None,
                self._context.get_summarised_row_dq_res,
                self._context.get_rules_exceeds_threshold,
                {
                    "run_status": self._context.get_dq_run_status,
                    "source_agg_dq": self._context.get_source_agg_dq_status,
                    "source_query_dq": self._context.get_source_query_dq_status,
                    "row_dq": self._context.get_row_dq_status,
                    "final_agg_dq": self._context.get_final_agg_dq_status,
                    "final_query_dq": self._context.get_final_query_dq_status,
                },
                {
                    "run_time": self._context.get_dq_run_time,
                    "source_agg_dq_run_time": self._context.get_source_agg_dq_run_time,
                    "source_query_dq_run_time": self._context.get_source_query_dq_run_time,
                    "row_dq_run_time": self._context.get_row_dq_run_time,
                    "final_agg_dq_run_time": self._context.get_final_agg_dq_run_time,
                    "final_query_dq_run_time": self._context.get_final_query_dq_run_time,
                },
                {
                    "rules": {
                        "num_row_dq_rules": self._context.get_num_row_dq_rules,
                        "num_dq_rules": self._context.get_num_dq_rules,
                    },
                    "agg_dq_rules": self._context.get_num_agg_dq_rules,
                    "query_dq_rules": self._context.get_num_query_dq_rules,
                },
                self._context.get_run_id,
                date.fromisoformat(self._context.get_run_date[0:10]),
                datetime.strptime(
                    self._context.get_run_date,
                    "%Y-%m-%d %H:%M:%S",
                ),
            )
        ]
        error_stats_rdd = self.spark.sparkContext.parallelize(error_stats_data)

        from pyspark.sql.types import (
            StructType,
            StructField,
            StringType,
            IntegerType,
            LongType,
            FloatType,
            DateType,
            ArrayType,
            MapType,
            TimestampType,
        )

        error_stats_schema = StructType(
            [
                StructField("product_id", StringType(), True),
                StructField("table_name", StringType(), True),
                StructField("input_count", LongType(), True),
                StructField("error_count", LongType(), True),
                StructField("output_count", LongType(), True),
                StructField("output_percentage", FloatType(), True),
                StructField("success_percentage", FloatType(), True),
                StructField("error_percentage", FloatType(), True),
                StructField(
                    "source_agg_dq_results",
                    ArrayType(MapType(StringType(), StringType())),
                    True,
                ),
                StructField(
                    "final_agg_dq_results",
                    ArrayType(MapType(StringType(), StringType())),
                    True,
                ),
                StructField(
                    "source_query_dq_results",
                    ArrayType(MapType(StringType(), StringType())),
                    True,
                ),
                StructField(
                    "final_query_dq_results",
                    ArrayType(MapType(StringType(), StringType())),
                    True,
                ),
                StructField(
                    "row_dq_res_summary",
                    ArrayType(MapType(StringType(), StringType())),
                    True,
                ),
                StructField(
                    "row_dq_error_threshold",
                    ArrayType(MapType(StringType(), StringType())),
                    True,
                ),
                StructField("dq_status", MapType(StringType(), StringType()), True),
                StructField(
                    "dq_run_time", MapType(StringType(), FloatType()), True
                ),
                StructField(
                    "dq_rules",
                    MapType(StringType(), MapType(StringType(), IntegerType())),
                    True,
                ),
                StructField(self._context.get_run_id_name, StringType(), True),
                StructField(self._context.get_run_date_name, DateType(), True),
                StructField(
                    self._context.get_run_date_time_name, TimestampType(), True
                ),
            ]
        )

        df = self.spark.createDataFrame(error_stats_rdd, schema=error_stats_schema)
        self._context.print_dataframe_with_debugger(df)

        df = (
            df.withColumn("output_percentage", sql_round(df.output_percentage, 2))
            .withColumn("success_percentage", sql_round(df.success_percentage, 2))
            .withColumn("error_percentage", sql_round(df.error_percentage, 2))
        )
        _log.info(
            "Writing metrics to the stats table: %s, started",
            self._context.get_dq_stats_table_name,
        )
        if self._context.get_stats_table_writer_config["format"] == "bigquery":
            df = df.withColumn("dq_rules", to_json(df["dq_rules"]))

        self.save_df_as_table(
            df,
            self._context.get_dq_stats_table_name,
            config=self._context.get_stats_table_writer_config,
            stats_table=True,
        )

        _log.info(
            "Writing metrics to the stats table: %s, ended",
            {self._context.get_dq_stats_table_name},
        )

        # TODO check if streaming_stats is set to off, if it's enabled only then this should run

        _se_stats_dict = self._context.get_se_streaming_stats_dict
        if _se_stats_dict["se.enable.streaming"]:
            secret_handler = SparkExpectationsSecretsBackend(
                secret_dict=_se_stats_dict
            )
            kafka_write_options: dict = (
                {
                    "kafka.bootstrap.servers": "localhost:9092",
                    "topic": self._context.get_se_streaming_stats_topic_name,
                    "failOnDataLoss": "true",
                }
                if self._context.get_env == "local"
                else (
                    {
                        "kafka.bootstrap.servers": f"{secret_handler.get_secret(self._context.get_server_url_key)}",
                        "kafka.security.protocol": "SASL_SSL",
                        "kafka.sasl.mechanism": "OAUTHBEARER",
                        "kafka.sasl.jaas.config": "kafkashaded.org.apache.kafka.common.security.oauthbearer."
                        "OAuthBearerLoginModule required oauth.client.id="
                        f"'{secret_handler.get_secret(self._context.get_client_id)}'  "
                        + "oauth.client.secret="
                        f"'{secret_handler.get_secret(self._context.get_token)}' "
                        "oauth.token.endpoint.uri="
                        f"'{secret_handler.get_secret(self._context.get_token_endpoint_url)}'; ",
                        "kafka.sasl.login.callback.handler.class": "io.strimzi.kafka.oauth.client"
                        ".JaasClientOauthLoginCallbackHandler",
                        "topic": (
                            self._context.get_se_streaming_stats_topic_name
                            if self._context.get_env == "local"
                            else secret_handler.get_secret(
                                self._context.get_topic_name
                            )
                        ),
                    }
                )
            )

            _sink_hook.writer(
                _write_args={
                    "product_id": self._context.product_id,
                    "enable_se_streaming": _se_stats_dict[
                        user_config.se_enable_streaming
                    ],
                    "kafka_write_options": kafka_write_options,
                    "stats_df": df,
                }
            )
        else:
            _log.info(
                "Streaming stats to kafka is disabled, hence skipping writing to kafka"
            )

    except Exception as e:
        raise SparkExpectationsMiscException(
            f"error occurred while saving the data into the stats table {e}"
        )

Functions