# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
import sys
if sys.version >= '3':
basestring = str
from pyspark import SparkContext, SQLContext
from pyspark.sql import DataFrame
from pyspark.ml.param.shared import *
from pyspark import keyword_only
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from synapse.ml.core.serialize.java_params_patch import *
from pyspark.ml.wrapper import JavaTransformer, JavaEstimator, JavaModel
from pyspark.ml.evaluation import JavaEvaluator
from pyspark.ml.common import inherit_doc
from synapse.ml.core.schema.Utils import *
from pyspark.ml.param import TypeConverters
from synapse.ml.core.schema.TypeConversionUtils import generateTypeConverter, complexTypeConverter
[docs]@inherit_doc
class SpeechToTextSDK(ComplexParamsMixin, JavaMLReadable, JavaMLWritable, JavaTransformer):
"""
Args:
audioDataCol (object): Column holding audio data, must be either ByteArrays or Strings representing file URIs
endpointId (object): endpoint for custom speech models
extraFfmpegArgs (list): extra arguments to for ffmpeg output decoding
fileType (object): The file type of the sound files, supported types: wav, ogg, mp3
format (object): Specifies the result format. Accepted values are simple and detailed. Default is simple.
language (object): Identifies the spoken language that is being recognized.
outputCol (object): The name of the output column
participantsJson (object): a json representation of a list of conversation participants (email, language, user)
profanity (object): Specifies how to handle profanity in recognition results. Accepted values are masked, which replaces profanity with asterisks, removed, which remove all profanity from the result, or raw, which includes the profanity in the result. The default setting is masked.
recordAudioData (bool): Whether to record audio data to a file location, for use only with m3u8 streams
recordedFileNameCol (object): Column holding file names to write audio data to if ``recordAudioData'' is set to true
streamIntermediateResults (bool): Whether or not to immediately return itermediate results, or group in a sequence
subscriptionKey (object): the API key to use
url (object): Url of the service
"""
audioDataCol = Param(Params._dummy(), "audioDataCol", "Column holding audio data, must be either ByteArrays or Strings representing file URIs")
endpointId = Param(Params._dummy(), "endpointId", "endpoint for custom speech models")
extraFfmpegArgs = Param(Params._dummy(), "extraFfmpegArgs", "extra arguments to for ffmpeg output decoding", typeConverter=TypeConverters.toListString)
fileType = Param(Params._dummy(), "fileType", "The file type of the sound files, supported types: wav, ogg, mp3")
format = Param(Params._dummy(), "format", " Specifies the result format. Accepted values are simple and detailed. Default is simple. ")
language = Param(Params._dummy(), "language", " Identifies the spoken language that is being recognized. ")
outputCol = Param(Params._dummy(), "outputCol", "The name of the output column")
participantsJson = Param(Params._dummy(), "participantsJson", "a json representation of a list of conversation participants (email, language, user)")
profanity = Param(Params._dummy(), "profanity", " Specifies how to handle profanity in recognition results. Accepted values are masked, which replaces profanity with asterisks, removed, which remove all profanity from the result, or raw, which includes the profanity in the result. The default setting is masked. ")
recordAudioData = Param(Params._dummy(), "recordAudioData", "Whether to record audio data to a file location, for use only with m3u8 streams", typeConverter=TypeConverters.toBoolean)
recordedFileNameCol = Param(Params._dummy(), "recordedFileNameCol", "Column holding file names to write audio data to if ``recordAudioData'' is set to true")
streamIntermediateResults = Param(Params._dummy(), "streamIntermediateResults", "Whether or not to immediately return itermediate results, or group in a sequence", typeConverter=TypeConverters.toBoolean)
subscriptionKey = Param(Params._dummy(), "subscriptionKey", "the API key to use")
url = Param(Params._dummy(), "url", "Url of the service")
@keyword_only
def __init__(
self,
java_obj=None,
audioDataCol=None,
endpointId=None,
extraFfmpegArgs=[],
fileType=None,
fileTypeCol=None,
format=None,
formatCol=None,
language=None,
languageCol=None,
outputCol=None,
participantsJson=None,
participantsJsonCol=None,
profanity=None,
profanityCol=None,
recordAudioData=False,
recordedFileNameCol=None,
streamIntermediateResults=True,
subscriptionKey=None,
subscriptionKeyCol=None,
url=None
):
super(SpeechToTextSDK, self).__init__()
if java_obj is None:
self._java_obj = self._new_java_obj("com.microsoft.azure.synapse.ml.cognitive.SpeechToTextSDK", self.uid)
else:
self._java_obj = java_obj
self._setDefault(extraFfmpegArgs=[])
self._setDefault(recordAudioData=False)
self._setDefault(streamIntermediateResults=True)
if hasattr(self, "_input_kwargs"):
kwargs = self._input_kwargs
else:
kwargs = self.__init__._input_kwargs
if java_obj is None:
for k,v in kwargs.items():
if v is not None:
getattr(self, "set" + k[0].upper() + k[1:])(v)
[docs] @keyword_only
def setParams(
self,
audioDataCol=None,
endpointId=None,
extraFfmpegArgs=[],
fileType=None,
fileTypeCol=None,
format=None,
formatCol=None,
language=None,
languageCol=None,
outputCol=None,
participantsJson=None,
participantsJsonCol=None,
profanity=None,
profanityCol=None,
recordAudioData=False,
recordedFileNameCol=None,
streamIntermediateResults=True,
subscriptionKey=None,
subscriptionKeyCol=None,
url=None
):
"""
Set the (keyword only) parameters
"""
if hasattr(self, "_input_kwargs"):
kwargs = self._input_kwargs
else:
kwargs = self.__init__._input_kwargs
return self._set(**kwargs)
[docs] @classmethod
def read(cls):
""" Returns an MLReader instance for this class. """
return JavaMMLReader(cls)
[docs] @staticmethod
def getJavaPackage():
""" Returns package name String. """
return "com.microsoft.azure.synapse.ml.cognitive.SpeechToTextSDK"
@staticmethod
def _from_java(java_stage):
module_name=SpeechToTextSDK.__module__
module_name=module_name.rsplit(".", 1)[0] + ".SpeechToTextSDK"
return from_java(java_stage, module_name)
[docs] def setAudioDataCol(self, value):
"""
Args:
audioDataCol: Column holding audio data, must be either ByteArrays or Strings representing file URIs
"""
self._set(audioDataCol=value)
return self
[docs] def setEndpointId(self, value):
"""
Args:
endpointId: endpoint for custom speech models
"""
self._set(endpointId=value)
return self
[docs] def setFileType(self, value):
"""
Args:
fileType: The file type of the sound files, supported types: wav, ogg, mp3
"""
if isinstance(value, list):
value = SparkContext._active_spark_context._jvm.org.apache.spark.ml.param.ServiceParam.toSeq(value)
self._java_obj = self._java_obj.setFileType(value)
return self
[docs] def setFileTypeCol(self, value):
"""
Args:
fileType: The file type of the sound files, supported types: wav, ogg, mp3
"""
self._java_obj = self._java_obj.setFileTypeCol(value)
return self
[docs] def setFormat(self, value):
"""
Args:
format: Specifies the result format. Accepted values are simple and detailed. Default is simple.
"""
if isinstance(value, list):
value = SparkContext._active_spark_context._jvm.org.apache.spark.ml.param.ServiceParam.toSeq(value)
self._java_obj = self._java_obj.setFormat(value)
return self
[docs] def setFormatCol(self, value):
"""
Args:
format: Specifies the result format. Accepted values are simple and detailed. Default is simple.
"""
self._java_obj = self._java_obj.setFormatCol(value)
return self
[docs] def setLanguage(self, value):
"""
Args:
language: Identifies the spoken language that is being recognized.
"""
if isinstance(value, list):
value = SparkContext._active_spark_context._jvm.org.apache.spark.ml.param.ServiceParam.toSeq(value)
self._java_obj = self._java_obj.setLanguage(value)
return self
[docs] def setLanguageCol(self, value):
"""
Args:
language: Identifies the spoken language that is being recognized.
"""
self._java_obj = self._java_obj.setLanguageCol(value)
return self
[docs] def setOutputCol(self, value):
"""
Args:
outputCol: The name of the output column
"""
self._set(outputCol=value)
return self
[docs] def setParticipantsJson(self, value):
"""
Args:
participantsJson: a json representation of a list of conversation participants (email, language, user)
"""
if isinstance(value, list):
value = SparkContext._active_spark_context._jvm.org.apache.spark.ml.param.ServiceParam.toSeq(value)
self._java_obj = self._java_obj.setParticipantsJson(value)
return self
[docs] def setParticipantsJsonCol(self, value):
"""
Args:
participantsJson: a json representation of a list of conversation participants (email, language, user)
"""
self._java_obj = self._java_obj.setParticipantsJsonCol(value)
return self
[docs] def setProfanity(self, value):
"""
Args:
profanity: Specifies how to handle profanity in recognition results. Accepted values are masked, which replaces profanity with asterisks, removed, which remove all profanity from the result, or raw, which includes the profanity in the result. The default setting is masked.
"""
if isinstance(value, list):
value = SparkContext._active_spark_context._jvm.org.apache.spark.ml.param.ServiceParam.toSeq(value)
self._java_obj = self._java_obj.setProfanity(value)
return self
[docs] def setProfanityCol(self, value):
"""
Args:
profanity: Specifies how to handle profanity in recognition results. Accepted values are masked, which replaces profanity with asterisks, removed, which remove all profanity from the result, or raw, which includes the profanity in the result. The default setting is masked.
"""
self._java_obj = self._java_obj.setProfanityCol(value)
return self
[docs] def setRecordAudioData(self, value):
"""
Args:
recordAudioData: Whether to record audio data to a file location, for use only with m3u8 streams
"""
self._set(recordAudioData=value)
return self
[docs] def setRecordedFileNameCol(self, value):
"""
Args:
recordedFileNameCol: Column holding file names to write audio data to if ``recordAudioData'' is set to true
"""
self._set(recordedFileNameCol=value)
return self
[docs] def setStreamIntermediateResults(self, value):
"""
Args:
streamIntermediateResults: Whether or not to immediately return itermediate results, or group in a sequence
"""
self._set(streamIntermediateResults=value)
return self
[docs] def setSubscriptionKey(self, value):
"""
Args:
subscriptionKey: the API key to use
"""
if isinstance(value, list):
value = SparkContext._active_spark_context._jvm.org.apache.spark.ml.param.ServiceParam.toSeq(value)
self._java_obj = self._java_obj.setSubscriptionKey(value)
return self
[docs] def setSubscriptionKeyCol(self, value):
"""
Args:
subscriptionKey: the API key to use
"""
self._java_obj = self._java_obj.setSubscriptionKeyCol(value)
return self
[docs] def setUrl(self, value):
"""
Args:
url: Url of the service
"""
self._set(url=value)
return self
[docs] def getAudioDataCol(self):
"""
Returns:
audioDataCol: Column holding audio data, must be either ByteArrays or Strings representing file URIs
"""
return self.getOrDefault(self.audioDataCol)
[docs] def getEndpointId(self):
"""
Returns:
endpointId: endpoint for custom speech models
"""
return self.getOrDefault(self.endpointId)
[docs] def getFileType(self):
"""
Returns:
fileType: The file type of the sound files, supported types: wav, ogg, mp3
"""
return self.getOrDefault(self.fileType)
[docs] def getFormat(self):
"""
Returns:
format: Specifies the result format. Accepted values are simple and detailed. Default is simple.
"""
return self.getOrDefault(self.format)
[docs] def getLanguage(self):
"""
Returns:
language: Identifies the spoken language that is being recognized.
"""
return self.getOrDefault(self.language)
[docs] def getOutputCol(self):
"""
Returns:
outputCol: The name of the output column
"""
return self.getOrDefault(self.outputCol)
[docs] def getParticipantsJson(self):
"""
Returns:
participantsJson: a json representation of a list of conversation participants (email, language, user)
"""
return self.getOrDefault(self.participantsJson)
[docs] def getProfanity(self):
"""
Returns:
profanity: Specifies how to handle profanity in recognition results. Accepted values are masked, which replaces profanity with asterisks, removed, which remove all profanity from the result, or raw, which includes the profanity in the result. The default setting is masked.
"""
return self.getOrDefault(self.profanity)
[docs] def getRecordAudioData(self):
"""
Returns:
recordAudioData: Whether to record audio data to a file location, for use only with m3u8 streams
"""
return self.getOrDefault(self.recordAudioData)
[docs] def getRecordedFileNameCol(self):
"""
Returns:
recordedFileNameCol: Column holding file names to write audio data to if ``recordAudioData'' is set to true
"""
return self.getOrDefault(self.recordedFileNameCol)
[docs] def getStreamIntermediateResults(self):
"""
Returns:
streamIntermediateResults: Whether or not to immediately return itermediate results, or group in a sequence
"""
return self.getOrDefault(self.streamIntermediateResults)
[docs] def getSubscriptionKey(self):
"""
Returns:
subscriptionKey: the API key to use
"""
return self.getOrDefault(self.subscriptionKey)
[docs] def getUrl(self):
"""
Returns:
url: Url of the service
"""
return self.getOrDefault(self.url)
[docs] def setLocation(self, value):
self._java_obj = self._java_obj.setLocation(value)
return self
[docs] def setLinkedService(self, value):
self._java_obj = self._java_obj.setLinkedService(value)
return self