Source code for synapse.ml.hf.HuggingFaceCausalLMTransform

import os

from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import (
    HasInputCol,
    HasOutputCol,
    Param,
    Params,
    TypeConverters,
)
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql import Row, SparkSession
from pyspark.sql.types import StringType, StructField, StructType
from transformers import AutoModelForCausalLM, AutoTokenizer


class _PeekableIterator:
    def __init__(self, iterable):
        self._iterator = iter(iterable)
        self._cache = []

    def __iter__(self):
        return self

    def __next__(self):
        if self._cache:
            return self._cache.pop(0)
        else:
            return next(self._iterator)

    def peek(self, n=1):
        """Peek at the next n elements without consuming them."""
        while len(self._cache) < n:
            try:
                self._cache.append(next(self._iterator))
            except StopIteration:
                break
        if n == 1:
            return self._cache[0] if self._cache else None
        else:
            return self._cache[:n]


class _ModelParam:
    def __init__(self, **kwargs):
        self.param = {}
        self.param.update(kwargs)

    def get_param(self):
        return self.param


class _ModelConfig:
    def __init__(self, **kwargs):
        self.config = {}
        self.config.update(kwargs)

    def get_config(self):
        return self.config

    def set_config(self, **kwargs):
        self.config.update(kwargs)


[docs]def broadcast_model(cachePath, modelConfig): bc_computable = _BroadcastableModel(cachePath, modelConfig) sc = SparkSession.builder.getOrCreate().sparkContext return sc.broadcast(bc_computable)
class _BroadcastableModel: def __init__(self, model_path=None, model_config=None): self.model_path = model_path self.model = None self.tokenizer = None self.model_config = model_config def load_model(self): if self.model_path and os.path.exists(self.model_path): model_config = self.model_config.get_config() self.model = AutoModelForCausalLM.from_pretrained( self.model_path, local_files_only=True, **model_config ) self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, local_files_only=True ) else: raise ValueError(f"Model path {self.model_path} does not exist.") def __getstate__(self): return {"model_path": self.model_path, "model_config": self.model_config} def __setstate__(self, state): self.model_path = state.get("model_path") self.model_config = state.get("model_config") self.model = None self.tokenizer = None if self.model_path: self.load_model()
[docs]class HuggingFaceCausalLM( Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable ): modelName = Param( Params._dummy(), "modelName", "huggingface causal lm model name", typeConverter=TypeConverters.toString, ) inputCol = Param( Params._dummy(), "inputCol", "input column", typeConverter=TypeConverters.toString, ) outputCol = Param( Params._dummy(), "outputCol", "output column", typeConverter=TypeConverters.toString, ) task = Param( Params._dummy(), "task", "Specifies the task, can be chat or completion.", typeConverter=TypeConverters.toString, ) modelParam = Param( Params._dummy(), "modelParam", "Model Parameters, passed to .generate(). For more details, check https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig", ) modelConfig = Param( Params._dummy(), "modelConfig", "Model configuration, passed to AutoModelForCausalLM.from_pretrained(). For more details, check https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModelForCausalLM", ) cachePath = Param( Params._dummy(), "cachePath", "cache path for the model. A shared location between the workers, could be a lakehouse path", typeConverter=TypeConverters.toString, ) deviceMap = Param( Params._dummy(), "deviceMap", "Specifies a model parameter for the device map. It can also be set with modelParam. Commonly used values include 'auto', 'cuda', or 'cpu'. You may want to check your model documentation for device map", typeConverter=TypeConverters.toString, ) torchDtype = Param( Params._dummy(), "torchDtype", "Specifies a model parameter for the torch dtype. It can be set with modelParam. The most commonly used value is 'auto'. You may want to check your model documentation for torch dtype.", typeConverter=TypeConverters.toString, ) @keyword_only def __init__( self, modelName=None, inputCol=None, outputCol=None, task="chat", cachePath=None, deviceMap=None, torchDtype=None, ): super(HuggingFaceCausalLM, self).__init__() self._setDefault( modelName=modelName, inputCol=inputCol, outputCol=outputCol, modelParam=_ModelParam(), modelConfig=_ModelConfig(), task=task, cachePath=None, deviceMap=None, torchDtype=None, ) kwargs = self._input_kwargs self.setParams(**kwargs)
[docs] @keyword_only def setParams(self): kwargs = self._input_kwargs return self._set(**kwargs)
[docs] def setModelName(self, value): return self._set(modelName=value)
[docs] def getModelName(self): return self.getOrDefault(self.modelName)
[docs] def setInputCol(self, value): return self._set(inputCol=value)
[docs] def getInputCol(self): return self.getOrDefault(self.inputCol)
[docs] def setOutputCol(self, value): return self._set(outputCol=value)
[docs] def getOutputCol(self): return self.getOrDefault(self.outputCol)
[docs] def setModelParam(self, **kwargs): param = _ModelParam(**kwargs) return self._set(modelParam=param)
[docs] def getModelParam(self): return self.getOrDefault(self.modelParam)
[docs] def setModelConfig(self, **kwargs): config = _ModelConfig(**kwargs) return self._set(modelConfig=config)
[docs] def getModelConfig(self): return self.getOrDefault(self.modelConfig)
[docs] def setTask(self, value): supported_values = ["completion", "chat"] if value not in supported_values: raise ValueError( f"Task must be one of {supported_values}, but got '{value}'." ) return self._set(task=value)
[docs] def getTask(self): return self.getOrDefault(self.task)
[docs] def setCachePath(self, value): return self._set(cachePath=value)
[docs] def getCachePath(self): return self.getOrDefault(self.cachePath)
[docs] def setDeviceMap(self, value): return self._set(deviceMap=value)
[docs] def getDeviceMap(self): return self.getOrDefault(self.deviceMap)
[docs] def setTorchDtype(self, value): return self._set(torchDtype=value)
[docs] def getTorchDtype(self): return self.getOrDefault(self.torchDtype)
[docs] def getBCObject(self): return self.bcObject
def _predict_single_completion(self, prompt, model, tokenizer): param = self.getModelParam().get_param() inputs = tokenizer(prompt, return_tensors="pt").input_ids outputs = model.generate(inputs, **param) decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] return decoded_output def _predict_single_chat(self, prompt, model, tokenizer): param = self.getModelParam().get_param() if isinstance(prompt, list): chat = prompt else: chat = [{"role": "user", "content": prompt}] formatted_chat = tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True ) tokenized_chat = tokenizer( formatted_chat, return_tensors="pt", add_special_tokens=False ) inputs = { key: tensor.to(model.device) for key, tensor in tokenized_chat.items() } merged_inputs = {**inputs, **param} outputs = model.generate(**merged_inputs) decoded_output = tokenizer.decode( outputs[0][inputs["input_ids"].size(1) :], skip_special_tokens=True ) return decoded_output def _process_partition(self, iterator, bc_object): """Process each partition of the data.""" peekable_iterator = _PeekableIterator(iterator) try: first_row = peekable_iterator.peek() except StopIteration: return None if bc_object: lc_object = bc_object.value model = lc_object.model tokenizer = lc_object.tokenizer else: model_name = self.getModelName() model_config = self.getModelConfig().get_config() model = AutoModelForCausalLM.from_pretrained(model_name, **model_config) tokenizer = AutoTokenizer.from_pretrained(model_name) task = self.getTask() if self.getTask() else "chat" for row in peekable_iterator: prompt = row[self.getInputCol()] if task == "chat": result = self._predict_single_chat(prompt, model, tokenizer) elif task == "completion": result = self._predict_single_completion(prompt, model, tokenizer) else: raise ValueError( f"Unsupported task '{task}'. Supported tasks are 'chat' and 'completion'." ) row_dict = row.asDict() row_dict[self.getOutputCol()] = result yield Row(**row_dict) def _transform(self, dataset): if self.getCachePath(): bc_object = broadcast_model(self.getCachePath(), self.getModelConfig()) else: bc_object = None input_schema = dataset.schema output_schema = StructType( input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)] ) result_rdd = dataset.rdd.mapPartitions( lambda partition: self._process_partition(partition, bc_object) ) result_df = result_rdd.toDF(output_schema) return result_df