Adds operations to compute the partial derivatives of sum of y
s w.r.t x
s,
i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
If Options.dx()
values are set, they are as the initial symbolic partial derivatives of some loss
function L
w.r.t. y
. Options.dx()
must have the size of y
.
If Options.dx()
is not set, the implementation will use dx of OnesLike
for all
shapes in y
.
The partial derivatives are returned in output dy
, with the size of x
.
Example of usage:
Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
Constant<Float> alpha = ops.constant(1.0f, Float.class);
ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1));
Nested Classes
class | Gradients.Options | Optional attributes for Gradients
|
Public Methods
static Gradients |
create(Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options)
Adds gradients computation ops to the graph according to scope.
|
static Gradients |
create(Scope scope, Iterable<? extends Operand<?>> y, Iterable<? extends Operand<?>> x, Options... options)
Adds gradients computation ops to the graph according to scope.
|
static Gradients.Options | |
<T> Output<T> |
dy(int index)
Returns a symbolic handle to one of the gradient operation output
Warning: Does not check that the type of the tensor matches T. |
List<Output<?>> |
dy()
Partial derivatives of
y s w.r.t. |
Iterator<Operand<?>> |
iterator()
|
Inherited Methods
Public Methods
public static Gradients create (Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options)
Adds gradients computation ops to the graph according to scope.
This is a simplified version of ERROR(/#create(Scope, Iterable, Iterable, Options...))
where
y
is a single output.
Parameters
scope | current graph scope |
---|---|
y | output of the function to derive |
x | inputs of the function for which partial derivatives are computed |
options | carries optional attributes values |
Returns
- a new instance of
Gradients
Throws
IllegalArgumentException | if execution environment is not a graph |
---|
public static Gradients create (Scope scope, Iterable<? extends Operand<?>> y, Iterable<? extends Operand<?>> x, Options... options)
Adds gradients computation ops to the graph according to scope.
Parameters
scope | current graph scope |
---|---|
y | outputs of the function to derive |
x | inputs of the function for which partial derivatives are computed |
options | carries optional attributes values |
Returns
- a new instance of
Gradients
Throws
IllegalArgumentException | if execution environment is not a graph |
---|
public static Gradients.Options dx (Iterable<? extends Operand<?>> dx)
Parameters
dx | partial derivatives of some loss function L w.r.t. y |
---|
Returns
- builder to add more options to this operation
public Output<T> dy (int index)
Returns a symbolic handle to one of the gradient operation output
Warning: Does not check that the type of the tensor matches T. It is recommended to call
this method with an explicit type parameter rather than letting it be inferred, e.g. gradients.<Float>dy(0)
Parameters
index | The index of the output among the gradients added by this operation |
---|