__author__ = 'rolevin'
from typing import List
from synapse.ml.cyber.utils.spark_utils import DataFrameUtils, ExplainBuilder, HasSetInputCol, HasSetOutputCol
from pyspark.ml import Estimator, Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params
from pyspark.sql import DataFrame, functions as f
[docs]class IdIndexerModel(Transformer, HasSetInputCol, HasSetOutputCol):
partitionKey = Param(
Params._dummy(),
"partitionKey",
"The name of the column to partition by, i.e., make sure the indexing takes the partition into account. "
"This is exemplified in reset_per_partition."
)
def __init__(self, input_col: str, partition_key: str, output_col: str, vocab_df: DataFrame):
super().__init__()
ExplainBuilder.build(self, inputCol=input_col, partitionKey=partition_key, outputCol=output_col)
self._vocab_df = vocab_df
def _transform(self, df):
ucols = [self.partition_key, self.input_col]
input_col = self.input_col
output_col = self.output_col
vocab_df = self._vocab_df
return df.join(vocab_df, on=ucols, how='left_outer').withColumn(
output_col,
f.when(f.col(output_col).isNotNull(), f.col(output_col)).otherwise(f.lit(0))
).drop(
input_col
)
[docs]class IdIndexer(Estimator, HasSetInputCol, HasSetOutputCol):
partitionKey = Param(
Params._dummy(),
"partitionKey",
"The name of the column to partition by, i.e., make sure the indexing takes the partition into account. "
"This is exemplified in reset_per_partition."
)
resetPerPartition = Param(
Params._dummy(),
"resetPerPartition",
"When set to True then indexing is consecutive from [1..n] for each value of the partition column. "
"When set to False then indexing is consecutive for all partition and column values."
)
def __init__(self, input_col: str, partition_key: str, output_col: str, reset_per_partition: bool):
super().__init__()
ExplainBuilder.build(
self,
inputCol=input_col,
partitionKey=partition_key,
outputCol=output_col,
resetPerPartition=reset_per_partition
)
def _make_vocab_df(self, df):
ucols = [self.getPartitionKey(), self.getInputCol()]
the_df = df.select(ucols).distinct().orderBy(ucols)
return DataFrameUtils.zip_with_index(
df=the_df,
start_index=1,
col_name=self.getOutputCol(),
partition_col=self.getPartitionKey(),
order_by_col=self.getInputCol()
) if self.getResetPerPartition() else DataFrameUtils.zip_with_index(
df=the_df,
start_index=1,
col_name=self.getOutputCol()
)
def _fit(self, df: DataFrame) -> IdIndexerModel:
return IdIndexerModel(
self.input_col, self.partition_key, self.output_col, self._make_vocab_df(df).cache()
)
[docs]class MultiIndexerModel(Transformer):
def __init__(self, models: List[IdIndexerModel]):
super().__init__()
self.models = models
[docs] def get_model_by_output_col(self, output_col):
for m in self.models:
if m.output_col == output_col:
return m
return None
def _transform(self, df: DataFrame) -> DataFrame:
curr_df = df.cache()
for model in self.models:
curr_df = model.transform(curr_df).cache()
return curr_df
[docs]class MultiIndexer(Estimator):
def __init__(self, indexers: List[IdIndexer]):
super().__init__()
self.indexers = indexers
def _fit(self, df: DataFrame) -> MultiIndexerModel:
return MultiIndexerModel([i.fit(df) for i in self.indexers])