class DoubleMLEstimator extends Estimator[DoubleMLModel] with ComplexParamsWritable with DoubleMLParams with SynapseMLLogging with Wrappable

Double ML estimators. The estimator follows the two stage process, where a set of nuisance functions are estimated in the first stage in a cross-fitting manner and a final stage estimates the average treatment effect (ATE) model. Our goal is to estimate the constant marginal ATE Theta(X)

In this estimator, the ATE is estimated by using the following estimating equations: .. math :: Y - \\E[Y | X, W] = \\Theta(X) \\cdot (T - \\E[T | X, W]) + \\epsilon

Thus if we estimate the nuisance functions :math:q(X, W) = \\E[Y | X, W] and :math:f(X, W)=\\E[T | X, W] in the first stage, we can estimate the final stage ate for each treatment t, by running a regression, minimizing the residual on residual square loss, estimating Theta(X) is a final regression problem, regressing tilde{Y} on X and tilde{T})

.. math :: \\hat{\\theta} = \\arg\\min_{\\Theta}\ \E_n\\left[ (\\tilde{Y} - \\Theta(X) \\cdot \\tilde{T})^2 \\right]

Where \\tilde{Y}=Y - \\E[Y | X, W] and :math:\\tilde{T}=T-\\E[T | X, W] denotes the residual outcome and residual treatment.

The nuisance function :math:q is a simple machine learning problem and user can use setOutcomeModel to set an arbitrary sparkML model that is internally used to solve this problem

The problem of estimating the nuisance function :math:f is also a machine learning problem and user can use setTreatmentModel to set an arbitrary sparkML model that is internally used to solve this problem.

Linear Supertypes
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. DoubleMLEstimator
  2. Wrappable
  3. DotnetWrappable
  4. RWrappable
  5. PythonWrappable
  6. BaseWrappable
  7. SynapseMLLogging
  8. DoubleMLParams
  9. HasParallelismInjected
  10. HasParallelism
  11. HasWeightCol
  12. HasMaxIter
  13. HasFeaturesCol
  14. HasOutcomeCol
  15. HasTreatmentCol
  16. ComplexParamsWritable
  17. MLWritable
  18. Estimator
  19. PipelineStage
  20. Logging
  21. Params
  22. Serializable
  23. Serializable
  24. Identifiable
  25. AnyRef
  26. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new DoubleMLEstimator()
  2. new DoubleMLEstimator(uid: String)

Value Members

  1. final def !=(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int
    Definition Classes
    AnyRef → Any
  3. final def $[T](param: Param[T]): T
    Attributes
    protected
    Definition Classes
    Params
  4. final def ==(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  5. final def asInstanceOf[T0]: T0
    Definition Classes
    Any
  6. def awaitFutures[T](futures: Array[Future[T]]): Seq[T]
    Attributes
    protected
    Definition Classes
    HasParallelismInjected
  7. lazy val classNameHelper: String
    Attributes
    protected
    Definition Classes
    BaseWrappable
  8. final def clear(param: Param[_]): DoubleMLEstimator.this.type
    Definition Classes
    Params
  9. def clone(): AnyRef
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  10. def companionModelClassName: String
    Attributes
    protected
    Definition Classes
    BaseWrappable
  11. val confidenceLevel: DoubleParam
    Definition Classes
    DoubleMLParams
  12. def copy(extra: ParamMap): Estimator[DoubleMLModel]
    Definition Classes
    DoubleMLEstimator → Estimator → PipelineStage → Params
  13. def copyValues[T <: Params](to: T, extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  14. lazy val copyrightLines: String
    Attributes
    protected
    Definition Classes
    BaseWrappable
  15. final def defaultCopy[T <: Params](extra: ParamMap): T
    Attributes
    protected
    Definition Classes
    Params
  16. def dotnetAdditionalMethods: String
    Definition Classes
    DotnetWrappable
  17. def dotnetClass(): String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  18. lazy val dotnetClassName: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  19. lazy val dotnetClassNameString: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  20. lazy val dotnetClassWrapperName: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  21. lazy val dotnetCopyrightLines: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  22. def dotnetExtraEstimatorImports: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  23. def dotnetExtraMethods: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  24. lazy val dotnetInternalWrapper: Boolean
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  25. def dotnetMLReadWriteMethods: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  26. lazy val dotnetNamespace: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  27. lazy val dotnetObjectBaseClass: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  28. def dotnetParamGetter(p: Param[_]): String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  29. def dotnetParamGetters: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  30. def dotnetParamSetter(p: Param[_]): String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  31. def dotnetParamSetters: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  32. def dotnetWrapAsTypeMethod: String
    Attributes
    protected
    Definition Classes
    DotnetWrappable
  33. final def eq(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  34. def equals(arg0: Any): Boolean
    Definition Classes
    AnyRef → Any
  35. def explainParam(param: Param[_]): String
    Definition Classes
    Params
  36. def explainParams(): String
    Definition Classes
    Params
  37. final def extractParamMap(): ParamMap
    Definition Classes
    Params
  38. final def extractParamMap(extra: ParamMap): ParamMap
    Definition Classes
    Params
  39. val featuresCol: Param[String]

    The name of the features column

    The name of the features column

    Definition Classes
    HasFeaturesCol
  40. def finalize(): Unit
    Attributes
    protected[lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  41. def fit(dataset: Dataset[_]): DoubleMLModel

    Fits the DoubleML model.

    Fits the DoubleML model.

    dataset

    The input dataset to train.

    returns

    The trained DoubleML model, from which you can get Ate and Ci values

    Definition Classes
    DoubleMLEstimator → Estimator
  42. def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[DoubleMLModel]
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" )
  43. def fit(dataset: Dataset[_], paramMap: ParamMap): DoubleMLModel
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" )
  44. def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DoubleMLModel
    Definition Classes
    Estimator
    Annotations
    @Since( "2.0.0" ) @varargs()
  45. final def get[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  46. final def getClass(): Class[_]
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  47. def getConfidenceLevel: Double
    Definition Classes
    DoubleMLParams
  48. final def getDefault[T](param: Param[T]): Option[T]
    Definition Classes
    Params
  49. def getExecutionContextProxy: ExecutionContext
    Definition Classes
    HasParallelismInjected
  50. def getFeaturesCol: String

    Definition Classes
    HasFeaturesCol
  51. final def getMaxIter: Int
    Definition Classes
    HasMaxIter
  52. final def getOrDefault[T](param: Param[T]): T
    Definition Classes
    Params
  53. def getOutcomeCol: String
    Definition Classes
    HasOutcomeCol
  54. def getOutcomeModel: Estimator[_ <: Model[_]]
    Definition Classes
    DoubleMLParams
  55. def getParallelism: Int
    Definition Classes
    HasParallelism
  56. def getParam(paramName: String): Param[Any]
    Definition Classes
    Params
  57. def getParamInfo(p: Param[_]): ParamInfo[_]
    Definition Classes
    BaseWrappable
  58. def getSampleSplitRatio: Array[Double]
    Definition Classes
    DoubleMLParams
  59. def getTreatmentCol: String
    Definition Classes
    HasTreatmentCol
  60. def getTreatmentModel: Estimator[_ <: Model[_]]
    Definition Classes
    DoubleMLParams
  61. def getWeightCol: String

    Definition Classes
    HasWeightCol
  62. final def hasDefault[T](param: Param[T]): Boolean
    Definition Classes
    Params
  63. def hasParam(paramName: String): Boolean
    Definition Classes
    Params
  64. def hashCode(): Int
    Definition Classes
    AnyRef → Any
    Annotations
    @native()
  65. def initializeLogIfNecessary(isInterpreter: Boolean, silent: Boolean): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  66. def initializeLogIfNecessary(isInterpreter: Boolean): Unit
    Attributes
    protected
    Definition Classes
    Logging
  67. final def isDefined(param: Param[_]): Boolean
    Definition Classes
    Params
  68. final def isInstanceOf[T0]: Boolean
    Definition Classes
    Any
  69. final def isSet(param: Param[_]): Boolean
    Definition Classes
    Params
  70. def isTraceEnabled(): Boolean
    Attributes
    protected
    Definition Classes
    Logging
  71. def log: Logger
    Attributes
    protected
    Definition Classes
    Logging
  72. def logBase(info: SynapseMLLogInfo): Unit
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  73. def logBase(methodName: String, columns: Option[Int]): Unit
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  74. def logClass(): Unit
    Definition Classes
    SynapseMLLogging
  75. def logDebug(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  76. def logDebug(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  77. def logError(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  78. def logError(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  79. def logErrorBase(methodName: String, e: Exception): Unit
    Attributes
    protected
    Definition Classes
    SynapseMLLogging
  80. def logFit[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  81. def logInfo(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  82. def logInfo(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  83. def logName: String
    Attributes
    protected
    Definition Classes
    Logging
  84. def logTrace(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  85. def logTrace(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  86. def logTrain[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  87. def logTransform[T](f: ⇒ T, columns: Int): T
    Definition Classes
    SynapseMLLogging
  88. def logVerb[T](verb: String, f: ⇒ T, columns: Int = -1): T
    Definition Classes
    SynapseMLLogging
  89. def logWarning(msg: ⇒ String, throwable: Throwable): Unit
    Attributes
    protected
    Definition Classes
    Logging
  90. def logWarning(msg: ⇒ String): Unit
    Attributes
    protected
    Definition Classes
    Logging
  91. def makeDotnetFile(conf: CodegenConfig): Unit
    Definition Classes
    DotnetWrappable
  92. def makePyFile(conf: CodegenConfig): Unit
    Definition Classes
    PythonWrappable
  93. def makeRFile(conf: CodegenConfig): Unit
    Definition Classes
    RWrappable
  94. final val maxIter: IntParam
    Definition Classes
    HasMaxIter
  95. final def ne(arg0: AnyRef): Boolean
    Definition Classes
    AnyRef
  96. final def notify(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  97. final def notifyAll(): Unit
    Definition Classes
    AnyRef
    Annotations
    @native()
  98. val outcomeCol: Param[String]
    Definition Classes
    HasOutcomeCol
  99. val outcomeModel: EstimatorParam
    Definition Classes
    DoubleMLParams
  100. val parallelism: IntParam
    Definition Classes
    HasParallelism
  101. lazy val params: Array[Param[_]]
    Definition Classes
    Params
  102. def pyAdditionalMethods: String
    Definition Classes
    PythonWrappable
  103. lazy val pyClassDoc: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  104. lazy val pyClassName: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  105. def pyExtraEstimatorImports: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  106. def pyExtraEstimatorMethods: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  107. lazy val pyInheritedClasses: Seq[String]
    Attributes
    protected
    Definition Classes
    PythonWrappable
  108. def pyInitFunc(): String
    Definition Classes
    PythonWrappable
  109. lazy val pyInternalWrapper: Boolean
    Attributes
    protected
    Definition Classes
    PythonWrappable
  110. lazy val pyObjectBaseClass: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  111. def pyParamArg[T](p: Param[T]): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  112. def pyParamDefault[T](p: Param[T]): Option[String]
    Attributes
    protected
    Definition Classes
    PythonWrappable
  113. def pyParamGetter(p: Param[_]): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  114. def pyParamSetter(p: Param[_]): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  115. def pyParamsArgs: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  116. def pyParamsDefaults: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  117. lazy val pyParamsDefinitions: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  118. def pyParamsGetters: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  119. def pyParamsSetters: String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  120. def pythonClass(): String
    Attributes
    protected
    Definition Classes
    PythonWrappable
  121. def rClass(): String
    Attributes
    protected
    Definition Classes
    RWrappable
  122. def rDocString: String
    Attributes
    protected
    Definition Classes
    RWrappable
  123. def rExtraBodyLines: String
    Attributes
    protected
    Definition Classes
    RWrappable
  124. def rExtraInitLines: String
    Attributes
    protected
    Definition Classes
    RWrappable
  125. lazy val rFuncName: String
    Attributes
    protected
    Definition Classes
    RWrappable
  126. lazy val rInternalWrapper: Boolean
    Attributes
    protected
    Definition Classes
    RWrappable
  127. def rParamArg[T](p: Param[T]): String
    Attributes
    protected
    Definition Classes
    RWrappable
  128. def rParamsArgs: String
    Attributes
    protected
    Definition Classes
    RWrappable
  129. def rSetterLines: String
    Attributes
    protected
    Definition Classes
    RWrappable
  130. val sampleSplitRatio: DoubleArrayParam
    Definition Classes
    DoubleMLParams
  131. def save(path: String): Unit
    Definition Classes
    MLWritable
    Annotations
    @Since( "1.6.0" ) @throws( ... )
  132. final def set(paramPair: ParamPair[_]): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  133. final def set(param: String, value: Any): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  134. final def set[T](param: Param[T], value: T): DoubleMLEstimator.this.type
    Definition Classes
    Params
  135. def setConfidenceLevel(value: Double): DoubleMLEstimator.this.type

    Set the higher bound percentile of ATE distribution.

    Set the higher bound percentile of ATE distribution. Default is 0.975. lower bound value will be automatically calculated as 100*(1-confidenceLevel) That means by default we compute 95% confidence interval, it is [2.5%, 97.5%] percentile of ATE distribution

    Definition Classes
    DoubleMLParams
  136. final def setDefault(paramPairs: ParamPair[_]*): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  137. final def setDefault[T](param: Param[T], value: T): DoubleMLEstimator.this.type
    Attributes
    protected
    Definition Classes
    Params
  138. def setFeaturesCol(value: String): DoubleMLEstimator.this.type

    Definition Classes
    HasFeaturesCol
  139. def setMaxIter(value: Int): DoubleMLEstimator.this.type

    Set the maximum number of confidence interval bootstrapping iterations.

    Set the maximum number of confidence interval bootstrapping iterations. Default is 1, which means it does not calculate confidence interval. To get Ci values please set a meaningful value

    Definition Classes
    DoubleMLParams
  140. def setOutcomeCol(value: String): DoubleMLEstimator.this.type

    Set name of the column which will be used as outcome

    Set name of the column which will be used as outcome

    Definition Classes
    HasOutcomeCol
  141. def setOutcomeModel(value: Estimator[_ <: Model[_]]): DoubleMLEstimator.this.type

    Set outcome model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Set outcome model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Definition Classes
    DoubleMLParams
  142. def setParallelism(value: Int): DoubleMLEstimator.this.type
    Definition Classes
    DoubleMLParams
  143. def setSampleSplitRatio(value: Array[Double]): DoubleMLEstimator.this.type

    Set the sample split ratio, default is Array(0.5, 0.5)

    Set the sample split ratio, default is Array(0.5, 0.5)

    Definition Classes
    DoubleMLParams
  144. def setTreatmentCol(value: String): DoubleMLEstimator.this.type

    Set name of the column which will be used as treatment

    Set name of the column which will be used as treatment

    Definition Classes
    HasTreatmentCol
  145. def setTreatmentModel(value: Estimator[_ <: Model[_]]): DoubleMLEstimator.this.type

    Set treatment model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Set treatment model, it could be any model derived from 'org.apache.spark.ml.regression.Regressor' or 'org.apache.spark.ml.classification.ProbabilisticClassifier'

    Definition Classes
    DoubleMLParams
  146. def setWeightCol(value: String): DoubleMLEstimator.this.type

    Definition Classes
    HasWeightCol
  147. final def synchronized[T0](arg0: ⇒ T0): T0
    Definition Classes
    AnyRef
  148. val thisStage: Params
    Attributes
    protected
    Definition Classes
    BaseWrappable
  149. def toString(): String
    Definition Classes
    Identifiable → AnyRef → Any
  150. def transformSchema(schema: StructType): StructType
    Definition Classes
    DoubleMLEstimator → PipelineStage
    Annotations
    @DeveloperApi()
  151. def transformSchema(schema: StructType, logging: Boolean): StructType
    Attributes
    protected
    Definition Classes
    PipelineStage
    Annotations
    @DeveloperApi()
  152. val treatmentCol: Param[String]
    Definition Classes
    HasTreatmentCol
  153. val treatmentModel: EstimatorParam
    Definition Classes
    DoubleMLParams
  154. val uid: String
    Definition Classes
    DoubleMLEstimatorSynapseMLLogging → Identifiable
  155. def validateColTypeWithModel(dataset: Dataset[_], colName: String, model: Estimator[_]): Unit
    Attributes
    protected
  156. final def wait(): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  157. final def wait(arg0: Long, arg1: Int): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  158. final def wait(arg0: Long): Unit
    Definition Classes
    AnyRef
    Annotations
    @throws( ... ) @native()
  159. val weightCol: Param[String]

    The name of the weight column

    The name of the weight column

    Definition Classes
    HasWeightCol
  160. def write: MLWriter
    Definition Classes
    ComplexParamsWritable → MLWritable

Inherited from Wrappable

Inherited from DotnetWrappable

Inherited from RWrappable

Inherited from PythonWrappable

Inherited from BaseWrappable

Inherited from SynapseMLLogging

Inherited from DoubleMLParams

Inherited from HasParallelismInjected

Inherited from HasParallelism

Inherited from HasWeightCol

Inherited from HasMaxIter

Inherited from HasFeaturesCol

Inherited from HasOutcomeCol

Inherited from HasTreatmentCol

Inherited from ComplexParamsWritable

Inherited from MLWritable

Inherited from Estimator[DoubleMLModel]

Inherited from PipelineStage

Inherited from Logging

Inherited from Params

Inherited from Serializable

Inherited from Serializable

Inherited from Identifiable

Inherited from AnyRef

Inherited from Any

getParam

param

setParam

Ungrouped