Source code for synapse.ml.causal.SyntheticDiffInDiffEstimator

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.


import sys
if sys.version >= '3':
    basestring = str

from pyspark import SparkContext, SQLContext
from pyspark.sql import DataFrame
from pyspark.ml.param.shared import *
from pyspark import keyword_only
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from synapse.ml.core.platform import running_on_synapse_internal
from synapse.ml.core.serialize.java_params_patch import *
from pyspark.ml.wrapper import JavaTransformer, JavaEstimator, JavaModel
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.common import inherit_doc
from synapse.ml.core.schema.Utils import *
from pyspark.ml.param import TypeConverters
from synapse.ml.core.schema.TypeConversionUtils import generateTypeConverter, complexTypeConverter
from synapse.ml.causal.DiffInDiffModel import DiffInDiffModel

[docs]@inherit_doc class SyntheticDiffInDiffEstimator(ComplexParamsMixin, JavaMLReadable, JavaMLWritable, JavaEstimator): """ Args: epsilon (float): This value is added to the weights when we fit the final linear model for SyntheticControlEstimator and SyntheticDiffInDiffEstimator in order to avoid zero weights. handleMissingOutcome (str): How to handle missing outcomes. Options are skip (which will filter out units with missing outcomes), zero (fill in missing outcomes with zero), or impute (impute with nearest available outcomes, or mean if two nearest outcomes are available) localSolverThreshold (long): threshold for using local solver on driver node. Local solver is faster but relies on part of data being collected on driver node. maxIter (int): maximum number of iterations (>= 0) numIterNoChange (int): Early termination when number of iterations without change reached. outcomeCol (str): outcome column postTreatmentCol (str): post treatment indicator column stepSize (float): Step size to be used for each iteration of optimization (> 0) timeCol (str): Specify the column that identifies the time when outcome is measured in the panel data. For example, if the outcome is measured daily, this column could be the Date column. tol (float): the convergence tolerance for iterative algorithms (>= 0) treatmentCol (str): treatment column unitCol (str): Specify the name of the column which contains an identifier for each observed unit in the panel data. For example, if the observed units are users, this column could be the UserId column. zeta (float): The zeta value for regularization term when fitting unit weights. If not specified, a default value will be computed based on formula (2.2) specified in https://www.nber.org/system/files/working_papers/w25532/w25532.pdf. For large scale data, one may want to tune the zeta value, minimizing the loss of the unit weights regression. """ epsilon = Param(Params._dummy(), "epsilon", "This value is added to the weights when we fit the final linear model for SyntheticControlEstimator and SyntheticDiffInDiffEstimator in order to avoid zero weights.", typeConverter=TypeConverters.toFloat) handleMissingOutcome = Param(Params._dummy(), "handleMissingOutcome", "How to handle missing outcomes. Options are skip (which will filter out units with missing outcomes), zero (fill in missing outcomes with zero), or impute (impute with nearest available outcomes, or mean if two nearest outcomes are available)", typeConverter=TypeConverters.toString) localSolverThreshold = Param(Params._dummy(), "localSolverThreshold", "threshold for using local solver on driver node. Local solver is faster but relies on part of data being collected on driver node.") maxIter = Param(Params._dummy(), "maxIter", "maximum number of iterations (>= 0)", typeConverter=TypeConverters.toInt) numIterNoChange = Param(Params._dummy(), "numIterNoChange", "Early termination when number of iterations without change reached.", typeConverter=TypeConverters.toInt) outcomeCol = Param(Params._dummy(), "outcomeCol", "outcome column", typeConverter=TypeConverters.toString) postTreatmentCol = Param(Params._dummy(), "postTreatmentCol", "post treatment indicator column", typeConverter=TypeConverters.toString) stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization (> 0)", typeConverter=TypeConverters.toFloat) timeCol = Param(Params._dummy(), "timeCol", "Specify the column that identifies the time when outcome is measured in the panel data. For example, if the outcome is measured daily, this column could be the Date column.", typeConverter=TypeConverters.toString) tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms (>= 0)", typeConverter=TypeConverters.toFloat) treatmentCol = Param(Params._dummy(), "treatmentCol", "treatment column", typeConverter=TypeConverters.toString) unitCol = Param(Params._dummy(), "unitCol", "Specify the name of the column which contains an identifier for each observed unit in the panel data. For example, if the observed units are users, this column could be the UserId column.", typeConverter=TypeConverters.toString) zeta = Param(Params._dummy(), "zeta", "The zeta value for regularization term when fitting unit weights. If not specified, a default value will be computed based on formula (2.2) specified in https://www.nber.org/system/files/working_papers/w25532/w25532.pdf. For large scale data, one may want to tune the zeta value, minimizing the loss of the unit weights regression.", typeConverter=TypeConverters.toFloat) @keyword_only def __init__( self, java_obj=None, epsilon=1.0E-10, handleMissingOutcome="zero", localSolverThreshold=1000000, maxIter=100, numIterNoChange=None, outcomeCol=None, postTreatmentCol=None, stepSize=1.0, timeCol=None, tol=0.001, treatmentCol=None, unitCol=None, zeta=None ): super(SyntheticDiffInDiffEstimator, self).__init__() if java_obj is None: self._java_obj = self._new_java_obj("com.microsoft.azure.synapse.ml.causal.SyntheticDiffInDiffEstimator", self.uid) else: self._java_obj = java_obj self._setDefault(epsilon=1.0E-10) self._setDefault(handleMissingOutcome="zero") self._setDefault(localSolverThreshold=1000000) self._setDefault(maxIter=100) self._setDefault(stepSize=1.0) self._setDefault(tol=0.001) if hasattr(self, "_input_kwargs"): kwargs = self._input_kwargs else: kwargs = self.__init__._input_kwargs if java_obj is None: for k,v in kwargs.items(): if v is not None: getattr(self, "set" + k[0].upper() + k[1:])(v)
[docs] @keyword_only def setParams( self, epsilon=1.0E-10, handleMissingOutcome="zero", localSolverThreshold=1000000, maxIter=100, numIterNoChange=None, outcomeCol=None, postTreatmentCol=None, stepSize=1.0, timeCol=None, tol=0.001, treatmentCol=None, unitCol=None, zeta=None ): """ Set the (keyword only) parameters """ if hasattr(self, "_input_kwargs"): kwargs = self._input_kwargs else: kwargs = self.__init__._input_kwargs return self._set(**kwargs)
[docs] @classmethod def read(cls): """ Returns an MLReader instance for this class. """ return JavaMMLReader(cls)
[docs] @staticmethod def getJavaPackage(): """ Returns package name String. """ return "com.microsoft.azure.synapse.ml.causal.SyntheticDiffInDiffEstimator"
@staticmethod def _from_java(java_stage): module_name=SyntheticDiffInDiffEstimator.__module__ module_name=module_name.rsplit(".", 1)[0] + ".SyntheticDiffInDiffEstimator" return from_java(java_stage, module_name)
[docs] def setEpsilon(self, value): """ Args: epsilon: This value is added to the weights when we fit the final linear model for SyntheticControlEstimator and SyntheticDiffInDiffEstimator in order to avoid zero weights. """ self._set(epsilon=value) return self
[docs] def setHandleMissingOutcome(self, value): """ Args: handleMissingOutcome: How to handle missing outcomes. Options are skip (which will filter out units with missing outcomes), zero (fill in missing outcomes with zero), or impute (impute with nearest available outcomes, or mean if two nearest outcomes are available) """ self._set(handleMissingOutcome=value) return self
[docs] def setLocalSolverThreshold(self, value): """ Args: localSolverThreshold: threshold for using local solver on driver node. Local solver is faster but relies on part of data being collected on driver node. """ self._set(localSolverThreshold=value) return self
[docs] def setMaxIter(self, value): """ Args: maxIter: maximum number of iterations (>= 0) """ self._set(maxIter=value) return self
[docs] def setNumIterNoChange(self, value): """ Args: numIterNoChange: Early termination when number of iterations without change reached. """ self._set(numIterNoChange=value) return self
[docs] def setOutcomeCol(self, value): """ Args: outcomeCol: outcome column """ self._set(outcomeCol=value) return self
[docs] def setPostTreatmentCol(self, value): """ Args: postTreatmentCol: post treatment indicator column """ self._set(postTreatmentCol=value) return self
[docs] def setStepSize(self, value): """ Args: stepSize: Step size to be used for each iteration of optimization (> 0) """ self._set(stepSize=value) return self
[docs] def setTimeCol(self, value): """ Args: timeCol: Specify the column that identifies the time when outcome is measured in the panel data. For example, if the outcome is measured daily, this column could be the Date column. """ self._set(timeCol=value) return self
[docs] def setTol(self, value): """ Args: tol: the convergence tolerance for iterative algorithms (>= 0) """ self._set(tol=value) return self
[docs] def setTreatmentCol(self, value): """ Args: treatmentCol: treatment column """ self._set(treatmentCol=value) return self
[docs] def setUnitCol(self, value): """ Args: unitCol: Specify the name of the column which contains an identifier for each observed unit in the panel data. For example, if the observed units are users, this column could be the UserId column. """ self._set(unitCol=value) return self
[docs] def setZeta(self, value): """ Args: zeta: The zeta value for regularization term when fitting unit weights. If not specified, a default value will be computed based on formula (2.2) specified in https://www.nber.org/system/files/working_papers/w25532/w25532.pdf. For large scale data, one may want to tune the zeta value, minimizing the loss of the unit weights regression. """ self._set(zeta=value) return self
[docs] def getEpsilon(self): """ Returns: epsilon: This value is added to the weights when we fit the final linear model for SyntheticControlEstimator and SyntheticDiffInDiffEstimator in order to avoid zero weights. """ return self.getOrDefault(self.epsilon)
[docs] def getHandleMissingOutcome(self): """ Returns: handleMissingOutcome: How to handle missing outcomes. Options are skip (which will filter out units with missing outcomes), zero (fill in missing outcomes with zero), or impute (impute with nearest available outcomes, or mean if two nearest outcomes are available) """ return self.getOrDefault(self.handleMissingOutcome)
[docs] def getLocalSolverThreshold(self): """ Returns: localSolverThreshold: threshold for using local solver on driver node. Local solver is faster but relies on part of data being collected on driver node. """ return self.getOrDefault(self.localSolverThreshold)
[docs] def getMaxIter(self): """ Returns: maxIter: maximum number of iterations (>= 0) """ return self.getOrDefault(self.maxIter)
[docs] def getNumIterNoChange(self): """ Returns: numIterNoChange: Early termination when number of iterations without change reached. """ return self.getOrDefault(self.numIterNoChange)
[docs] def getOutcomeCol(self): """ Returns: outcomeCol: outcome column """ return self.getOrDefault(self.outcomeCol)
[docs] def getPostTreatmentCol(self): """ Returns: postTreatmentCol: post treatment indicator column """ return self.getOrDefault(self.postTreatmentCol)
[docs] def getStepSize(self): """ Returns: stepSize: Step size to be used for each iteration of optimization (> 0) """ return self.getOrDefault(self.stepSize)
[docs] def getTimeCol(self): """ Returns: timeCol: Specify the column that identifies the time when outcome is measured in the panel data. For example, if the outcome is measured daily, this column could be the Date column. """ return self.getOrDefault(self.timeCol)
[docs] def getTol(self): """ Returns: tol: the convergence tolerance for iterative algorithms (>= 0) """ return self.getOrDefault(self.tol)
[docs] def getTreatmentCol(self): """ Returns: treatmentCol: treatment column """ return self.getOrDefault(self.treatmentCol)
[docs] def getUnitCol(self): """ Returns: unitCol: Specify the name of the column which contains an identifier for each observed unit in the panel data. For example, if the observed units are users, this column could be the UserId column. """ return self.getOrDefault(self.unitCol)
[docs] def getZeta(self): """ Returns: zeta: The zeta value for regularization term when fitting unit weights. If not specified, a default value will be computed based on formula (2.2) specified in https://www.nber.org/system/files/working_papers/w25532/w25532.pdf. For large scale data, one may want to tune the zeta value, minimizing the loss of the unit weights regression. """ return self.getOrDefault(self.zeta)
def _create_model(self, java_model): try: model = DiffInDiffModel(java_obj=java_model) model._transfer_params_from_java() except TypeError: model = DiffInDiffModel._from_java(java_model) return model def _fit(self, dataset): java_model = self._fit_java(dataset) return self._create_model(java_model)