tfm.vision.heads.DetectionHead

Creates a detection head.

num_classes An int for the number of classes.
num_convs An int number that represents the number of the intermediate convolution layers before the FC layers.
num_filters An int number that represents the number of filters of the intermediate convolution layers.
use_separable_conv A bool that indicates whether the separable convolution layers is used.
num_fcs An int number that represents the number of FC layers before the predictions.
fc_dims An int number that represents the number of dimension of the FC layers.
class_agnostic_bbox_pred bool, indicating whether bboxes should be predicted for every class or not.
activation A str that indicates which activation is used, e.g. 'relu', 'swish', etc.
use_sync_bn A bool that indicates whether to use synchronized batch normalization across different replicas.
norm_momentum A float of normalization momentum for the moving average.
norm_epsilon A float added to variance to avoid dividing by zero.
kernel_regularizer A tf.keras.regularizers.Regularizer object for Conv2D. Default is None.
bias_regularizer A tf.keras.regularizers.Regularizer object for Conv2D.
**kwargs Additional keyword arguments to be passed.

Methods

call

View source

Forward pass of box and class branches for the Mask-RCNN model.

Args
inputs A tf.Tensor of the shape [batch_size, num_instances, roi_height, roi_width, roi_channels], representing the ROI features.
training a bool indicating whether it is in training mode.

Returns
class_outputs A tf.Tensor of the shape [batch_size, num_rois, num_classes], representing the class predictions.
box_outputs A tf.Tensor of the shape [batch_size, num_rois, num_classes * 4], representing the box predictions.