# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
import sys
if sys.version >= "3":
basestring = str
import pyspark
from pyspark import SparkContext
[docs]def getOption(opt):
if opt.isDefined():
return opt.get()
else:
return None
[docs]class OpenAIDefaults:
def __init__(self):
self.defaults = (
SparkContext.getOrCreate()._jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIDefaults
)
[docs] def set_deployment_name(self, name):
self.defaults.setDeploymentName(name)
[docs] def get_deployment_name(self):
return getOption(self.defaults.getDeploymentName())
[docs] def reset_deployment_name(self):
self.defaults.resetDeploymentName()
[docs] def set_subscription_key(self, key):
self.defaults.setSubscriptionKey(key)
[docs] def get_subscription_key(self):
return getOption(self.defaults.getSubscriptionKey())
[docs] def reset_subscription_key(self):
self.defaults.resetSubscriptionKey()
[docs] def set_temperature(self, temp):
temp_float = float(temp)
if not (0.0 <= temp_float <= 2.0):
raise ValueError(
f"Temperature must be between 0.0 and 2.0, got: {temp_float}"
)
self.defaults.setTemperature(temp_float)
[docs] def get_temperature(self):
return getOption(self.defaults.getTemperature())
[docs] def reset_temperature(self):
self.defaults.resetTemperature()
[docs] def set_URL(self, URL):
self.defaults.setURL(URL)
[docs] def get_URL(self):
return getOption(self.defaults.getURL())
[docs] def reset_URL(self):
self.defaults.resetURL()
[docs] def set_seed(self, seed):
self.defaults.setSeed(int(seed))
[docs] def get_seed(self):
return getOption(self.defaults.getSeed())
[docs] def reset_seed(self):
self.defaults.resetSeed()
[docs] def set_top_p(self, top_p):
top_p_float = float(top_p)
if not (0.0 <= top_p_float <= 1.0):
raise ValueError(f"TopP must be between 0.0 and 1.0, got: {top_p_float}")
self.defaults.setTopP(top_p_float)
[docs] def get_top_p(self):
return getOption(self.defaults.getTopP())
[docs] def reset_top_p(self):
self.defaults.resetTopP()
[docs] def set_api_version(self, api_version):
self.defaults.setApiVersion(api_version)
[docs] def get_api_version(self):
return getOption(self.defaults.getApiVersion())
[docs] def reset_api_version(self):
self.defaults.resetApiVersion()
[docs] def set_model(self, ai_foundry_model):
self.defaults.setModel(ai_foundry_model)
[docs] def get_model(self):
return getOption(self.defaults.getModel())
[docs] def reset_model(self):
self.defaults.resetModel()