View source on GitHub |
MinDiffKernel abstract base class.
model_remediation.min_diff.losses.MinDiffKernel(
tile_input: bool = True
)
Arguments | |
---|---|
tile_input
|
Boolean indicating whether to tile inputs before computing the kernel (see below for details). |
To be implemented by subclasses:
call()
: contains the logic for the kernel tensor calculation.
Example subclass Implementation:
class GuassKernel(MinDiffKernel):
def call(x, y):
return tf.exp(-tf.reduce_sum(tf.square(x - y), axis=2) / 0.01)
"Tiling" is a way of expanding the rank of the input tensors so that their dimensions work for the operations we need.
If x
and y
are of rank [N, D]
and [M, D]
respectively, tiling expands
them to be: [N, ?, D]
and [?, M, D]
where tf
broadcasting will ensure
that the operations between them work.
Methods
call
@abc.abstractmethod
call( x: types.TensorType, y: types.TensorType )
Invokes the MinDiffKernel
instance.
Arguments | |
---|---|
x
|
tf.Tensor of shape [N, M, D] .
|
y
|
tf.Tensor of shape [N, M, D] .
|
This method contains the logic for computing the kernel. It must be implemented by subclasses.
Returns | |
---|---|
tf.Tensor of shape [N, M] .
|
from_config
@classmethod
from_config( config )
Creates a MinDiffKernel
instance fron the config.
Any subclass with additional attributes or a different initialization
signature will need to override this method or get_config
.
Returns | |
---|---|
A new MinDiffKernel instance corresponding to config .
|
get_config
get_config()
Creates a config dictionary for the MinDiffKernel
instance.
Any subclass with additional attributes will need to override this method.
When doing so, users will mostly likely want to first call super
.
Returns | |
---|---|
A config dictionary for the MinDiffKernel isinstance.
|
__call__
__call__(
x: types.TensorType, y: Optional[types.TensorType] = None
) -> types.TensorType
Invokes the kernel instance.
Arguments | |
---|---|
x
|
tf.Tensor of shape [N, D] (if tiling input) or [N, M, D] (if not
tiling input).
|
y
|
Optional tf.Tensor of shape [M, D] (if tiling input) or [N, M, D]
(if not tiling input).
|
If y
is None
, it is set to be the same as x
:
if y is None:
y = x
Inputs are tiled if self.tile_input == True
and left as is otherwise.
Returns | |
---|---|
tf.Tensor of shape [N, M] .
|