# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
import sys
if sys.version >= "3":
basestring = str
import pyspark
from pyspark import SparkContext
[docs]def getOption(opt):
if opt.isDefined():
return opt.get()
else:
return None
[docs]class OpenAIDefaults:
def __init__(self):
self.defaults = (
SparkContext.getOrCreate()._jvm.com.microsoft.azure.synapse.ml.services.openai.OpenAIDefaults
)
[docs] def set_deployment_name(self, name):
self.defaults.setDeploymentName(name)
[docs] def get_deployment_name(self):
return getOption(self.defaults.getDeploymentName())
[docs] def reset_deployment_name(self):
self.defaults.resetDeploymentName()
[docs] def set_subscription_key(self, key):
self.defaults.setSubscriptionKey(key)
[docs] def get_subscription_key(self):
return getOption(self.defaults.getSubscriptionKey())
[docs] def reset_subscription_key(self):
self.defaults.resetSubscriptionKey()
[docs] def set_temperature(self, temp):
self.defaults.setTemperature(float(temp))
[docs] def get_temperature(self):
return getOption(self.defaults.getTemperature())
[docs] def reset_temperature(self):
self.defaults.resetTemperature()
[docs] def set_URL(self, URL):
self.defaults.setURL(URL)
[docs] def get_URL(self):
return getOption(self.defaults.getURL())
[docs] def reset_URL(self):
self.defaults.resetURL()