# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
from abc import ABCMeta
import sys
from typing import Mapping, List
from py4j import java_gateway
if sys.version >= "3":
basestring = str
from synapse.ml.onnx._ONNXModel import _ONNXModel
from pyspark.ml.common import inherit_doc
from py4j.java_gateway import JavaObject
[docs]class NodeInfo(object):
def __init__(self, name: str, value_info: JavaObject):
self.name = name
self.value_info = ValueInfo.from_java(value_info)
def __str__(self) -> str:
return "NodeInfo(name=" + self.name + ",info=" + str(self.value_info) + ")"
def __repr__(self) -> str:
return self.__str__()
[docs]@inherit_doc
class ONNXModel(_ONNXModel):
"""
Args:
SparkSession (SparkSession): The SparkSession that will be used to find the model
location (str): The location of the model, either on local or HDFS
"""
[docs] def setModelLocation(self, location):
self._java_obj = self._java_obj.setModelLocation(location)
return self
[docs] def setMiniBatchSize(self, n):
self._java_obj = self._java_obj.setMiniBatchSize(n)
return self
def __parse_node_info(self, node_info: JavaObject) -> "NodeInfo":
name = node_info.getName()
value_info = node_info.getInfo()
return NodeInfo(name, value_info)
[docs] def getModelOutputs(self) -> Mapping[str, NodeInfo]:
self._transfer_params_to_java()
mo = self._java_obj.modelOutputJava()
return {name: self.__parse_node_info(info) for name, info in mo.items()}
[docs]class ValueInfo(metaclass=ABCMeta):
[docs] @classmethod
def from_java(cls, java_value_info: JavaObject) -> "ValueInfo":
className = java_value_info.getClass().getName()
if className == "ai.onnxruntime.TensorInfo":
return TensorInfo.from_java(java_value_info)
elif className == "ai.onnxruntime.MapInfo":
return MapInfo.from_java(java_value_info)
else:
return SequenceInfo.from_java(java_value_info)
[docs]class TensorInfo(ValueInfo):
def __init__(self, shape: List[int], type: str):
self.shape = shape
self.type = type
def __repr__(self):
return str(self)
def __str__(self):
return "TensorInfo(shape={}, type={})".format("[" + ",".join(map(str, self.shape)) + "]", self.type)
[docs] @classmethod
def from_java(cls, java_tensor_info: JavaObject) -> "TensorInfo":
shape = list(java_tensor_info.getShape())
type = java_gateway.get_field(java_tensor_info, "type").toString()
return cls(shape, type)
[docs]class MapInfo(ValueInfo):
def __init__(self, key_type: str, value_type: str, size: int = -1):
self.key_type = key_type
self.value_type = value_type
self.size = size
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
initial = "MapInfo(size=UNKNOWN" if self.size == -1 else "MapInfo(size=" + str(self.size)
return initial + ",keyType=" + self.key_type + ",valueType=" + self.value_type + ")"
[docs] @classmethod
def from_java(cls, java_map_info: JavaObject) -> "MapInfo":
if java_map_info == None:
return None
else:
key_type = java_gateway.get_field(java_map_info, "keyType").toString()
value_type = java_gateway.get_field(java_map_info, "valueType").toString()
size = java_gateway.get_field(java_map_info, "size")
return cls(key_type, value_type, size)
[docs]class SequenceInfo(ValueInfo):
def __init__(self, length: int, sequence_of_maps: bool, map_info: MapInfo, sequence_type: str):
self.length = length
self.sequence_of_maps = sequence_of_maps
self.map_info = map_info
self.sequence_type = sequence_type
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
initial = "SequenceInfo(length=" + ("UNKNOWN" if self.length == -1 else str(self.length))
if self.sequence_of_maps:
initial += ",type=" + str(self.map_info) + ")"
else:
initial += ",type=" + str(self.sequence_type) + ")"
return initial
[docs] @classmethod
def from_java(cls, java_sequence_info: JavaObject) -> "SequenceInfo":
length = java_gateway.get_field(java_sequence_info, "length")
sequence_of_maps = java_gateway.get_field(java_sequence_info, "sequenceOfMaps")
map_info = MapInfo.from_java(java_gateway.get_field(java_sequence_info, "mapInfo"))
sequence_type = java_gateway.get_field(java_sequence_info, "sequenceType").toString()
return cls(length, sequence_of_maps, map_info, sequence_type)