TensorLabel est un wrapper utilitaire pour TensorBuffers avec des étiquettes significatives sur un axe.
Par exemple, un modèle de classification d'images peut avoir un tenseur de sortie de forme {1, 10}, où 1 est la taille du lot et 10 est le nombre de catégories. En fait, sur le 2ème axe, on pourrait étiqueter chaque sous-tenseur avec le nom ou la description de chaque catégorie correspondante. TensorLabel
pourrait aider à convertir le Tensor simple dans TensorBuffer
en une carte d'étiquettes prédéfinies en sous-tenseurs. Dans ce cas, si 10 étiquettes sont fournies pour le 2ème axe, TensorLabel
pourrait convertir le Tensor {1, 10} d'origine en une carte de 10 éléments, dont chaque valeur est un Tensor en forme {} (scalaire). Exemple d'utilisation :
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Remarque : actuellement, nous ne prenons en charge la conversion tenseur en carte que pour la première étiquette dont la taille est supérieure à 1.
Constructeurs Publics
TensorLabel ( Map < Integer , List < String >> axisLabels, TensorBuffer tensorBuffer) Crée un objet TensorLabel capable d'étiqueter sur les axes de tenseurs multidimensionnels. | |
TensorLabel ( Liste <String> axisLabels, TensorBuffer tensorBuffer) Crée un objet TensorLabel capable d'étiqueter sur un axe de tenseurs multidimensionnels. |
Méthodes publiques
Liste < Catégorie > | getCatégorieListe () Obtient une liste de Category à partir de l'objet TensorLabel . |
Carte < String , Float > | getMapWithFloatValue () Obtient une carte qui mappe l'étiquette à float. |
Carte < String , TensorBuffer > | getMapWithTensorBuffer () Obtient la carte avec une paire de l'étiquette et du TensorBuffer correspondant. |
Méthodes héritées
Constructeurs Publics
public TensorLabel ( Map < Integer , List < String >> axisLabels, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter sur les axes de tenseurs multidimensionnels.
Paramètres
axisLabels | Une carte dont la clé est l'identifiant de l'axe (à partir de 0) et la valeur correspond aux étiquettes correspondantes. Remarque : La taille des étiquettes doit être la même que la taille du tenseur sur cet axe. |
---|---|
tensorBuffer | Le TensorBuffer à étiqueter. |
Jetés
NullPointerException | si axisLabels ou tensorBuffer est nul, ou si toute valeur dans axisLabels est nulle. |
---|---|
Exception d'argument illégal | si une clé dans axisLabels est hors plage (par rapport à la forme de tensorBuffer , ou si toute valeur (étiquettes) a une taille différente avec le tensorBuffer sur la dimension donnée. |
public TensorLabel ( Liste <String> axisLabels, TensorBuffer tensorBuffer)
Crée un objet TensorLabel capable d'étiqueter sur un axe de tenseurs multidimensionnels.
Remarque : Les étiquettes sont appliquées sur le premier axe dont la taille est supérieure à 1. Par exemple, si la forme du tenseur est [1, 10, 3], les étiquettes seront appliquées sur l'axe 1 (id commençant à 0), et la taille de axisLabels
devrait également être de 10.
Paramètres
axisLabels | Une liste d'étiquettes dont la taille doit être la même que la taille du tenseur sur l'axe à étiqueter. |
---|---|
tensorBuffer | Le TensorBuffer à étiqueter. |
Méthodes publiques
liste publique < Catégorie > getCategoryList ()
Obtient une liste de Category
à partir de l'objet TensorLabel
.
L'axe de l'étiquette doit être effectivement le dernier axe (ce qui signifie que chaque sous-tenseur spécifié par cet axe doit avoir une taille plate de 1), de sorte que chaque sous-tenseur étiqueté puisse être converti en un score de valeur flottante. Exemple : Un TensorLabel
avec la forme {2, 5, 3}
et l'axe 2 est valide. Si axis est 1 ou 0, il ne peut pas être converti en Category
.
getMapWithFloatValue()
est une alternative mais renvoie une Map
comme résultat.
Jetés
IllegalStateException | si la taille d'un sous-tenseur sur chaque étiquette n'est pas 1. |
---|
public Map < String , Float > getMapWithFloatValue ()
Obtient une carte qui mappe l'étiquette à float. Autorisez uniquement le mappage sur le premier axe avec une taille supérieure à 1, et l'axe doit être effectivement le dernier axe (ce qui signifie que chaque sous-tenseur spécifié par cet axe doit avoir une taille plate de 1).
getCategoryList()
est une API alternative pour obtenir le résultat.
Jetés
IllegalStateException | si la taille d'un sous-tenseur sur chaque étiquette n'est pas 1. |
---|
public Map < String , TensorBuffer > getMapWithTensorBuffer ()
Obtient la carte avec une paire de l'étiquette et du TensorBuffer correspondant. Autoriser uniquement le mappage sur le premier axe de taille supérieure à 1 actuellement.