TensorLabel to użyteczne opakowanie dla TensorBuffers ze znaczącymi etykietami na osi.
Na przykład model klasyfikacji obrazu może mieć tensor wyjściowy o kształcie {1, 10}, gdzie 1 to rozmiar partii, a 10 to liczba kategorii. W rzeczywistości na drugiej osi moglibyśmy oznaczyć każdy podtensor nazwą lub opisem każdej odpowiedniej kategorii. TensorLabel
może pomóc w przekonwertowaniu zwykłego Tensora w TensorBuffer
na mapę z predefiniowanych etykiet na pod-tensory. W tym przypadku, jeśli podano 10 etykiet dla drugiej osi, TensorLabel
może przekonwertować oryginalny Tensor {1, 10} na mapę składającą się z 10 elementów, której każda wartość ma kształt tensora {} (skalar). Przykład użycia:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Uwaga: obecnie obsługujemy konwersję tensora na mapę tylko dla pierwszej etykiety o rozmiarze większym niż 1.
Konstruktorzy publiczni
TensorLabel ( Mapa < Liczba całkowita , Lista < Ciąg >> ośLabels, TensorBuffer tensorBuffer) Tworzy obiekt TensorLabel, który może oznaczać osie wielowymiarowych tensorów. | |
TensorLabel ( List <String> axisLabels, TensorBuffer tensorBuffer) Tworzy obiekt TensorLabel, który może opisywać na jednej osi wielowymiarowe tensory. |
Metody publiczne
Lista < Kategoria > | pobierz listę kategorii () Pobiera listę Category z obiektu TensorLabel . |
Mapa < String , Float > | getMapWithFloatValue () Pobiera mapę mapującą etykietę do pływającej. |
Mapa < String , TensorBuffer > | getMapWithTensorBuffer () Pobiera mapę z parą etykiety i odpowiedniego TensorBuffer. |
Metody dziedziczone
Konstruktorzy publiczni
public TensorLabel ( Map < Integer , List < String >> osiLabels, TensorBuffer tensorBuffer)
Tworzy obiekt TensorLabel, który może oznaczać osie wielowymiarowych tensorów.
Parametry
Etykiety osi | Mapa, której kluczem jest identyfikator osi (zaczynając od 0), a wartością odpowiadające etykiety. Uwaga: Rozmiar etykiet powinien być taki sam jak rozmiar tensora na tej osi. |
---|---|
bufor tensora | TensorBuffer, który ma być oznaczony etykietą. |
Rzuca
Wyjątek NullPointer | jeśli axisLabels lub tensorBuffer ma wartość null lub dowolna wartość w axisLabels ma wartość null. |
---|---|
Wyjątek IllegalArgument | jeśli którykolwiek klucz axisLabels jest poza zakresem (w porównaniu z kształtem tensorBuffer lub dowolna wartość (etykiety) ma inny rozmiar z tensorBuffer w danym wymiarze. |
public TensorLabel ( List <String> axisLabels, TensorBuffer tensorBuffer)
Tworzy obiekt TensorLabel, który może opisywać na jednej osi wielowymiarowe tensory.
Uwaga: Etykiety nanoszone są na pierwszą oś, której rozmiar jest większy niż 1. Przykładowo, jeśli kształt tensora wynosi [1, 10, 3], etykiety zostaną naniesione na oś 1 (identyfikator zaczynający się od 0), i rozmiar axisLabels
również powinien wynosić 10.
Parametry
Etykiety osi | Lista etykiet, których rozmiar powinien być taki sam jak rozmiar tensora na osi, która ma być etykietowana. |
---|---|
bufor tensora | TensorBuffer, który ma być oznaczony etykietą. |
Metody publiczne
Lista publiczna < Kategoria > getCategoryList ()
Pobiera listę Category
z obiektu TensorLabel
.
Oś etykiety powinna być w rzeczywistości ostatnią osią (co oznacza, że każdy podtensor określony przez tę oś powinien mieć płaski rozmiar 1), tak aby każdy oznaczony podtensor mógł zostać przekonwertowany na wynik wartości zmiennoprzecinkowej. Przykład: TensorLabel
o kształcie {2, 5, 3}
i osi 2 jest prawidłowy. Jeśli oś ma wartość 1 lub 0, nie można jej przekształcić w Category
.
getMapWithFloatValue()
jest alternatywą, ale w rezultacie zwraca Map
.
Rzuca
Wyjątek IllegalStateException | jeśli rozmiar podtensora na każdej etykiecie nie wynosi 1. |
---|
mapa publiczna < String , Float > getMapWithFloatValue ()
Pobiera mapę mapującą etykietę do pływającej. Zezwalaj na mapowanie tylko na pierwszej osi o rozmiarze większym niż 1, a oś powinna być faktycznie ostatnią osią (co oznacza, że każdy podtensor określony przez tę oś powinien mieć płaski rozmiar 1).
getCategoryList()
to alternatywny interfejs API umożliwiający uzyskanie wyniku.
Rzuca
Wyjątek IllegalStateException | jeśli rozmiar podtensora na każdej etykiecie nie wynosi 1. |
---|
public Map < String , TensorBuffer > getMapWithTensorBuffer ()
Pobiera mapę z parą etykiety i odpowiedniego TensorBuffer. Zezwalaj na mapowanie tylko na pierwszej osi o rozmiarze większym niż 1 obecnie.