Source code for synapse.ml.services.openai.OpenAIChatCompletion

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.


import sys
if sys.version >= '3':
    basestring = str

from pyspark import SparkContext, SQLContext
from pyspark.sql import DataFrame
from pyspark.ml.param.shared import *
from pyspark import keyword_only
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from synapse.ml.core.platform import running_on_synapse_internal
from synapse.ml.core.serialize.java_params_patch import *
from pyspark.ml.wrapper import JavaTransformer, JavaEstimator, JavaModel
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.common import inherit_doc
from synapse.ml.core.schema.Utils import *
from pyspark.ml.param import TypeConverters
from synapse.ml.core.schema.TypeConversionUtils import generateTypeConverter, complexTypeConverter


[docs]@inherit_doc class OpenAIChatCompletion(ComplexParamsMixin, JavaMLReadable, JavaMLWritable, JavaTransformer): """ Args: AADToken (object): AAD Token used for authentication CustomAuthHeader (object): A Custom Value for Authorization Header apiVersion (object): version of the api bestOf (object): How many generations to create server side, and display only the best. Will not stream intermediate progress if best_of > 1. Has maximum value of 128. cacheLevel (object): can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache concurrency (int): max number of concurrent calls concurrentTimeout (float): max number seconds to wait on futures if concurrency >= 1 deploymentName (object): The name of the deployment echo (object): Echo back the prompt in addition to the completion errorCol (str): column to hold http errors frequencyPenalty (object): How much to penalize new tokens based on whether they appear in the text so far. Increases the likelihood of the model to talk about new topics. handler (object): Which strategy to use when handling requests logProbs (object): Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned. Minimum of 0 and maximum of 100 allowed. maxTokens (object): The maximum number of tokens to generate. Has minimum of 0. messagesCol (str): The column messages to generate chat completions for, in the chat format. This column should have type Array(Struct(role: String, content: String)). n (object): How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed. outputCol (str): The name of the output column presencePenalty (object): How much to penalize new tokens based on their existing frequency in the text so far. Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2. stop (object): A sequence which indicates the end of the current document. subscriptionKey (object): the API key to use temperature (object): What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed. timeout (float): number of seconds to wait before closing the connection topP (object): An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10 percent probability mass are considered. We generally recommend using this or `temperature` but not both. Minimum of 0 and maximum of 1 allowed. url (str): Url of the service user (object): The ID of the end-user, for use in tracking and rate-limiting. """ AADToken = Param(Params._dummy(), "AADToken", "ServiceParam: AAD Token used for authentication") CustomAuthHeader = Param(Params._dummy(), "CustomAuthHeader", "ServiceParam: A Custom Value for Authorization Header") apiVersion = Param(Params._dummy(), "apiVersion", "ServiceParam: version of the api") bestOf = Param(Params._dummy(), "bestOf", "ServiceParam: How many generations to create server side, and display only the best. Will not stream intermediate progress if best_of > 1. Has maximum value of 128.") cacheLevel = Param(Params._dummy(), "cacheLevel", "ServiceParam: can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache") concurrency = Param(Params._dummy(), "concurrency", "max number of concurrent calls", typeConverter=TypeConverters.toInt) concurrentTimeout = Param(Params._dummy(), "concurrentTimeout", "max number seconds to wait on futures if concurrency >= 1", typeConverter=TypeConverters.toFloat) deploymentName = Param(Params._dummy(), "deploymentName", "ServiceParam: The name of the deployment") echo = Param(Params._dummy(), "echo", "ServiceParam: Echo back the prompt in addition to the completion") errorCol = Param(Params._dummy(), "errorCol", "column to hold http errors", typeConverter=TypeConverters.toString) frequencyPenalty = Param(Params._dummy(), "frequencyPenalty", "ServiceParam: How much to penalize new tokens based on whether they appear in the text so far. Increases the likelihood of the model to talk about new topics.") handler = Param(Params._dummy(), "handler", "Which strategy to use when handling requests") logProbs = Param(Params._dummy(), "logProbs", "ServiceParam: Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned. Minimum of 0 and maximum of 100 allowed.") maxTokens = Param(Params._dummy(), "maxTokens", "ServiceParam: The maximum number of tokens to generate. Has minimum of 0.") messagesCol = Param(Params._dummy(), "messagesCol", "The column messages to generate chat completions for, in the chat format. This column should have type Array(Struct(role: String, content: String)).", typeConverter=TypeConverters.toString) n = Param(Params._dummy(), "n", "ServiceParam: How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed.") outputCol = Param(Params._dummy(), "outputCol", "The name of the output column", typeConverter=TypeConverters.toString) presencePenalty = Param(Params._dummy(), "presencePenalty", "ServiceParam: How much to penalize new tokens based on their existing frequency in the text so far. Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2.") stop = Param(Params._dummy(), "stop", "ServiceParam: A sequence which indicates the end of the current document.") subscriptionKey = Param(Params._dummy(), "subscriptionKey", "ServiceParam: the API key to use") temperature = Param(Params._dummy(), "temperature", "ServiceParam: What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed.") timeout = Param(Params._dummy(), "timeout", "number of seconds to wait before closing the connection", typeConverter=TypeConverters.toFloat) topP = Param(Params._dummy(), "topP", "ServiceParam: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10 percent probability mass are considered. We generally recommend using this or `temperature` but not both. Minimum of 0 and maximum of 1 allowed.") url = Param(Params._dummy(), "url", "Url of the service", typeConverter=TypeConverters.toString) user = Param(Params._dummy(), "user", "ServiceParam: The ID of the end-user, for use in tracking and rate-limiting.") @keyword_only def __init__( self, java_obj=None, AADToken=None, AADTokenCol=None, CustomAuthHeader=None, CustomAuthHeaderCol=None, apiVersion=None, apiVersionCol=None, bestOf=None, bestOfCol=None, cacheLevel=None, cacheLevelCol=None, concurrency=1, concurrentTimeout=None, deploymentName=None, deploymentNameCol=None, echo=None, echoCol=None, errorCol="OpenAIChatCompletion_18fed8d2057a_error", frequencyPenalty=None, frequencyPenaltyCol=None, handler=None, logProbs=None, logProbsCol=None, maxTokens=None, maxTokensCol=None, messagesCol=None, n=None, nCol=None, outputCol="OpenAIChatCompletion_18fed8d2057a_output", presencePenalty=None, presencePenaltyCol=None, stop=None, stopCol=None, subscriptionKey=None, subscriptionKeyCol=None, temperature=None, temperatureCol=None, timeout=360.0, topP=None, topPCol=None, url=None, user=None, userCol=None ): super(OpenAIChatCompletion, self).__init__() if java_obj is None: self._java_obj = self._new_java_obj("com.microsoft.azure.synapse.ml.services.openai.OpenAIChatCompletion", self.uid) else: self._java_obj = java_obj self._setDefault(concurrency=1) self._setDefault(errorCol="OpenAIChatCompletion_18fed8d2057a_error") self._setDefault(outputCol="OpenAIChatCompletion_18fed8d2057a_output") self._setDefault(timeout=360.0) if hasattr(self, "_input_kwargs"): kwargs = self._input_kwargs else: kwargs = self.__init__._input_kwargs if java_obj is None: for k,v in kwargs.items(): if v is not None: getattr(self, "set" + k[0].upper() + k[1:])(v)
[docs] @keyword_only def setParams( self, AADToken=None, AADTokenCol=None, CustomAuthHeader=None, CustomAuthHeaderCol=None, apiVersion=None, apiVersionCol=None, bestOf=None, bestOfCol=None, cacheLevel=None, cacheLevelCol=None, concurrency=1, concurrentTimeout=None, deploymentName=None, deploymentNameCol=None, echo=None, echoCol=None, errorCol="OpenAIChatCompletion_18fed8d2057a_error", frequencyPenalty=None, frequencyPenaltyCol=None, handler=None, logProbs=None, logProbsCol=None, maxTokens=None, maxTokensCol=None, messagesCol=None, n=None, nCol=None, outputCol="OpenAIChatCompletion_18fed8d2057a_output", presencePenalty=None, presencePenaltyCol=None, stop=None, stopCol=None, subscriptionKey=None, subscriptionKeyCol=None, temperature=None, temperatureCol=None, timeout=360.0, topP=None, topPCol=None, url=None, user=None, userCol=None ): """ Set the (keyword only) parameters """ if hasattr(self, "_input_kwargs"): kwargs = self._input_kwargs else: kwargs = self.__init__._input_kwargs return self._set(**kwargs)
[docs] @classmethod def read(cls): """ Returns an MLReader instance for this class. """ return JavaMMLReader(cls)
[docs] @staticmethod def getJavaPackage(): """ Returns package name String. """ return "com.microsoft.azure.synapse.ml.services.openai.OpenAIChatCompletion"
@staticmethod def _from_java(java_stage): module_name=OpenAIChatCompletion.__module__ module_name=module_name.rsplit(".", 1)[0] + ".OpenAIChatCompletion" return from_java(java_stage, module_name)
[docs] def setAADToken(self, value): """ Args: AADToken: AAD Token used for authentication """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setAADToken(value) return self
[docs] def setAADTokenCol(self, value): """ Args: AADToken: AAD Token used for authentication """ self._java_obj = self._java_obj.setAADTokenCol(value) return self
[docs] def setCustomAuthHeader(self, value): """ Args: CustomAuthHeader: A Custom Value for Authorization Header """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setCustomAuthHeader(value) return self
[docs] def setCustomAuthHeaderCol(self, value): """ Args: CustomAuthHeader: A Custom Value for Authorization Header """ self._java_obj = self._java_obj.setCustomAuthHeaderCol(value) return self
[docs] def setApiVersion(self, value): """ Args: apiVersion: version of the api """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setApiVersion(value) return self
[docs] def setApiVersionCol(self, value): """ Args: apiVersion: version of the api """ self._java_obj = self._java_obj.setApiVersionCol(value) return self
[docs] def setBestOf(self, value): """ Args: bestOf: How many generations to create server side, and display only the best. Will not stream intermediate progress if best_of > 1. Has maximum value of 128. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setBestOf(value) return self
[docs] def setBestOfCol(self, value): """ Args: bestOf: How many generations to create server side, and display only the best. Will not stream intermediate progress if best_of > 1. Has maximum value of 128. """ self._java_obj = self._java_obj.setBestOfCol(value) return self
[docs] def setCacheLevel(self, value): """ Args: cacheLevel: can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setCacheLevel(value) return self
[docs] def setCacheLevelCol(self, value): """ Args: cacheLevel: can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache """ self._java_obj = self._java_obj.setCacheLevelCol(value) return self
[docs] def setConcurrency(self, value): """ Args: concurrency: max number of concurrent calls """ self._set(concurrency=value) return self
[docs] def setConcurrentTimeout(self, value): """ Args: concurrentTimeout: max number seconds to wait on futures if concurrency >= 1 """ self._set(concurrentTimeout=value) return self
[docs] def setDeploymentName(self, value): """ Args: deploymentName: The name of the deployment """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setDeploymentName(value) return self
[docs] def setDeploymentNameCol(self, value): """ Args: deploymentName: The name of the deployment """ self._java_obj = self._java_obj.setDeploymentNameCol(value) return self
[docs] def setEcho(self, value): """ Args: echo: Echo back the prompt in addition to the completion """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setEcho(value) return self
[docs] def setEchoCol(self, value): """ Args: echo: Echo back the prompt in addition to the completion """ self._java_obj = self._java_obj.setEchoCol(value) return self
[docs] def setErrorCol(self, value): """ Args: errorCol: column to hold http errors """ self._set(errorCol=value) return self
[docs] def setFrequencyPenalty(self, value): """ Args: frequencyPenalty: How much to penalize new tokens based on whether they appear in the text so far. Increases the likelihood of the model to talk about new topics. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setFrequencyPenalty(value) return self
[docs] def setFrequencyPenaltyCol(self, value): """ Args: frequencyPenalty: How much to penalize new tokens based on whether they appear in the text so far. Increases the likelihood of the model to talk about new topics. """ self._java_obj = self._java_obj.setFrequencyPenaltyCol(value) return self
[docs] def setHandler(self, value): """ Args: handler: Which strategy to use when handling requests """ self._set(handler=value) return self
[docs] def setLogProbs(self, value): """ Args: logProbs: Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned. Minimum of 0 and maximum of 100 allowed. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setLogProbs(value) return self
[docs] def setLogProbsCol(self, value): """ Args: logProbs: Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned. Minimum of 0 and maximum of 100 allowed. """ self._java_obj = self._java_obj.setLogProbsCol(value) return self
[docs] def setMaxTokens(self, value): """ Args: maxTokens: The maximum number of tokens to generate. Has minimum of 0. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setMaxTokens(value) return self
[docs] def setMaxTokensCol(self, value): """ Args: maxTokens: The maximum number of tokens to generate. Has minimum of 0. """ self._java_obj = self._java_obj.setMaxTokensCol(value) return self
[docs] def setMessagesCol(self, value): """ Args: messagesCol: The column messages to generate chat completions for, in the chat format. This column should have type Array(Struct(role: String, content: String)). """ self._set(messagesCol=value) return self
[docs] def setN(self, value): """ Args: n: How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setN(value) return self
[docs] def setNCol(self, value): """ Args: n: How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed. """ self._java_obj = self._java_obj.setNCol(value) return self
[docs] def setOutputCol(self, value): """ Args: outputCol: The name of the output column """ self._set(outputCol=value) return self
[docs] def setPresencePenalty(self, value): """ Args: presencePenalty: How much to penalize new tokens based on their existing frequency in the text so far. Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setPresencePenalty(value) return self
[docs] def setPresencePenaltyCol(self, value): """ Args: presencePenalty: How much to penalize new tokens based on their existing frequency in the text so far. Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2. """ self._java_obj = self._java_obj.setPresencePenaltyCol(value) return self
[docs] def setStop(self, value): """ Args: stop: A sequence which indicates the end of the current document. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setStop(value) return self
[docs] def setStopCol(self, value): """ Args: stop: A sequence which indicates the end of the current document. """ self._java_obj = self._java_obj.setStopCol(value) return self
[docs] def setSubscriptionKey(self, value): """ Args: subscriptionKey: the API key to use """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setSubscriptionKey(value) return self
[docs] def setSubscriptionKeyCol(self, value): """ Args: subscriptionKey: the API key to use """ self._java_obj = self._java_obj.setSubscriptionKeyCol(value) return self
[docs] def setTemperature(self, value): """ Args: temperature: What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setTemperature(value) return self
[docs] def setTemperatureCol(self, value): """ Args: temperature: What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed. """ self._java_obj = self._java_obj.setTemperatureCol(value) return self
[docs] def setTimeout(self, value): """ Args: timeout: number of seconds to wait before closing the connection """ self._set(timeout=value) return self
[docs] def setTopP(self, value): """ Args: topP: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10 percent probability mass are considered. We generally recommend using this or `temperature` but not both. Minimum of 0 and maximum of 1 allowed. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setTopP(value) return self
[docs] def setTopPCol(self, value): """ Args: topP: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10 percent probability mass are considered. We generally recommend using this or `temperature` but not both. Minimum of 0 and maximum of 1 allowed. """ self._java_obj = self._java_obj.setTopPCol(value) return self
[docs] def setUrl(self, value): """ Args: url: Url of the service """ self._set(url=value) return self
[docs] def setUser(self, value): """ Args: user: The ID of the end-user, for use in tracking and rate-limiting. """ if isinstance(value, list): value = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.param.ServiceParam.toSeq(value) self._java_obj = self._java_obj.setUser(value) return self
[docs] def setUserCol(self, value): """ Args: user: The ID of the end-user, for use in tracking and rate-limiting. """ self._java_obj = self._java_obj.setUserCol(value) return self
[docs] def getAADToken(self): """ Returns: AADToken: AAD Token used for authentication """ return self._java_obj.getAADToken()
[docs] def getCustomAuthHeader(self): """ Returns: CustomAuthHeader: A Custom Value for Authorization Header """ return self._java_obj.getCustomAuthHeader()
[docs] def getApiVersion(self): """ Returns: apiVersion: version of the api """ return self._java_obj.getApiVersion()
[docs] def getBestOf(self): """ Returns: bestOf: How many generations to create server side, and display only the best. Will not stream intermediate progress if best_of > 1. Has maximum value of 128. """ return self._java_obj.getBestOf()
[docs] def getCacheLevel(self): """ Returns: cacheLevel: can be used to disable any server-side caching, 0=no cache, 1=prompt prefix enabled, 2=full cache """ return self._java_obj.getCacheLevel()
[docs] def getConcurrency(self): """ Returns: concurrency: max number of concurrent calls """ return self.getOrDefault(self.concurrency)
[docs] def getConcurrentTimeout(self): """ Returns: concurrentTimeout: max number seconds to wait on futures if concurrency >= 1 """ return self.getOrDefault(self.concurrentTimeout)
[docs] def getDeploymentName(self): """ Returns: deploymentName: The name of the deployment """ return self._java_obj.getDeploymentName()
[docs] def getEcho(self): """ Returns: echo: Echo back the prompt in addition to the completion """ return self._java_obj.getEcho()
[docs] def getErrorCol(self): """ Returns: errorCol: column to hold http errors """ return self.getOrDefault(self.errorCol)
[docs] def getFrequencyPenalty(self): """ Returns: frequencyPenalty: How much to penalize new tokens based on whether they appear in the text so far. Increases the likelihood of the model to talk about new topics. """ return self._java_obj.getFrequencyPenalty()
[docs] def getHandler(self): """ Returns: handler: Which strategy to use when handling requests """ return self.getOrDefault(self.handler)
[docs] def getLogProbs(self): """ Returns: logProbs: Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned. Minimum of 0 and maximum of 100 allowed. """ return self._java_obj.getLogProbs()
[docs] def getMaxTokens(self): """ Returns: maxTokens: The maximum number of tokens to generate. Has minimum of 0. """ return self._java_obj.getMaxTokens()
[docs] def getMessagesCol(self): """ Returns: messagesCol: The column messages to generate chat completions for, in the chat format. This column should have type Array(Struct(role: String, content: String)). """ return self.getOrDefault(self.messagesCol)
[docs] def getN(self): """ Returns: n: How many snippets to generate for each prompt. Minimum of 1 and maximum of 128 allowed. """ return self._java_obj.getN()
[docs] def getOutputCol(self): """ Returns: outputCol: The name of the output column """ return self.getOrDefault(self.outputCol)
[docs] def getPresencePenalty(self): """ Returns: presencePenalty: How much to penalize new tokens based on their existing frequency in the text so far. Decreases the likelihood of the model to repeat the same line verbatim. Has minimum of -2 and maximum of 2. """ return self._java_obj.getPresencePenalty()
[docs] def getStop(self): """ Returns: stop: A sequence which indicates the end of the current document. """ return self._java_obj.getStop()
[docs] def getSubscriptionKey(self): """ Returns: subscriptionKey: the API key to use """ return self._java_obj.getSubscriptionKey()
[docs] def getTemperature(self): """ Returns: temperature: What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. We generally recommend using this or `top_p` but not both. Minimum of 0 and maximum of 2 allowed. """ return self._java_obj.getTemperature()
[docs] def getTimeout(self): """ Returns: timeout: number of seconds to wait before closing the connection """ return self.getOrDefault(self.timeout)
[docs] def getTopP(self): """ Returns: topP: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10 percent probability mass are considered. We generally recommend using this or `temperature` but not both. Minimum of 0 and maximum of 1 allowed. """ return self._java_obj.getTopP()
[docs] def getUrl(self): """ Returns: url: Url of the service """ return self.getOrDefault(self.url)
[docs] def getUser(self): """ Returns: user: The ID of the end-user, for use in tracking and rate-limiting. """ return self._java_obj.getUser()
[docs] def setCustomServiceName(self, value): self._java_obj = self._java_obj.setCustomServiceName(value) return self
[docs] def setEndpoint(self, value): self._java_obj = self._java_obj.setEndpoint(value) return self
[docs] def setDefaultInternalEndpoint(self, value): self._java_obj = self._java_obj.setDefaultInternalEndpoint(value) return self
def _transform(self, dataset: DataFrame) -> DataFrame: return super()._transform(dataset)