Source code for synapse.ml.automl.HyperparamBuilder

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys

if sys.version >= "3":
    basestring = str

import pyspark
from pyspark import SparkContext
from pyspark.ml.common import inherit_doc
from pyspark.sql.types import *


[docs]class HyperparamBuilder(object): """ Specifies the search space for hyperparameters. """ def __init__(self): ctx = SparkContext.getOrCreate() self.jvm = ctx.getOrCreate()._jvm self.hyperparams = {}
[docs] def addHyperparam(self, est, param, hyperParam): """ Add a hyperparam to the builder Args: param (Param): The param to tune dist (Dist): Distribution of values """ self.hyperparams[param] = (est, hyperParam) return self
[docs] def build(self): """ Builds the search space of hyperparameters, returns the map of hyperparameters to search through. """ return self.hyperparams.items()
[docs]class DiscreteHyperParam(object): """ Specifies a discrete list of values. """ def __init__(self, values, seed=0): ctx = SparkContext.getOrCreate() self.jvm = ctx.getOrCreate()._jvm self.hyperParam = self.jvm.com.microsoft.azure.synapse.ml.automl.HyperParamUtils.getDiscreteHyperParam( values, seed, )
[docs] def get(self): return self.hyperParam
[docs]class RangeHyperParam(object): """ Specifies a range of values. """ def __init__(self, min, max, seed=0): ctx = SparkContext.getOrCreate() self.jvm = ctx.getOrCreate()._jvm self.rangeParam = self.jvm.com.microsoft.azure.synapse.ml.automl.HyperParamUtils.getRangeHyperParam( min, max, seed, )
[docs] def get(self): return self.rangeParam
[docs]class GridSpace(object): """ Specifies a predetermined grid of values to search through. """ def __init__(self, paramValues): ctx = SparkContext.getOrCreate() self.jvm = ctx.getOrCreate()._jvm hyperparamBuilder = self.jvm.org.apache.spark.ml.tuning.ParamGridBuilder() for k, (est, hyperparam) in paramValues: javaParam = est._java_obj.getParam(k.name) if not isinstance(hyperparam, DiscreteHyperParam): raise ValueError( "GridSpace only supports DiscreteHyperParam, but hyperparam {} is of type {}".format( k, type(hyperparam), ), ) values = hyperparam.get().getValues() hyperparamBuilder.addGrid(javaParam, self.jvm.PythonUtils.toList(values)) self.gridSpace = self.jvm.com.microsoft.azure.synapse.ml.automl.GridSpace( hyperparamBuilder.build(), )
[docs] def space(self): return self.gridSpace
[docs]class RandomSpace(object): """ Specifies a random streaming range of values to search through. """ def __init__(self, paramDistributions): ctx = SparkContext.getOrCreate() self.jvm = ctx.getOrCreate()._jvm hyperparamBuilder = ( self.jvm.com.microsoft.azure.synapse.ml.automl.HyperparamBuilder() ) for k, (est, hyperparam) in paramDistributions: javaParam = est._java_obj.getParam(k.name) hyperparamBuilder.addHyperparam(javaParam, hyperparam.get()) self.paramSpace = self.jvm.com.microsoft.azure.synapse.ml.automl.RandomSpace( hyperparamBuilder.build(), )
[docs] def space(self): return self.paramSpace