Revision b08f3062c96677de266af26767634fd7c6e6611d authored by Alexander G. de G. Matthews on 09 September 2016, 10:59:46 UTC, committed by James Hensman on 09 September 2016, 10:59:46 UTC
* Renaming tf_hacks to tf_wraps * Changing tf_hacks to tf_wraps in code. * adding a tf_hacks file that raises deprecationwarnings * release notes * bumpng version on docs * importing tf_hacks, tf_wraps
1 parent 61b0659
tri_to_vec.cc
// Copyright 2016 Mark van der Wilk
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
REGISTER_OP("TriToVec")
.Attr("T: realnumbertype")
.Input("trimat: T")
.Output("vec: T")
.Doc(R"doc(
Converts a series of triangular matrices to a series of vectors (i.e. a matrix).
If the input is D x N x N, then the output is D x M, where the lower
triangle of each N x N matrix has been packed into an M-vector.
)doc");
using namespace tensorflow;
template <typename T>
class TriToVecOp : public OpKernel {
public:
explicit TriToVecOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// Grab the input tensor
const Tensor& input_tensor = context->input(0);
const TensorShape& input_shape = input_tensor.shape();
const int rank = input_shape.dims();
// For now, keep it as just a matrix
OP_REQUIRES(context, rank == 3,
errors::InvalidArgument("TriToVec expects a rank-3 tensor, received shape: ",
input_shape.DebugString()));
const int k = input_shape.dim_size(rank - 1); // Matrix size
OP_REQUIRES(context, k == input_shape.dim_size(rank - 2),
errors::InvalidArgument("input's last two dimensions must be equal, received shape: ",
input_shape.DebugString()));
auto f = input_tensor.flat_inner_dims<T, 3>();
// Create an output tensor
TensorShape out_shape({input_shape.dim_size(rank - 3), k * (k+1) / 2});
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape,
&output_tensor));
auto output = output_tensor->template flat<T>();
int i = 0;
for (int z = 0; z != f.dimension(0); z++) {
for (int y = 0; y != f.dimension(1); y++) {
for (int x = 0; x != f.dimension(2); x++) {
if (y >= x) {
output(i) = f(z, y, x);
i++;
}
}
}
}
}
};
#define REGISTER_KERNEL(type) \
REGISTER_KERNEL_BUILDER( \
Name("TriToVec") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
TriToVecOp<type> \
);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
#undef REGISTER_KERNEL

Computing file changes ...