Source code for mmlspark.plot.plot

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys

if sys.version >= '3':
    basestring = str

import pyspark
from pyspark import SparkContext
from pyspark import sql
from pyspark.sql import DataFrame
import warnings
import numpy as np
import itertools

[docs]def confusionMatrix(df, y_col, y_hat_col, labels): from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt if isinstance(df, pyspark.sql.dataframe.DataFrame): df = df.select([y_col, y_hat_col]).toPandas() y, y_hat = df[y_col], df[y_hat_col] accuracy = np.mean([1. if pred==true else 0. for (pred,true) in zip(y_hat,y)]) cm = confusion_matrix(y, y_hat) cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] plt.text(-.3, -.55,"$Accuracy$ $=$ ${}\%$".format(round(accuracy*100,1)),fontsize=18) tick_marks = np.arange(len(labels)) plt.xticks(tick_marks, labels, rotation=0) plt.yticks(tick_marks, labels, rotation=90) plt.imshow(cmn, interpolation="nearest", cmap=plt.cm.Blues, vmin=0, vmax=1) thresh = .1 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", fontsize=18, color="white" if cmn[i, j] > thresh else "black") plt.colorbar() plt.xlabel("Predicted Label", fontsize=18) plt.ylabel("True Label", fontsize=18)
[docs]def roc(df, y_col, y_hat_col, thresh=.5): from sklearn.metrics import roc_curve import matplotlib.pyplot as plt if isinstance(df, pyspark.sql.dataframe.DataFrame): df = df.select([y_col, y_hat_col]).toPandas() def f2i(X): return [int(x>thresh) for x in X] y, y_hat = f2i(df[y_col]), df[y_hat_col] fpr, tpr, thresholds = roc_curve(y, y_hat) plt.plot(fpr, tpr) plt.xlabel("False Positive Rate", fontsize=20) plt.ylabel("True Positive Rate", fontsize=20)