tfm.vision.heads.SegmentationHead

Creates a segmentation head.

num_classes An int number of mask classification categories. The number of classes does not include background class.
level An int or str, level to use to build segmentation head.
num_convs An int number of stacked convolution before the last prediction layer.
num_filters An int number to specify the number of filters used. Default is 256.
use_depthwise_convolution A bool to specify if use depthwise separable convolutions.
prediction_kernel_size An int number to specify the kernel size of the prediction layer.
upsample_factor An int number to specify the upsampling factor to generate finer mask. Default 1 means no upsampling is applied.
feature_fusion One of deeplabv3plus, pyramid_fusion, panoptic_fpn_fusion, or None. If deeplabv3plus, features from decoder_features[level] will be fused with low level feature maps from backbone. If pyramid_fusion, multiscale features will be resized and fused at the target level.
decoder_min_level An int of minimum level from decoder to use in feature fusion. It is only used when feature_fusion is set to panoptic_fpn_fusion.
decoder_max_level An int of maximum level from decoder to use in feature fusion. It is only used when feature_fusion is set to panoptic_fpn_fusion.
low_level An int of backbone level to be used for feature fusion. It is used when feature_fusion is set to deeplabv3plus.
low_level_num_filters An int of reduced number of filters for the low level features before fusing it with higher level features. It is only used when feature_fusion is set to deeplabv3plus.
num_decoder_filters An int of number of filters in the decoder outputs. It is only used when feature_fusion is set to panoptic_fpn_fusion.
activation A str that indicates which activation is used, e.g. 'relu', 'swish', etc.
use_sync_bn A bool that indicates whether to use synchronized batch normalization across different replicas.
norm_momentum A float of normalization momentum for the moving average.
norm_epsilon A float added to variance to avoid dividing by zero.
kernel_regularizer A tf.keras.regularizers.Regularizer object for Conv2D. Default is None.
bias_regularizer A tf.keras.regularizers.Regularizer object for Conv2D.
**kwargs Additional keyword arguments to be passed.

Methods

call

View source

Forward pass of the segmentation head.

It supports both a tuple of 2 tensors or 2 dictionaries. The first is backbone endpoints, and the second is decoder endpoints. When inputs are tensors, they are from a single level of feature maps. When inputs are dictionaries, they contain multiple levels of feature maps, where the key is the index of feature map.

Args
inputs A tuple of 2 feature map tensors of shape [batch, height_l, width_l, channels] or 2 dictionaries of tensors:

  • key: A str of the level of the multilevel features.
  • values: A tf.Tensor of the feature map tensors, whose shape is [batch, height_l, width_l, channels]. The first is backbone endpoints, and the second is decoder endpoints.

Returns
segmentation prediction mask: A tf.Tensor of the segmentation mask scores predicted from input features.