class DoubleMLEstimator extends Estimator[DoubleMLModel] with ComplexParamsWritable with DoubleMLParams with SynapseMLLogging with Wrappable

Double ML estimators. The estimator follows the two stage process, where a set of nuisance functions are estimated in the first stage in a cross-fitting manner and a final stage estimates the average treatment effect (ATE) model. Our goal is to estimate the constant marginal ATE Theta(X)

In this estimator, the ATE is estimated by using the following estimating equations: .. math :: Y - \\E[Y | X, W] = \\Theta(X) \\cdot (T - \\E[T | X, W]) + \\epsilon

Thus if we estimate the nuisance functions :math:q(X, W) = \\E[Y | X, W] and :math:f(X, W)=\\E[T | X, W] in the first stage, we can estimate the final stage ate for each treatment t, by running a regression, minimizing the residual on residual square loss, estimating Theta(X) is a final regression problem, regressing tilde{Y} on X and tilde{T})

.. math :: \\hat{\\theta} = \\arg\\min_{\\Theta}\ \E_n\\left[ (\\tilde{Y} - \\Theta(X) \\cdot \\tilde{T})^2 \\right]

Where \\tilde{Y}=Y - \\E[Y | X, W] and :math:\\tilde{T}=T-\\E[T | X, W] denotes the residual outcome and residual treatment.

The nuisance function :math:q is a simple machine learning problem and user can use setOutcomeModel to set an arbitrary sparkML model that is internally used to solve this problem

The problem of estimating the nuisance function :math:f is also a machine learning problem and user can use setTreatmentModel to set an arbitrary sparkML model that is internally used to solve this problem.

Linear Supertypes
Wrappable, RWrappable, PythonWrappable, BaseWrappable, SynapseMLLogging, DoubleMLParams, HasParallelismInjected, HasParallelism, HasWeightCol, HasMaxIter, HasFeaturesCol, HasOutcomeCol, HasTreatmentCol, ComplexParamsWritable, MLWritable, Estimator[DoubleMLModel], PipelineStage, Logging, Params, Serializable, Serializable, Identifiable, AnyRef, Any
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. DoubleMLEstimator
  2. Wrappable
  3. RWrappable
  4. PythonWrappable
  5. BaseWrappable
  6. SynapseMLLogging
  7. DoubleMLParams
  8. HasParallelismInjected
  9. HasParallelism
  10. HasWeightCol
  11. HasMaxIter
  12. HasFeaturesCol
  13. HasOutcomeCol
  14. HasTreatmentCol
  15. ComplexParamsWritable
  16. MLWritable
  17. Estimator
  18. PipelineStage
  19. Logging
  20. Params
  21. Serializable
  22. Serializable
  23. Identifiable
  24. AnyRef
  25. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new DoubleMLEstimator()
  2. new DoubleMLEstimator(uid: String)

Value Members

  1. final def !=(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int
    Definition Classes
    AnyRef → Any
  3. final def $[T](param: Param[T]): T
    Attributes
    protected
    Definition Classes
    Params
  4. final def ==(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  5. final def asInstanceOf[T0]: T0
    Definition Classes
    Any
  6. def awaitFutures[T](futures: Array[Future[T]]): Seq[T]
    Attributes
    protected
    Definition Classes
    HasParallelismInjected
  7. lazy val classNameHelper: String
    Attributes
    protected
    Definition Classes
    BaseWrappable
  8. final def clear(param: Param[_]): DoubleMLEstimator.this.type
    Definition Classes
    Params
  9. def clone(): AnyRef
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  10. def companionModelClassName: String
    Attributes
    protected
    Definition Classes
    BaseWrappable
  11. val confidenceLevel: DoubleParam
    Definition Classes
    DoubleMLParams
  12. def copy(extra: ParamMap): Estimator[DoubleMLModel]
    Definition Classes
    DoubleMLEstimator → Estimator → PipelineStage → Params
  13. def copyValues[T <: Params](to: T, extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  14. lazy val copyrightLines: String
    Attributes
    protected
    Definition Classes
    BaseWrappable
  15. final def defaultCopy[T <: Params](extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  16. final def eq(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  17. def equals(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  18. def explainParam(param: Param[_]): String
    Definition Classes
    Params
  19. def explainParams(): String
    Definition Classes
    Params
  20. final def extractParamMap(): ParamMap
    Definition Classes
    Params
  21. final def extractParamMap(extra: ParamMap): ParamMap
    Definition Classes
    Params
  22. val featuresCol: Param[String]

    The name of the features column

    The name of the features column

    Definition Classes
    HasFeaturesCol
  23. def finalize(): Unit
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  24. def fit(dataset: Dataset[_]): DoubleMLModel

    Fits the DoubleML model.

    Fits the DoubleML model.

    dataset

    The input dataset to train.

    returns

    The trained DoubleML model, from which you can get Ate and Ci values

    Definition Classes
    DoubleMLEstimator → Estimator
  25. def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[DoubleMLModel]
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" )
  26. def fit(dataset: Dataset[_], paramMap: ParamMap): DoubleMLModel
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" )
  27. def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DoubleMLModel
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" ) @varargs()
  28. final def get[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  29. final def getClass(): Class[_]
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  30. def getConfidenceLevel: Double
    Definition Classes
    DoubleMLParams
  31. final def getDefault[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  32. def getExecutionContextProxy: ExecutionContext
    Definition Classes
    HasParallelismInjected
  33. def getFeaturesCol: String

    Definition Classes
    HasFeaturesCol
  34. final def getMaxIter: Int
    Definition Classes
    HasMaxIter
  35. final def getOrDefault[T](param: Param[T]): T
    Definition Classes
    Params
  36. def getOutcomeCol: String
    Definition Classes
    HasOutcomeCol
  37. def getOutcomeModel: Estimator[_ <: Model[_]]
    Definition Classes
    DoubleMLParams
  38. def getParallelism: Int
    Definition Classes
    HasParallelism
  39. def getParam(paramName: String): Param[Any]
    Definition Classes
    Params
  40. def getParamInfo(p: Param[_]): ParamInfo[_]
    Definition Classes
    BaseWrappable
  41. def getPayload(methodName: String, numCols: Option[Int], executionSeconds: Option[Double], exception: Option[Exception]): Map[String, String]
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  42. def getSampleSplitRatio: Array[Double]
    Definition Classes
    DoubleMLParams
  43. def getTreatmentCol: String
    Definition Classes
    HasTreatmentCol
  44. def getTreatmentModel: Estimator[_ <: Model[_]]
    Definition Classes
    DoubleMLParams
  45. def getWeightCol: String

    Definition Classes
    HasWeightCol
  46. final def hasDefault[T](param: Param[T]): Boolean
    Definition Classes
    Params
  47. def hasParam(paramName: String): Boolean
    Definition Classes
    Params
  48. def hashCode(): Int
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  49. def initializeLogIfNecessary(isInterpreter: Boolean, silent: Boolean): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  50. def initializeLogIfNecessary(isInterpreter: Boolean): Unit
    Attributes
    protected
    Definition Classes
    Logging
  51. final def isDefined(param: Param[_]): Boolean
    Definition Classes
    Params
  52. final def isInstanceOf[T0]: Boolean
    Definition Classes
    Any
  53. final def isSet(param: Param[_]): Boolean
    Definition Classes
    Params
  54. def isTraceEnabled(): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  55. def log: Logger
    Attributes
    protected
    Definition Classes
    Logging
  56. def logBase(info: Map[String, String], featureName: Option[String]): Unit
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  57. def logBase(methodName: String, numCols: Option[Int], executionSeconds: Option[Double], featureName: Option[String]): Unit
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  58. def logClass(featureName: String): Unit
    Definition Classes
    SynapseMLLogging
  59. def logDebug(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  60. def logDebug(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  61. def logError(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  62. def logError(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  63. def logErrorBase(methodName: String, e: Exception): Unit
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  64. def logFit[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  65. def logInfo(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  66. def logInfo(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  67. def logName: String
    Attributes
    protected
    Definition Classes
    Logging
  68. def logTrace(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  69. def logTrace(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  70. def logTransform[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  71. def logVerb[T](verb: String, f: ⇒ T, columns: Option[Int] = None): T
    Definition Classes
    SynapseMLLogging
  72. def logWarning(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  73. def logWarning(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  74. def makePyFile(conf: CodegenConfig): Unit
    Definition Classes
    PythonWrappable
  75. def makeRFile(conf: CodegenConfig): Unit
    Definition Classes
    RWrappable
  76. final val maxIter: IntParam
    Definition Classes
    HasMaxIter
  77. final def ne(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  78. final def notify(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  79. final def notifyAll(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  80. val outcomeCol: Param[String]
    Definition Classes
    HasOutcomeCol
  81. val outcomeModel: EstimatorParam
    Definition Classes
    DoubleMLParams
  82. val parallelism: IntParam
    Definition Classes
    HasParallelism
  83. lazy val params: Array[Param[_]]
    Definition Classes
    Params
  84. def pyAdditionalMethods: String
    Definition Classes
    PythonWrappable
  85. lazy val pyClassDoc: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  86. lazy val pyClassName: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  87. def pyExtraEstimatorImports: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  88. def pyExtraEstimatorMethods: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  89. lazy val pyInheritedClasses: Seq[String]
    Attributes
    protected
    Definition Classes
    PythonWrappable
  90. def pyInitFunc(): String
    Definition Classes
    PythonWrappable
  91. lazy val pyInternalWrapper: Boolean
    Attributes
    protected
    Definition Classes
    PythonWrappable
  92. lazy val pyObjectBaseClass: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  93. def pyParamArg[T](p: Param[T]): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  94. def pyParamDefault[T](p: Param[T]): Option[String]
    Attributes
    protected
    Definition Classes
    PythonWrappable
  95. def pyParamGetter(p: Param[_]): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  96. def pyParamSetter(p: Param[_]): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  97. def pyParamsArgs: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  98. def pyParamsDefaults: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  99. lazy val pyParamsDefinitions: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  100. def pyParamsGetters: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  101. def pyParamsSetters: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  102. def pythonClass(): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  103. def rClass(): String
    Attributes
    protected
    Definition Classes
    RWrappable
  104. def rDocString: String
    Attributes
    protected
    Definition Classes
    RWrappable
  105. def rExtraBodyLines: String
    Attributes
    protected
    Definition Classes
    RWrappable
  106. def rExtraInitLines: String
    Attributes
    protected
    Definition Classes
    RWrappable
  107. lazy val rFuncName: String
    Attributes
    protected
    Definition Classes
    RWrappable
  108. lazy val rInternalWrapper: Boolean
    Attributes
    protected
    Definition Classes
    RWrappable
  109. def rParamArg[T](p: Param[T]): String
    Attributes
    protected
    Definition Classes
    RWrappable
  110. def rParamsArgs: String
    Attributes
    protected
    Definition Classes
    RWrappable
  111. def rSetterLines: String
    Attributes
    protected
    Definition Classes
    RWrappable
  112. val sampleSplitRatio: DoubleArrayParam
    Definition Classes
    DoubleMLParams
  113. def save(path: String): Unit
    Definition Classes
    MLWritable
    Annotations
    @Since( "1.6.0" ) @throws( ... )
  114. final def set(paramPair: ParamPair[_]): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  115. final def set(param: String, value: Any): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  116. final def set[T](param: Param[T], value: T): DoubleMLEstimator.this.type
    Definition Classes
    Params
  117. def setConfidenceLevel(value: Double): DoubleMLEstimator.this.type

    Set the higher bound percentile of ATE distribution.

    Set the higher bound percentile of ATE distribution. Default is 0.975. lower bound value will be automatically calculated as 100*(1-confidenceLevel) That means by default we compute 95% confidence interval, it is [2.5%, 97.5%] percentile of ATE distribution

    Definition Classes
    DoubleMLParams
  118. final def setDefault(paramPairs: ParamPair[_]*): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  119. final def setDefault[T](param: Param[T], value: T): DoubleMLEstimator.this.type
    Attributes
    protected[org.apache.spark.ml]
    Definition Classes
    Params
  120. def setFeaturesCol(value: String): DoubleMLEstimator.this.type

    Definition Classes
    HasFeaturesCol
  121. def setMaxIter(value: Int): DoubleMLEstimator.this.type

    Set the maximum number of confidence interval bootstrapping iterations.

    Set the maximum number of confidence interval bootstrapping iterations. Default is 1, which means it does not calculate confidence interval. To get Ci values please set a meaningful value

    Definition Classes
    DoubleMLParams
  122. def setOutcomeCol(value: String): DoubleMLEstimator.this.type

    Set name of the column which will be used as outcome

    Set name of the column which will be used as outcome

    Definition Classes
    HasOutcomeCol
  123. def setOutcomeModel(value: Estimator[_ <: Model[_]]): DoubleMLEstimator.this.type

    Set outcome model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Set outcome model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Definition Classes
    DoubleMLParams
  124. def setParallelism(value: Int): DoubleMLEstimator.this.type
    Definition Classes
    DoubleMLParams
  125. def setSampleSplitRatio(value: Array[Double]): DoubleMLEstimator.this.type

    Set the sample split ratio, default is Array(0.5, 0.5)

    Set the sample split ratio, default is Array(0.5, 0.5)

    Definition Classes
    DoubleMLParams
  126. def setTreatmentCol(value: String): DoubleMLEstimator.this.type

    Set name of the column which will be used as treatment

    Set name of the column which will be used as treatment

    Definition Classes
    HasTreatmentCol
  127. def setTreatmentModel(value: Estimator[_ <: Model[_]]): DoubleMLEstimator.this.type

    Set treatment model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Set treatment model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Definition Classes
    DoubleMLParams
  128. def setWeightCol(value: String): DoubleMLEstimator.this.type

    Definition Classes
    HasWeightCol
  129. final def synchronized[T0](arg0: ⇒ T0): T0
    Definition Classes
    AnyRef
  130. val thisStage: Params
    Attributes
    protected
    Definition Classes
    BaseWrappable
  131. def toString(): String
    Definition Classes
    Identifiable → AnyRef → Any
  132. def transformSchema(schema: StructType): StructType
    Definition Classes
    DoubleMLEstimator → PipelineStage
    Annotations
    @DeveloperApi()
  133. def transformSchema(schema: StructType, logging: Boolean): StructType
    Attributes
    protected
    Definition Classes
    PipelineStage
    Annotations
    @DeveloperApi()
  134. val treatmentCol: Param[String]
    Definition Classes
    HasTreatmentCol
  135. val treatmentModel: EstimatorParam
    Definition Classes
    DoubleMLParams
  136. val uid: String
    Definition Classes
    DoubleMLEstimatorSynapseMLLogging → Identifiable
  137. def validateColTypeWithModel(dataset: Dataset[_], colName: String, model: Estimator[_]): Unit
    Attributes
    protected
  138. final def wait(): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  139. final def wait(arg0: Long, arg1: Int): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  140. final def wait(arg0: Long): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  141. val weightCol: Param[String]

    The name of the weight column

    The name of the weight column

    Definition Classes
    HasWeightCol
  142. def write: MLWriter
    Definition Classes
    ComplexParamsWritable → MLWritable

Inherited from Wrappable

Inherited from RWrappable

Inherited from PythonWrappable

Inherited from BaseWrappable

Inherited from SynapseMLLogging

Inherited from DoubleMLParams

Inherited from HasParallelismInjected

Inherited from HasParallelism

Inherited from HasWeightCol

Inherited from HasMaxIter

Inherited from HasFeaturesCol

Inherited from HasOutcomeCol

Inherited from HasTreatmentCol

Inherited from ComplexParamsWritable

Inherited from MLWritable

Inherited from Estimator[DoubleMLModel]

Inherited from PipelineStage

Inherited from Logging

Inherited from Params

Inherited from Serializable

Inherited from Serializable

Inherited from Identifiable

Inherited from AnyRef

Inherited from Any

getParam

param

setParam

Ungrouped