TensorLabel — это служебная оболочка для TensorBuffers со значимыми метками на оси.
Например, модель классификации изображений может иметь выходной тензор с формой {1, 10}, где 1 — размер пакета, а 10 — количество категорий. Фактически, на второй оси мы могли бы пометить каждый субтензор именем или описанием каждой соответствующей категории. TensorLabel
может помочь преобразовать простой Tensor в TensorBuffer
в карту из предопределенных меток в субтензоры. В этом случае, если для второй оси предоставлено 10 меток, TensorLabel
может преобразовать исходный тензор {1, 10} в карту из 10 элементов, каждое значение которой представляет собой тензор в форме {} (скаляр). Пример использования:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Примечание. В настоящее время мы поддерживаем преобразование тензора в карту только для первой метки с размером больше 1.
Публичные конструкторы
TensorLabel ( Map < Integer , List < String >> axisLabels, TensorBuffer tensorBuffer) Создает объект TensorLabel, который может размечать оси многомерных тензоров. | |
TensorLabel ( List < String > axisLabels, TensorBuffer tensorBuffer) Создает объект TensorLabel, который может маркировать по одной оси многомерные тензоры. |
Публичные методы
Список < Категория > | получитьСписокКатегорий () Получает список Category из объекта TensorLabel . |
Карта < Строка , Плавающее число > | getMapWithFloatValue () Получает карту, которая отображает метку как плавающую. |
Map < String , TensorBuffer > | getMapWithTensorBuffer () Получает карту с парой метки и соответствующим TensorBuffer. |
Унаследованные методы
Публичные конструкторы
public TensorLabel ( Map < Integer , List < String >> axisLabels, TensorBuffer tensorBuffer)
Создает объект TensorLabel, который может размечать оси многомерных тензоров.
Параметры
Осевые метки | Карта, ключом которой является идентификатор оси (начиная с 0), а значением — соответствующие метки. Примечание. Размер меток должен совпадать с размером тензора на этой оси. |
---|---|
ТензорБуфер | TensorBuffer, который нужно пометить. |
Броски
Исключение нулевого указателя | если axisLabels или tensorBuffer имеют значение null или любое значение в axisLabels имеет значение null. |
---|---|
IllegalArgumentException | если какой-либо ключ в axisLabels выходит за пределы диапазона (по сравнению с формой tensorBuffer или любое значение (метки) имеет другой размер с tensorBuffer в данном измерении. |
public TensorLabel ( List < String > axisLabels, TensorBuffer tensorBuffer)
Создает объект TensorLabel, который может маркировать по одной оси многомерные тензоры.
Примечание. Метки применяются к первой оси, размер которой больше 1. Например, если форма тензора равна [1, 10, 3], метки будут применены к оси 1 (идентификатор начинается с 0), и размер axisLabels
также должен быть 10.
Параметры
Осевые метки | Список меток, размер которых должен совпадать с размером тензора на оси, подлежащей маркировке. |
---|---|
ТензорБуфер | TensorBuffer, который нужно пометить. |
Публичные методы
общедоступный список < Категория > getCategoryList ()
Получает список Category
из объекта TensorLabel
.
Ось метки должна фактически быть последней осью (это означает, что каждый субтензор, указанный этой осью, должен иметь плоский размер, равный 1), чтобы каждый помеченный субтензор можно было преобразовать в оценку с плавающей запятой. Пример: допустима TensorLabel
с формой {2, 5, 3}
и осью 2. Если ось равна 1 или 0, ее нельзя преобразовать в Category
.
getMapWithFloatValue()
является альтернативой, но в качестве результата возвращает Map
.
Броски
IllegalStateException | если размер субтензора на каждой метке не равен 1. |
---|
public Map < String , Float > getMapWithFloatValue ()
Получает карту, которая отображает метку как плавающую. Разрешить отображение только на первой оси с размером больше 1, и эта ось должна быть фактически последней осью (это означает, что каждый субтензор, указанный этой осью, должен иметь плоский размер 1).
getCategoryList()
— альтернативный API для получения результата.
Броски
IllegalStateException | если размер субтензора на каждой метке не равен 1. |
---|
общедоступная карта < String , TensorBuffer > getMapWithTensorBuffer ()
Получает карту с парой метки и соответствующим TensorBuffer. В настоящее время разрешено отображение только на первой оси с размером больше 1.