class DoubleMLEstimator extends Estimator[DoubleMLModel] with ComplexParamsWritable with DoubleMLParams with SynapseMLLogging with Wrappable

Double ML estimators. The estimator follows the two stage process, where a set of nuisance functions are estimated in the first stage in a cross-fitting manner and a final stage estimates the average treatment effect (ATE) model. Our goal is to estimate the constant marginal ATE Theta(X)

In this estimator, the ATE is estimated by using the following estimating equations: .. math :: Y - \\E[Y | X, W] = \\Theta(X) \\cdot (T - \\E[T | X, W]) + \\epsilon

Thus if we estimate the nuisance functions :math:q(X, W) = \\E[Y | X, W] and :math:f(X, W)=\\E[T | X, W] in the first stage, we can estimate the final stage ate for each treatment t, by running a regression, minimizing the residual on residual square loss, estimating Theta(X) is a final regression problem, regressing tilde{Y} on X and tilde{T})

.. math :: \\hat{\\theta} = \\arg\\min_{\\Theta}\ \E_n\\left[ (\\tilde{Y} - \\Theta(X) \\cdot \\tilde{T})^2 \\right]

Where \\tilde{Y}=Y - \\E[Y | X, W] and :math:\\tilde{T}=T-\\E[T | X, W] denotes the residual outcome and residual treatment.

The nuisance function :math:q is a simple machine learning problem and user can use setOutcomeModel to set an arbitrary sparkML model that is internally used to solve this problem

The problem of estimating the nuisance function :math:f is also a machine learning problem and user can use setTreatmentModel to set an arbitrary sparkML model that is internally used to solve this problem.

Linear Supertypes
Wrappable, RWrappable, PythonWrappable, BaseWrappable, SynapseMLLogging, DoubleMLParams, HasParallelismInjected, HasParallelism, HasWeightCol, HasMaxIter, HasFeaturesCol, HasOutcomeCol, HasTreatmentCol, ComplexParamsWritable, MLWritable, Estimator[DoubleMLModel], PipelineStage, Logging, Params, Serializable, Serializable, Identifiable, AnyRef, Any
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. DoubleMLEstimator
  2. Wrappable
  3. RWrappable
  4. PythonWrappable
  5. BaseWrappable
  6. SynapseMLLogging
  7. DoubleMLParams
  8. HasParallelismInjected
  9. HasParallelism
  10. HasWeightCol
  11. HasMaxIter
  12. HasFeaturesCol
  13. HasOutcomeCol
  14. HasTreatmentCol
  15. ComplexParamsWritable
  16. MLWritable
  17. Estimator
  18. PipelineStage
  19. Logging
  20. Params
  21. Serializable
  22. Serializable
  23. Identifiable
  24. AnyRef
  25. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new DoubleMLEstimator()
  2. new DoubleMLEstimator(uid: String)

Value Members

  1. final def clear(param: Param[_]): DoubleMLEstimator.this.type
    Definition Classes
    Params
  2. val confidenceLevel: DoubleParam
    Definition Classes
    DoubleMLParams
  3. def copy(extra: ParamMap): Estimator[DoubleMLModel]
    Definition Classes
    DoubleMLEstimator → Estimator → PipelineStage → Params
  4. def explainParam(param: Param[_]): String
    Definition Classes
    Params
  5. def explainParams(): String
    Definition Classes
    Params
  6. final def extractParamMap(): ParamMap
    Definition Classes
    Params
  7. final def extractParamMap(extra: ParamMap): ParamMap
    Definition Classes
    Params
  8. val featuresCol: Param[String]

    The name of the features column

    The name of the features column

    Definition Classes
    HasFeaturesCol
  9. def fit(dataset: Dataset[_]): DoubleMLModel

    Fits the DoubleML model.

    Fits the DoubleML model.

    dataset

    The input dataset to train.

    returns

    The trained DoubleML model, from which you can get Ate and Ci values

    Definition Classes
    DoubleMLEstimator → Estimator
  10. def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[DoubleMLModel]
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" )
  11. def fit(dataset: Dataset[_], paramMap: ParamMap): DoubleMLModel
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" )
  12. def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DoubleMLModel
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" ) @varargs()
  13. final def get[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  14. def getConfidenceLevel: Double
    Definition Classes
    DoubleMLParams
  15. final def getDefault[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  16. def getExecutionContextProxy: ExecutionContext
    Definition Classes
    HasParallelismInjected
  17. def getFeaturesCol: String

    Definition Classes
    HasFeaturesCol
  18. final def getMaxIter: Int
    Definition Classes
    HasMaxIter
  19. final def getOrDefault[T](param: Param[T]): T
    Definition Classes
    Params
  20. def getOutcomeCol: String
    Definition Classes
    HasOutcomeCol
  21. def getOutcomeModel: Estimator[_ <: Model[_]]
    Definition Classes
    DoubleMLParams
  22. def getParallelism: Int
    Definition Classes
    HasParallelism
  23. def getParam(paramName: String): Param[Any]
    Definition Classes
    Params
  24. def getParamInfo(p: Param[_]): ParamInfo[_]
    Definition Classes
    BaseWrappable
  25. def getSampleSplitRatio: Array[Double]
    Definition Classes
    DoubleMLParams
  26. def getTreatmentCol: String
    Definition Classes
    HasTreatmentCol
  27. def getTreatmentModel: Estimator[_ <: Model[_]]
    Definition Classes
    DoubleMLParams
  28. def getWeightCol: String

    Definition Classes
    HasWeightCol
  29. final def hasDefault[T](param: Param[T]): Boolean
    Definition Classes
    Params
  30. def hasParam(paramName: String): Boolean
    Definition Classes
    Params
  31. final def isDefined(param: Param[_]): Boolean
    Definition Classes
    Params
  32. final def isSet(param: Param[_]): Boolean
    Definition Classes
    Params
  33. def logClass(featureName: String): Unit
    Definition Classes
    SynapseMLLogging
  34. def logFit[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  35. def logTransform[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  36. def logVerb[T](verb: String, f: ⇒ T, columns: Option[Int] = None): T
    Definition Classes
    SynapseMLLogging
  37. def makePyFile(conf: CodegenConfig): Unit
    Definition Classes
    PythonWrappable
  38. def makeRFile(conf: CodegenConfig): Unit
    Definition Classes
    RWrappable
  39. final val maxIter: IntParam
    Definition Classes
    HasMaxIter
  40. val outcomeCol: Param[String]
    Definition Classes
    HasOutcomeCol
  41. val outcomeModel: EstimatorParam
    Definition Classes
    DoubleMLParams
  42. val parallelism: IntParam
    Definition Classes
    HasParallelism
  43. lazy val params: Array[Param[_]]
    Definition Classes
    Params
  44. def pyAdditionalMethods: String
    Definition Classes
    PythonWrappable
  45. def pyInitFunc(): String
    Definition Classes
    PythonWrappable
  46. val sampleSplitRatio: DoubleArrayParam
    Definition Classes
    DoubleMLParams
  47. def save(path: String): Unit
    Definition Classes
    MLWritable
    Annotations
    @Since( "1.6.0" ) @throws( ... )
  48. final def set[T](param: Param[T], value: T): DoubleMLEstimator.this.type
    Definition Classes
    Params
  49. def setConfidenceLevel(value: Double): DoubleMLEstimator.this.type

    Set the higher bound percentile of ATE distribution.

    Set the higher bound percentile of ATE distribution. Default is 0.975. lower bound value will be automatically calculated as 100*(1-confidenceLevel) That means by default we compute 95% confidence interval, it is [2.5%, 97.5%] percentile of ATE distribution

    Definition Classes
    DoubleMLParams
  50. def setFeaturesCol(value: String): DoubleMLEstimator.this.type

    Definition Classes
    HasFeaturesCol
  51. def setMaxIter(value: Int): DoubleMLEstimator.this.type

    Set the maximum number of confidence interval bootstrapping iterations.

    Set the maximum number of confidence interval bootstrapping iterations. Default is 1, which means it does not calculate confidence interval. To get Ci values please set a meaningful value

    Definition Classes
    DoubleMLParams
  52. def setOutcomeCol(value: String): DoubleMLEstimator.this.type

    Set name of the column which will be used as outcome

    Set name of the column which will be used as outcome

    Definition Classes
    HasOutcomeCol
  53. def setOutcomeModel(value: Estimator[_ <: Model[_]]): DoubleMLEstimator.this.type

    Set outcome model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Set outcome model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Definition Classes
    DoubleMLParams
  54. def setParallelism(value: Int): DoubleMLEstimator.this.type
    Definition Classes
    DoubleMLParams
  55. def setSampleSplitRatio(value: Array[Double]): DoubleMLEstimator.this.type

    Set the sample split ratio, default is Array(0.5, 0.5)

    Set the sample split ratio, default is Array(0.5, 0.5)

    Definition Classes
    DoubleMLParams
  56. def setTreatmentCol(value: String): DoubleMLEstimator.this.type

    Set name of the column which will be used as treatment

    Set name of the column which will be used as treatment

    Definition Classes
    HasTreatmentCol
  57. def setTreatmentModel(value: Estimator[_ <: Model[_]]): DoubleMLEstimator.this.type

    Set treatment model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Set treatment model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Definition Classes
    DoubleMLParams
  58. def setWeightCol(value: String): DoubleMLEstimator.this.type

    Definition Classes
    HasWeightCol
  59. def toString(): String
    Definition Classes
    Identifiable → AnyRef → Any
  60. def transformSchema(schema: StructType): StructType
    Definition Classes
    DoubleMLEstimator → PipelineStage
    Annotations
    @DeveloperApi()
  61. val treatmentCol: Param[String]
    Definition Classes
    HasTreatmentCol
  62. val treatmentModel: EstimatorParam
    Definition Classes
    DoubleMLParams
  63. val uid: String
    Definition Classes
    DoubleMLEstimatorSynapseMLLogging → Identifiable
  64. val weightCol: Param[String]

    The name of the weight column

    The name of the weight column

    Definition Classes
    HasWeightCol
  65. def write: MLWriter
    Definition Classes
    ComplexParamsWritable → MLWritable