public final class Tensor<T> extends Object implements AutoCloseable
Instances of a Tensor are not thread-safe.
WARNING: Resources consumed by the Tensor object must be explicitly freed by
invoking the close() method when the object is no longer needed. For example, using a
try-with-resources block:
try (Tensor t = Tensor.create(...)) {
doSomethingWith(t);
}
| Modifier and Type | Method and Description |
|---|---|
boolean |
booleanValue()
Returns the value in a scalar
Boolean tensor. |
byte[] |
bytesValue()
Returns the value in a scalar
String tensor. |
void |
close()
Release resources associated with the Tensor.
|
<U> U |
copyTo(U dst)
Copies the contents of the tensor to
dst and returns dst. |
static <T> Tensor<T> |
create(Class<T> type,
long[] shape,
ByteBuffer data)
Create a Tensor of any type with data from the given buffer.
|
static Tensor<Double> |
create(long[] shape,
DoubleBuffer data)
Create a
Double Tensor with data from the given buffer. |
static Tensor<Float> |
create(long[] shape,
FloatBuffer data)
Create a
Float Tensor with data from the given buffer. |
static Tensor<Integer> |
create(long[] shape,
IntBuffer data)
Create a
Integer Tensor with data from the given buffer. |
static Tensor<Long> |
create(long[] shape,
LongBuffer data)
Create an
Long Tensor with data from the given buffer. |
static Tensor<?> |
create(Object obj)
Creates a tensor from an object whose class is inspected to figure out what the underlying data
type should be.
|
static <T> Tensor<T> |
create(Object obj,
Class<T> type)
Creates a Tensor from a Java object.
|
DataType |
dataType()
Returns the
DataType of elements stored in the Tensor. |
double |
doubleValue()
Returns the value in a scalar
Double tensor. |
<U> Tensor<U> |
expect(Class<U> type)
Returns this Tensor object with the type
Tensor<U>. |
float |
floatValue()
Returns the value in a scalar
Float tensor. |
int |
intValue()
Returns the value in a scalar
Integer tensor. |
long |
longValue()
Returns the value in a scalar
Long tensor. |
int |
numBytes()
Returns the size, in bytes, of the tensor data.
|
int |
numDimensions()
Returns the number of dimensions (sometimes referred to as rank) of the Tensor.
|
int |
numElements()
Returns the number of elements in a flattened (1-D) view of the tensor.
|
long[] |
shape()
Returns the shape of
the Tensor, i.e., the sizes of each dimension.
|
String |
toString()
Returns a string describing the type and shape of the Tensor.
|
void |
writeTo(ByteBuffer dst)
Write the tensor data into the given buffer.
|
void |
writeTo(DoubleBuffer dst)
Write the data of a
Double tensor into the given buffer. |
void |
writeTo(FloatBuffer dst)
Write the data of a
Float tensor into the given buffer. |
void |
writeTo(IntBuffer dst)
Write the data of a
Integer tensor into the given buffer. |
void |
writeTo(LongBuffer dst)
Write the data of a
Long tensor into the given buffer. |
public static <T> Tensor<T> create(Object obj, Class<T> type)
A Tensor is a multi-dimensional array of elements of a limited set of types. Not all
Java objects can be converted to a Tensor. In particular, the argument obj must
be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of
one of those primitives. The argument type specifies how to interpret the first
argument as a TensorFlow type. For example:
// Valid: A 64-bit integer scalar.
Tensor<Long> s = Tensor.create(42L, Long.class);
// Valid: A 3x2 matrix of floats.
float[][] matrix = new float[3][2];
Tensor<Float> m = Tensor.create(matrix, Float.class);
// Invalid: Will throw an IllegalArgumentException as an arbitrary Object
// does not fit into the TensorFlow type system.
Tensor<?> o = Tensor.create(new Object())
// Invalid: Will throw an IllegalArgumentException since there are
// a differing number of elements in each row of this 2-D array.
int[][] twoD = new int[2][];
twoD[0] = new int[1];
twoD[1] = new int[2];
Tensor<Integer> x = Tensor.create(twoD, Integer.class);
String-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
be initialized from arrays of byte[] elements. For example:
// Valid: A String tensor.
Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
// Java Strings will need to be encoded into a byte-sequence.
String mystring = "foo";
Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
// Valid: Matrix of String tensors.
// Each element might have a different length.
byte[][][] matrix = new byte[2][2][];
matrix[0][0] = "this".getBytes("UTF-8");
matrix[0][1] = "is".getBytes("UTF-8");
matrix[1][0] = "a".getBytes("UTF-8");
matrix[1][1] = "matrix".getBytes("UTF-8");
Tensor<String> m = Tensor.create(matrix, String.class);
obj - The object to convert to a Tensor<T>. Note that whether it is compatible
with the type T is not checked by the type system. For type-safe creation of tensors, use
Tensors.type - The class object representing the type T.IllegalArgumentException - if obj is not compatible with the TensorFlow type
system.public static Tensor<?> create(Object obj)
IllegalArgumentException - if obj is not compatible with the TensorFlow type
system.public static Tensor<Integer> create(long[] shape, IntBuffer data)
Integer Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3} (which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape - the tensor shape.data - a buffer containing the tensor data.IllegalArgumentException - If the tensor shape is not compatible with the bufferpublic static Tensor<Float> create(long[] shape, FloatBuffer data)
Float Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3} (which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape - the tensor shape.data - a buffer containing the tensor data.IllegalArgumentException - If the tensor shape is not compatible with the bufferpublic static Tensor<Double> create(long[] shape, DoubleBuffer data)
Double Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3} (which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape - the tensor shape.data - a buffer containing the tensor data.IllegalArgumentException - If the tensor shape is not compatible with the bufferpublic static Tensor<Long> create(long[] shape, LongBuffer data)
Long Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3} (which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape - the tensor shape.data - a buffer containing the tensor data.IllegalArgumentException - If the tensor shape is not compatible with the bufferpublic static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data)
Creates a Tensor with the provided shape of any type where the tensor's data has been
encoded into data as per the specification of the TensorFlow C
API.
T - the tensor element typetype - the tensor element type, represented as a class object.shape - the tensor shape.data - a buffer containing the tensor data.IllegalArgumentException - If the tensor datatype or shape is not compatible with the
bufferpublic <U> Tensor<U> expect(Class<U> type)
Tensor<U>. This method is useful when given a
value of type Tensor<?>.type - any (non-null) array of the correct type.IllegalArgumentException - if the actual data type of this object does not match the type
U.public void close()
WARNING:This must be invoked for all tensors that were not been produced by an eager operation or memory will be leaked.
The Tensor object is no longer usable after close returns.
close in interface AutoCloseablepublic int numDimensions()
Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
public int numBytes()
public int numElements()
public long[] shape()
public float floatValue()
Float tensor.IllegalArgumentException - if the Tensor does not represent a float scalar.public double doubleValue()
Double tensor.IllegalArgumentException - if the Tensor does not represent a double scalar.public int intValue()
Integer tensor.IllegalArgumentException - if the Tensor does not represent a int scalar.public long longValue()
Long tensor.IllegalArgumentException - if the Tensor does not represent a long scalar.public boolean booleanValue()
Boolean tensor.IllegalArgumentException - if the Tensor does not represent a boolean scalar.public byte[] bytesValue()
String tensor.IllegalArgumentException - if the Tensor does not represent a boolean scalar.public <U> U copyTo(U dst)
dst and returns dst.
For non-scalar tensors, this method copies the contents of the underlying tensor to a Java
array. For scalar tensors, use one of bytesValue(), floatValue(), doubleValue(), intValue(), longValue() or booleanValue() instead.
The type and shape of dst must be compatible with the tensor. For example:
int matrix[2][2] = {{1,2},{3,4}};
try(Tensor t = Tensor.create(matrix)) {
// Succeeds and prints "3"
int[][] copy = new int[2][2];
System.out.println(t.copyTo(copy)[1][0]);
// Throws IllegalArgumentException since the shape of dst does not match the shape of t.
int[][] dst = new int[4][1];
t.copyTo(dst);
}
IllegalArgumentException - if the tensor is a scalar or if dst is not compatible
with the tensor (for example, mismatched data types or shapes).public void writeTo(IntBuffer dst)
Integer tensor into the given buffer.
Copies numElements() elements to the buffer.
dst - the destination bufferBufferOverflowException - If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException - If the tensor data type is not Integerpublic void writeTo(FloatBuffer dst)
Float tensor into the given buffer.
Copies numElements() elements to the buffer.
dst - the destination bufferBufferOverflowException - If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException - If the tensor datatype is not Floatpublic void writeTo(DoubleBuffer dst)
Double tensor into the given buffer.
Copies numElements() elements to the buffer.
dst - the destination bufferBufferOverflowException - If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException - If the tensor datatype is not Doublepublic void writeTo(LongBuffer dst)
Long tensor into the given buffer.
Copies numElements() elements to the buffer.
dst - the destination bufferBufferOverflowException - If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException - If the tensor datatype is not Longpublic void writeTo(ByteBuffer dst)
Copies numBytes() bytes to the buffer in native byte order for primitive types.
dst - the destination bufferBufferOverflowException - If there is insufficient space in the given buffer for the data
in this tensorCopyright © 2022. All rights reserved.