@Operator public class Gradients extends Object implements Op, Iterable<Operand<?>>
y
s w.r.t x
s,
i.e., d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...
If Options.dx()
values are set, they are as the initial symbolic partial derivatives of some loss
function L
w.r.t. y
. Options.dx()
must have the size of y
.
If Options.dx()
is not set, the implementation will use dx of OnesLike
for all
shapes in y
.
The partial derivatives are returned in output dy
, with the size of x
.
Example of usage:
Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
Constant<Float> alpha = ops.constant(1.0f, Float.class);
ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1));
Modifier and Type | Class and Description |
---|---|
static class |
Gradients.Options
Optional attributes for
Gradients |
Modifier and Type | Method and Description |
---|---|
static Gradients |
create(Scope scope,
Iterable<? extends Operand<?>> y,
Iterable<? extends Operand<?>> x,
Gradients.Options... options)
Adds gradients computation ops to the graph according to scope.
|
static Gradients |
create(Scope scope,
Operand<?> y,
Iterable<? extends Operand<?>> x,
Gradients.Options... options)
Adds gradients computation ops to the graph according to scope.
|
static Gradients.Options |
dx(Iterable<? extends Operand<?>> dx) |
List<Output<?>> |
dy()
Partial derivatives of
y s w.r.t. |
<T> Output<T> |
dy(int index)
Returns a symbolic handle to one of the gradient operation output
|
Iterator<Operand<?>> |
iterator() |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
forEach, spliterator
public static Gradients create(Scope scope, Iterable<? extends Operand<?>> y, Iterable<? extends Operand<?>> x, Gradients.Options... options)
scope
- current graph scopey
- outputs of the function to derivex
- inputs of the function for which partial derivatives are computedoptions
- carries optional attributes valuesGradients
IllegalArgumentException
- if execution environment is not a graphpublic static Gradients create(Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Gradients.Options... options)
This is a simplified version of create(Scope, Iterable, Iterable, Options...)
where
y
is a single output.
scope
- current graph scopey
- output of the function to derivex
- inputs of the function for which partial derivatives are computedoptions
- carries optional attributes valuesGradients
IllegalArgumentException
- if execution environment is not a graphpublic static Gradients.Options dx(Iterable<? extends Operand<?>> dx)
dx
- partial derivatives of some loss function L
w.r.t. y
public <T> Output<T> dy(int index)
Warning: Does not check that the type of the tensor matches T. It is recommended to call
this method with an explicit type parameter rather than letting it be inferred, e.g. gradients.<Float>dy(0)
T
- The expected element type of the tensors produced by this output.index
- The index of the output among the gradients added by this operationCopyright © 2022. All rights reserved.