TF 2.0 is out! Get hands-on practice at TF World, Oct 28-31. Use code TF20 for 20% off select passes. Register now

Module: tf.contrib.model_pruning

View source on GitHub

Model pruning implementation in tensorflow.

Classes

class MaskedBasicLSTMCell: Basic LSTM recurrent network cell with pruning.

class MaskedLSTMCell: LSTMCell with pruning.

class Pruning

Functions

apply_mask(...): Apply mask to a given weight tensor.

get_masked_weights(...)

get_masks(...)

get_pruning_hparams(...): Get a tf.HParams object with the default values for the hyperparameters.

get_thresholds(...)

get_weight_sparsity(...): Get sparsity of the weights.

get_weights(...)

graph_def_from_checkpoint(...): Converts checkpoint data to GraphDef.

masked_conv2d(...): Adds an 2D convolution followed by an optional batch_norm layer.

masked_convolution(...): Adds an 2D convolution followed by an optional batch_norm layer.

masked_fully_connected(...): Adds a sparse fully connected layer. The weight matrix is masked.

strip_pruning_vars_fn(...): Removes mask variable from the graph.

train(...): Wrapper around tf-slim's train function.