# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
import sys
if sys.version >= '3':
basestring = str
import pyspark
from pyspark import SparkContext
from pyspark.ml.common import inherit_doc
from pyspark.sql.types import *
[docs]class HyperparamBuilder(object):
"""
Specifies the search space for hyperparameters.
"""
def __init__(self):
ctx = SparkContext.getOrCreate()
self.jvm = ctx.getOrCreate()._jvm
self.hyperparams = {}
[docs] def addHyperparam(self, est, param, hyperParam):
"""
Add a hyperparam to the builder
Args:
param (Param): The param to tune
dist (Dist): Distribution of values
"""
self.hyperparams[param] = (est, hyperParam)
return self
[docs] def build(self):
"""
Builds the search space of hyperparameters, returns the map of hyperparameters to search through.
"""
return self.hyperparams.items()
[docs]class DiscreteHyperParam(object):
"""
Specifies a discrete list of values.
"""
def __init__(self, values, seed=0):
ctx = SparkContext.getOrCreate()
self.jvm = ctx.getOrCreate()._jvm
self.hyperParam = self.jvm.com.microsoft.azure.synapse.ml.automl.HyperParamUtils.getDiscreteHyperParam(values, seed)
[docs] def get(self):
return self.hyperParam
[docs]class RangeHyperParam(object):
"""
Specifies a range of values.
"""
def __init__(self, min, max, seed=0):
ctx = SparkContext.getOrCreate()
self.jvm = ctx.getOrCreate()._jvm
self.rangeParam = self.jvm.com.microsoft.azure.synapse.ml.automl.HyperParamUtils.getRangeHyperParam(min, max, seed)
[docs] def get(self):
return self.rangeParam
[docs]class GridSpace(object):
"""
Specifies a predetermined grid of values to search through.
"""
def __init__(self, paramValues):
ctx = SparkContext.getOrCreate()
self.jvm = ctx.getOrCreate()._jvm
hyperparamBuilder = self.jvm.com.microsoft.azure.synapse.ml.automl.HyperparamBuilder()
for k, (est, hyperparam) in paramValues:
javaParam = est._java_obj.getParam(k.name)
hyperparamBuilder.addHyperparam(javaParam, hyperparam.get())
self.gridSpace = self.jvm.com.microsoft.azure.synapse.ml.automl.GridSpace(hyperparamBuilder.build())
[docs] def space(self):
return self.gridSpace
[docs]class RandomSpace(object):
"""
Specifies a random streaming range of values to search through.
"""
def __init__(self, paramDistributions):
ctx = SparkContext.getOrCreate()
self.jvm = ctx.getOrCreate()._jvm
hyperparamBuilder = self.jvm.com.microsoft.azure.synapse.ml.automl.HyperparamBuilder()
for k, (est, hyperparam) in paramDistributions:
javaParam = est._java_obj.getParam(k.name)
hyperparamBuilder.addHyperparam(javaParam, hyperparam.get())
self.paramSpace = self.jvm.com.microsoft.azure.synapse.ml.automl.RandomSpace(hyperparamBuilder.build())
[docs] def space(self):
return self.paramSpace