# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.
from abc import ABCMeta
import sys
from typing import Mapping, List
from py4j import java_gateway
if sys.version >= "3":
basestring = str
from synapse.ml.onnx._ONNXModel import _ONNXModel
from pyspark.ml.common import inherit_doc
from py4j.java_gateway import JavaObject
[docs]class NodeInfo(object):
def __init__(self, name: str, value_info: JavaObject):
self.name = name
self.value_info = ValueInfo.from_java(value_info)
def __str__(self) -> str:
return "NodeInfo(name=" + self.name + ",info=" + str(self.value_info) + ")"
def __repr__(self) -> str:
return self.__str__()
[docs]@inherit_doc
class ONNXModel(_ONNXModel):
"""
Args:
SparkSession (SparkSession): The SparkSession that will be used to find the model
location (str): The location of the model, either on local or HDFS
"""
[docs] def setModelLocation(self, location):
self._java_obj = self._java_obj.setModelLocation(location)
return self
[docs] def setMiniBatchSize(self, n):
self._java_obj = self._java_obj.setMiniBatchSize(n)
return self
def __parse_node_info(self, node_info: JavaObject) -> "NodeInfo":
name = node_info.getName()
value_info = node_info.getInfo()
return NodeInfo(name, value_info)
[docs] def getModelOutputs(self) -> Mapping[str, NodeInfo]:
self._transfer_params_to_java()
mo = self._java_obj.modelOutputJava()
return {name: self.__parse_node_info(info) for name, info in mo.items()}
[docs]class ValueInfo(metaclass=ABCMeta):
[docs] @classmethod
def from_java(cls, java_value_info: JavaObject) -> "ValueInfo":
className = java_value_info.getClass().getName()
if className == "ai.onnxruntime.TensorInfo":
return TensorInfo.from_java(java_value_info)
elif className == "ai.onnxruntime.MapInfo":
return MapInfo.from_java(java_value_info)
else:
return SequenceInfo.from_java(java_value_info)
[docs]class TensorInfo(ValueInfo):
def __init__(self, shape: List[int], type: str):
self.shape = shape
self.type = type
def __repr__(self):
return str(self)
def __str__(self):
return "TensorInfo(shape={}, type={})".format(
"[" + ",".join(map(str, self.shape)) + "]",
self.type,
)
[docs] @classmethod
def from_java(cls, java_tensor_info: JavaObject) -> "TensorInfo":
shape = list(java_tensor_info.getShape())
type = java_gateway.get_field(java_tensor_info, "type").toString()
return cls(shape, type)
[docs]class MapInfo(ValueInfo):
def __init__(self, key_type: str, value_type: str, size: int = -1):
self.key_type = key_type
self.value_type = value_type
self.size = size
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
initial = (
"MapInfo(size=UNKNOWN"
if self.size == -1
else "MapInfo(size=" + str(self.size)
)
return (
initial
+ ",keyType="
+ self.key_type
+ ",valueType="
+ self.value_type
+ ")"
)
[docs] @classmethod
def from_java(cls, java_map_info: JavaObject) -> "MapInfo":
if java_map_info == None:
return None
else:
key_type = java_gateway.get_field(java_map_info, "keyType").toString()
value_type = java_gateway.get_field(java_map_info, "valueType").toString()
size = java_gateway.get_field(java_map_info, "size")
return cls(key_type, value_type, size)
[docs]class SequenceInfo(ValueInfo):
def __init__(
self,
length: int,
sequence_of_maps: bool,
map_info: MapInfo,
sequence_type: str,
):
self.length = length
self.sequence_of_maps = sequence_of_maps
self.map_info = map_info
self.sequence_type = sequence_type
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
initial = "SequenceInfo(length=" + (
"UNKNOWN" if self.length == -1 else str(self.length)
)
if self.sequence_of_maps:
initial += ",type=" + str(self.map_info) + ")"
else:
initial += ",type=" + str(self.sequence_type) + ")"
return initial
[docs] @classmethod
def from_java(cls, java_sequence_info: JavaObject) -> "SequenceInfo":
length = java_gateway.get_field(java_sequence_info, "length")
sequence_of_maps = java_gateway.get_field(java_sequence_info, "sequenceOfMaps")
map_info = MapInfo.from_java(
java_gateway.get_field(java_sequence_info, "mapInfo"),
)
sequence_type = java_gateway.get_field(
java_sequence_info,
"sequenceType",
).toString()
return cls(length, sequence_of_maps, map_info, sequence_type)