Source code for synapse.ml.nn.ConditionalBallTree

# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys
from pyspark import SparkContext

if sys.version >= "3":
    basestring = str


[docs]class ConditionalBallTree(object): def __init__(self, keys, values, labels, leafSize, java_obj=None): """ Create a conditional ball tree. :param keys: 2D array representing the data, shape: n_points x n_features :param values: 1D array :param labels: 1D array :param leafSize: int """ if java_obj is None: self._jconditional_balltree = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.nn.ConditionalBallTree.apply( keys, values, labels, leafSize, ) else: self._jconditional_balltree = java_obj
[docs] def findMaximumInnerProducts(self, queryPoint, conditioner, k): """ Find the best match to the queryPoint given the conditioner and k from self. :param queryPoint: array vector to use to query for NNs :param conditioner: set of labels that will subset or condition the NN query :param k: int representing the maximum number of neighbors to return :return: array of tuples representing the index of the match and its distance """ return [ (bm.index(), bm.distance()) for bm in self._jconditional_balltree.findMaximumInnerProducts( queryPoint, conditioner, k, ) ]
[docs] def save(self, filename): self._jconditional_balltree.save(filename)
[docs] @staticmethod def load(filename): java_obj = SparkContext._active_spark_context._jvm.com.microsoft.azure.synapse.ml.nn.ConditionalBallTree.load( filename, ) return ConditionalBallTree(None, None, None, None, java_obj=java_obj)