public final class Tensor<T> extends Object implements AutoCloseable
Instances of a Tensor are not thread-safe.
WARNING: Resources consumed by the Tensor object must be explicitly freed by
invoking the close()
method when the object is no longer needed. For example, using a
try-with-resources block:
try (Tensor t = Tensor.create(...)) {
doSomethingWith(t);
}
Modifier and Type | Method and Description |
---|---|
boolean |
booleanValue()
Returns the value in a scalar
Boolean tensor. |
byte[] |
bytesValue()
Returns the value in a scalar
String tensor. |
void |
close()
Release resources associated with the Tensor.
|
<U> U |
copyTo(U dst)
Copies the contents of the tensor to
dst and returns dst . |
static <T> Tensor<T> |
create(Class<T> type,
long[] shape,
ByteBuffer data)
Create a Tensor of any type with data from the given buffer.
|
static Tensor<Double> |
create(long[] shape,
DoubleBuffer data)
Create a
Double Tensor with data from the given buffer. |
static Tensor<Float> |
create(long[] shape,
FloatBuffer data)
Create a
Float Tensor with data from the given buffer. |
static Tensor<Integer> |
create(long[] shape,
IntBuffer data)
Create a
Integer Tensor with data from the given buffer. |
static Tensor<Long> |
create(long[] shape,
LongBuffer data)
Create an
Long Tensor with data from the given buffer. |
static Tensor<?> |
create(Object obj)
Creates a tensor from an object whose class is inspected to figure out what the underlying data
type should be.
|
static <T> Tensor<T> |
create(Object obj,
Class<T> type)
Creates a Tensor from a Java object.
|
DataType |
dataType()
Returns the
DataType of elements stored in the Tensor. |
double |
doubleValue()
Returns the value in a scalar
Double tensor. |
<U> Tensor<U> |
expect(Class<U> type)
Returns this Tensor object with the type
Tensor<U> . |
float |
floatValue()
Returns the value in a scalar
Float tensor. |
int |
intValue()
Returns the value in a scalar
Integer tensor. |
long |
longValue()
Returns the value in a scalar
Long tensor. |
int |
numBytes()
Returns the size, in bytes, of the tensor data.
|
int |
numDimensions()
Returns the number of dimensions (sometimes referred to as rank) of the Tensor.
|
int |
numElements()
Returns the number of elements in a flattened (1-D) view of the tensor.
|
long[] |
shape()
Returns the shape of
the Tensor, i.e., the sizes of each dimension.
|
String |
toString()
Returns a string describing the type and shape of the Tensor.
|
void |
writeTo(ByteBuffer dst)
Write the tensor data into the given buffer.
|
void |
writeTo(DoubleBuffer dst)
Write the data of a
Double tensor into the given buffer. |
void |
writeTo(FloatBuffer dst)
Write the data of a
Float tensor into the given buffer. |
void |
writeTo(IntBuffer dst)
Write the data of a
Integer tensor into the given buffer. |
void |
writeTo(LongBuffer dst)
Write the data of a
Long tensor into the given buffer. |
public static <T> Tensor<T> create(Object obj, Class<T> type)
A Tensor
is a multi-dimensional array of elements of a limited set of types. Not all
Java objects can be converted to a Tensor
. In particular, the argument obj
must
be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of
one of those primitives. The argument type
specifies how to interpret the first
argument as a TensorFlow type. For example:
// Valid: A 64-bit integer scalar.
Tensor<Long> s = Tensor.create(42L, Long.class);
// Valid: A 3x2 matrix of floats.
float[][] matrix = new float[3][2];
Tensor<Float> m = Tensor.create(matrix, Float.class);
// Invalid: Will throw an IllegalArgumentException as an arbitrary Object
// does not fit into the TensorFlow type system.
Tensor<?> o = Tensor.create(new Object())
// Invalid: Will throw an IllegalArgumentException since there are
// a differing number of elements in each row of this 2-D array.
int[][] twoD = new int[2][];
twoD[0] = new int[1];
twoD[1] = new int[2];
Tensor<Integer> x = Tensor.create(twoD, Integer.class);
String
-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
be initialized from arrays of byte[]
elements. For example:
// Valid: A String tensor.
Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
// Java Strings will need to be encoded into a byte-sequence.
String mystring = "foo";
Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
// Valid: Matrix of String tensors.
// Each element might have a different length.
byte[][][] matrix = new byte[2][2][];
matrix[0][0] = "this".getBytes("UTF-8");
matrix[0][1] = "is".getBytes("UTF-8");
matrix[1][0] = "a".getBytes("UTF-8");
matrix[1][1] = "matrix".getBytes("UTF-8");
Tensor<String> m = Tensor.create(matrix, String.class);
obj
- The object to convert to a Tensor<T>
. Note that whether it is compatible
with the type T is not checked by the type system. For type-safe creation of tensors, use
Tensors
.type
- The class object representing the type T.IllegalArgumentException
- if obj
is not compatible with the TensorFlow type
system.public static Tensor<?> create(Object obj)
IllegalArgumentException
- if obj
is not compatible with the TensorFlow type
system.public static Tensor<Integer> create(long[] shape, IntBuffer data)
Integer
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor<Float> create(long[] shape, FloatBuffer data)
Float
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor<Double> create(long[] shape, DoubleBuffer data)
Double
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static Tensor<Long> create(long[] shape, LongBuffer data)
Long
Tensor with data from the given buffer.
Creates a Tensor with the given shape by copying elements from the buffer (starting from its
current position) into the tensor. For example, if shape = {2,3}
(which represents a
2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
method.
shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor shape is not compatible with the bufferpublic static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data)
Creates a Tensor with the provided shape of any type where the tensor's data has been
encoded into data
as per the specification of the TensorFlow C
API.
T
- the tensor element typetype
- the tensor element type, represented as a class object.shape
- the tensor shape.data
- a buffer containing the tensor data.IllegalArgumentException
- If the tensor datatype or shape is not compatible with the
bufferpublic <U> Tensor<U> expect(Class<U> type)
Tensor<U>
. This method is useful when given a
value of type Tensor<?>
.type
- any (non-null) array of the correct type.IllegalArgumentException
- if the actual data type of this object does not match the type
U
.public void close()
WARNING:This must be invoked for all tensors that were not been produced by an eager operation or memory will be leaked.
The Tensor object is no longer usable after close
returns.
close
in interface AutoCloseable
public int numDimensions()
Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
public int numBytes()
public int numElements()
public long[] shape()
public float floatValue()
Float
tensor.IllegalArgumentException
- if the Tensor does not represent a float scalar.public double doubleValue()
Double
tensor.IllegalArgumentException
- if the Tensor does not represent a double scalar.public int intValue()
Integer
tensor.IllegalArgumentException
- if the Tensor does not represent a int scalar.public long longValue()
Long
tensor.IllegalArgumentException
- if the Tensor does not represent a long scalar.public boolean booleanValue()
Boolean
tensor.IllegalArgumentException
- if the Tensor does not represent a boolean scalar.public byte[] bytesValue()
String
tensor.IllegalArgumentException
- if the Tensor does not represent a boolean scalar.public <U> U copyTo(U dst)
dst
and returns dst
.
For non-scalar tensors, this method copies the contents of the underlying tensor to a Java
array. For scalar tensors, use one of bytesValue()
, floatValue()
, doubleValue()
, intValue()
, longValue()
or booleanValue()
instead.
The type and shape of dst
must be compatible with the tensor. For example:
int matrix[2][2] = {{1,2},{3,4}};
try(Tensor t = Tensor.create(matrix)) {
// Succeeds and prints "3"
int[][] copy = new int[2][2];
System.out.println(t.copyTo(copy)[1][0]);
// Throws IllegalArgumentException since the shape of dst does not match the shape of t.
int[][] dst = new int[4][1];
t.copyTo(dst);
}
IllegalArgumentException
- if the tensor is a scalar or if dst
is not compatible
with the tensor (for example, mismatched data types or shapes).public void writeTo(IntBuffer dst)
Integer
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor data type is not Integer
public void writeTo(FloatBuffer dst)
Float
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not Float
public void writeTo(DoubleBuffer dst)
Double
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not Double
public void writeTo(LongBuffer dst)
Long
tensor into the given buffer.
Copies numElements()
elements to the buffer.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorIllegalArgumentException
- If the tensor datatype is not Long
public void writeTo(ByteBuffer dst)
Copies numBytes()
bytes to the buffer in native byte order for primitive types.
dst
- the destination bufferBufferOverflowException
- If there is insufficient space in the given buffer for the data
in this tensorCopyright © 2022. All rights reserved.