Source code for synapse.ml.causal.DiffInDiffModel

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys

if sys.version >= "3":
    basestring = str

from synapse.ml.causal._DiffInDiffModel import _DiffInDiffModel
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession, DataFrame
from pyspark import SparkContext, SQLContext


[docs]@inherit_doc class DiffInDiffModel(_DiffInDiffModel): @staticmethod def _mapOption(option, func): return func(option.get()) if option.isDefined() else None @staticmethod def _unwrapOption(option): return DiffInDiffModel._mapOption(option, lambda x: x) def __init__(self, java_obj=None) -> None: super(DiffInDiffModel, self).__init__(java_obj=java_obj) ctx = SparkContext._active_spark_context sql_ctx = SQLContext.getOrCreate(ctx) self.summary = java_obj.getSummary() self.treatmentEffect = self.summary.treatmentEffect() self.standardError = self.summary.standardError() self.timeIntercept = DiffInDiffModel._unwrapOption(self.summary.timeIntercept()) self.unitIntercept = DiffInDiffModel._unwrapOption(self.summary.unitIntercept()) self.timeWeights = DiffInDiffModel._mapOption( java_obj.getTimeWeights(), lambda x: DataFrame(x, sql_ctx) ) self.unitWeights = DiffInDiffModel._mapOption( java_obj.getUnitWeights(), lambda x: DataFrame(x, sql_ctx) ) self.timeRMSE = DiffInDiffModel._unwrapOption(self.summary.timeRMSE()) self.unitRMSE = DiffInDiffModel._unwrapOption(self.summary.unitRMSE()) self.zeta = DiffInDiffModel._unwrapOption(self.summary.zeta()) self.lossHistoryTimeWeights = DiffInDiffModel._unwrapOption( self.summary.getLossHistoryTimeWeightsJava() ) self.lossHistoryUnitWeights = DiffInDiffModel._unwrapOption( self.summary.getLossHistoryUnitWeightsJava() )