diff --git a/tensorflow_lite_support/cc/task/processor/BUILD b/tensorflow_lite_support/cc/task/processor/BUILD index c412dee5d..b65555010 100644 --- a/tensorflow_lite_support/cc/task/processor/BUILD +++ b/tensorflow_lite_support/cc/task/processor/BUILD @@ -40,6 +40,24 @@ cc_library_with_tflite( ], ) +cc_library_with_tflite( + name = "image_postprocessor", + srcs = ["image_postprocessor.cc"], + hdrs = ["image_postprocessor.h"], + tflite_deps = [ + ":processor", + "//tensorflow_lite_support/cc/task/vision/utils:image_tensor_specs", + ], + deps = [ + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_utils", + "//tensorflow_lite_support/cc/task/core:task_utils", + "@com_google_absl//absl/status", + ], +) + cc_library_with_tflite( name = "classification_postprocessor", srcs = ["classification_postprocessor.cc"], diff --git a/tensorflow_lite_support/cc/task/processor/image_postprocessor.cc b/tensorflow_lite_support/cc/task/processor/image_postprocessor.cc new file mode 100644 index 000000000..051badf22 --- /dev/null +++ b/tensorflow_lite_support/cc/task/processor/image_postprocessor.cc @@ -0,0 +1,130 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow_lite_support/cc/task/processor/image_postprocessor.h" + +namespace tflite { +namespace task { +namespace processor { + +namespace { + +using ::absl::StatusCode; +using ::tflite::metadata::ModelMetadataExtractor; +using ::tflite::support::CreateStatusWithPayload; +using ::tflite::support::StatusOr; +using ::tflite::support::TfLiteSupportStatus; + +constexpr int kRgbPixelBytes = 3; + +} // namespace + +/* static */ +tflite::support::StatusOr> +ImagePostprocessor::Create(core::TfLiteEngine* engine, const int output_index, + const int input_index) { + ASSIGN_OR_RETURN(auto processor, + Processor::Create( + /* num_expected_tensors = */ 1, engine, {output_index}, + /* requires_metadata = */ false)); + + RETURN_IF_ERROR(processor->Init(input_index, output_index)); + return processor; +} + +absl::Status ImagePostprocessor::Init(const int input_index, + const int output_index) { + if (input_index == -1) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Input image tensor not set. Input index found: %d", + input_index), + tflite::support::TfLiteSupportStatus::kInputTensorNotFoundError); + } + const TensorMetadata* metadata = GetTensorMetadata(output_index); + // Fallback to input metadata if output meta doesn't have norm params. + ASSIGN_OR_RETURN( + const tflite::ProcessUnit* normalization_process_unit, + ModelMetadataExtractor::FindFirstProcessUnit( + *metadata, tflite::ProcessUnitOptions_NormalizationOptions)); + if (normalization_process_unit == nullptr) { + metadata = + engine_->metadata_extractor()->GetInputTensorMetadata(input_index); + } + if (!GetTensor(output_index)->data.raw) { + return tflite::support::CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("Output tensor (%s) has no raw data.", + GetTensor(output_index)->name)); + } + output_tensor_ = GetTensor(output_index); + ASSIGN_OR_RETURN(auto output_specs, + vision::BuildImageTensorSpecs(*engine_->metadata_extractor(), + metadata, output_tensor_)); + options_ = std::make_unique( + output_specs.normalization_options.value()); + return absl::OkStatus(); +} + +absl::StatusOr ImagePostprocessor::Postprocess() { + vision::FrameBuffer::Dimension to_buffer_dimension = { + output_tensor_->dims->data[2], output_tensor_->dims->data[1]}; + size_t output_byte_size = + GetBufferByteSize(to_buffer_dimension, vision::FrameBuffer::Format::kRGB); + std::vector postprocessed_data(output_byte_size / sizeof(uint8), 0); + + if (output_tensor_->type == kTfLiteUInt8) { // No denormalization required. + core::PopulateVector(output_tensor_, &postprocessed_data); + } else if (output_tensor_->type == + kTfLiteFloat32) { // Denormalize to [0, 255] range. + uint8* denormalized_output_data = postprocessed_data.data(); + const float* output_data = + core::AssertAndReturnTypedTensor(output_tensor_).value(); + const auto norm_options = GetNormalizationOptions(); + + if (norm_options.num_values == 1) { + float mean_value = norm_options.mean_values[0]; + float std_value = norm_options.std_values[0]; + + for (size_t i = 0; i < output_byte_size / sizeof(uint8); + ++i, ++denormalized_output_data, ++output_data) { + *denormalized_output_data = static_cast(std::round(std::min( + 255.f, std::max(0.f, (*output_data) * std_value + mean_value)))); + } + } else { + for (size_t i = 0; i < output_byte_size / sizeof(uint8); + ++i, ++denormalized_output_data, ++output_data) { + *denormalized_output_data = static_cast(std::round(std::min( + 255.f, + std::max(0.f, (*output_data) * norm_options.std_values[i % 3] + + norm_options.mean_values[i % 3])))); + } + } + } + + vision::FrameBuffer::Plane postprocessed_plane = { + /*buffer=*/postprocessed_data.data(), + /*stride=*/{output_tensor_->dims->data[2] * kRgbPixelBytes, + kRgbPixelBytes}}; + auto postprocessed_frame_buffer = + vision::FrameBuffer::Create({postprocessed_plane}, to_buffer_dimension, + vision::FrameBuffer::Format::kRGB, + vision::FrameBuffer::Orientation::kTopLeft); + return *postprocessed_frame_buffer.get(); +} + +} // namespace processor +} // namespace task +} // namespace tflite diff --git a/tensorflow_lite_support/cc/task/processor/image_postprocessor.h b/tensorflow_lite_support/cc/task/processor/image_postprocessor.h new file mode 100644 index 000000000..b08567ecb --- /dev/null +++ b/tensorflow_lite_support/cc/task/processor/image_postprocessor.h @@ -0,0 +1,68 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either exPostss or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_POSTPROCESSOR_H_ +#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_POSTPROCESSOR_H_ + +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/processor/processor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" +#include "tensorflow_lite_support/cc/task/core/task_utils.h" +#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" + +namespace tflite { +namespace task { +namespace processor { + +// Process the associated output image tensor and convert it to a FrameBuffer. +// Requirement for the output tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - only RGB inputs are supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for output de-normalization. Uses input metadata +// as fallback in case output metadata isn't provided. +class ImagePostprocessor : public Postprocessor { + public: + static tflite::support::StatusOr> + Create(core::TfLiteEngine* engine, + const int output_index, + const int input_index = -1); + + // Processes the output tensor to an RGB of FrameBuffer type. + // If output tensor is of type kTfLiteFloat32, denormalize it into [0 - 255] + // via normalization parameters. + absl::StatusOr Postprocess(); + + private: + using Postprocessor::Postprocessor; + + const TfLiteTensor* output_tensor_; + + std::unique_ptr options_; + + absl::Status Init(const int input_index, const int output_index); + + const vision::NormalizationOptions& GetNormalizationOptions() { + return *options_.get(); + } +}; +} // namespace processor +} // namespace task +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_POSTPROCESSOR_H_ diff --git a/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc b/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc index 7ad4ad470..7e433395f 100644 --- a/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc +++ b/tensorflow_lite_support/cc/task/processor/image_preprocessor.cc @@ -72,9 +72,9 @@ absl::Status ImagePreprocessor::Init( const vision::FrameBufferUtils::ProcessEngine& process_engine) { frame_buffer_utils_ = vision::FrameBufferUtils::Create(process_engine); - ASSIGN_OR_RETURN(input_specs_, vision::BuildInputImageTensorSpecs( - *engine_->interpreter(), - *engine_->metadata_extractor())); + ASSIGN_OR_RETURN(input_specs_, vision::BuildImageTensorSpecs( + *engine_->metadata_extractor(), + GetTensorMetadata(), GetTensor())); if (input_specs_.color_space != tflite::ColorSpaceType_RGB) { return tflite::support::CreateStatusWithPayload( diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc index afbe07dd9..92ff74ac8 100644 --- a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc +++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.cc @@ -39,28 +39,21 @@ using ::tflite::support::StatusOr; using ::tflite::support::TfLiteSupportStatus; using ::tflite::task::core::TfLiteEngine; -StatusOr GetInputTensorMetadataIfAny( - const ModelMetadataExtractor& metadata_extractor) { +StatusOr GetTensorMetadataIfAny( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata* tensor_metadata) { if (metadata_extractor.GetModelMetadata() == nullptr || metadata_extractor.GetModelMetadata()->subgraph_metadata() == nullptr) { // Some models have no metadata at all (or very partial), so exit early. return nullptr; - } else if (metadata_extractor.GetInputTensorCount() != 1) { - return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - "Models are assumed to have a single input TensorMetadata.", - TfLiteSupportStatus::kInvalidNumInputTensorsError); } - const TensorMetadata* metadata = metadata_extractor.GetInputTensorMetadata(0); - - if (metadata == nullptr) { + if (tensor_metadata == nullptr) { // Should never happen. return CreateStatusWithPayload(StatusCode::kInternal, - "Input TensorMetadata is null."); + "Provided TensorMetadata is null."); } - - return metadata; + return tensor_metadata; } StatusOr GetImagePropertiesIfAny( @@ -134,53 +127,43 @@ StatusOr> GetNormalizationOptionsIfAny( } // namespace -StatusOr BuildInputImageTensorSpecs( - const TfLiteEngine::Interpreter& interpreter, - const tflite::metadata::ModelMetadataExtractor& metadata_extractor) { - ASSIGN_OR_RETURN(const TensorMetadata* metadata, - GetInputTensorMetadataIfAny(metadata_extractor)); - +StatusOr BuildImageTensorSpecs( + const ModelMetadataExtractor& metadata_extractor, + const TensorMetadata* tensor_metadata, const TfLiteTensor* tensor) { const ImageProperties* props = nullptr; absl::optional normalization_options; + ASSIGN_OR_RETURN(const TensorMetadata* metadata, + GetTensorMetadataIfAny(metadata_extractor, tensor_metadata)); if (metadata != nullptr) { ASSIGN_OR_RETURN(props, GetImagePropertiesIfAny(*metadata)); ASSIGN_OR_RETURN(normalization_options, GetNormalizationOptionsIfAny(*metadata)); } - if (TfLiteEngine::InputCount(&interpreter) != 1) { - return CreateStatusWithPayload( - StatusCode::kInvalidArgument, - "Models are assumed to have a single input.", - TfLiteSupportStatus::kInvalidNumInputTensorsError); - } - - // Input-related specifications. - const TfLiteTensor* input_tensor = TfLiteEngine::GetInput(&interpreter, 0); - if (input_tensor->dims->size != 4) { + if (tensor->dims->size != 4) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, "Only 4D tensors in BHWD layout are supported.", TfLiteSupportStatus::kInvalidInputTensorDimensionsError); } static constexpr TfLiteType valid_types[] = {kTfLiteUInt8, kTfLiteFloat32}; - TfLiteType input_type = input_tensor->type; - if (!absl::c_linear_search(valid_types, input_type)) { + TfLiteType tensor_type = tensor->type; + if (!absl::c_linear_search(valid_types, tensor_type)) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, absl::StrCat( - "Type mismatch for input tensor ", input_tensor->name, + "Type mismatch for tensor ", tensor->name, ". Requested one of these types: kTfLiteUint8/kTfLiteFloat32, got ", - TfLiteTypeGetName(input_type), "."), + TfLiteTypeGetName(tensor_type), "."), TfLiteSupportStatus::kInvalidInputTensorTypeError); } // The expected layout is BHWD, i.e. batch x height x width x color // See https://www.tensorflow.org/guide/tensors - const int batch = input_tensor->dims->data[0]; - const int height = input_tensor->dims->data[1]; - const int width = input_tensor->dims->data[2]; - const int depth = input_tensor->dims->data[3]; + const int batch = tensor->dims->data[0]; + const int height = tensor->dims->data[1]; + const int width = tensor->dims->data[2]; + const int depth = tensor->dims->data[3]; if (props != nullptr && props->color_space() != ColorSpaceType_RGB) { return CreateStatusWithPayload(StatusCode::kInvalidArgument, @@ -190,47 +173,48 @@ StatusOr BuildInputImageTensorSpecs( if (batch != 1 || depth != 3) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - absl::StrCat("The input tensor should have dimensions 1 x height x " + absl::StrCat("The tensor should have dimensions 1 x height x " "width x 3. Got ", batch, " x ", height, " x ", width, " x ", depth, "."), TfLiteSupportStatus::kInvalidInputTensorDimensionsError); } - int bytes_size = input_tensor->bytes; + int bytes_size = tensor->bytes; size_t byte_depth = - input_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8); + tensor_type == kTfLiteFloat32 ? sizeof(float) : sizeof(uint8); // Sanity checks. - if (input_type == kTfLiteFloat32) { + if (tensor_type == kTfLiteFloat32) { if (!normalization_options.has_value()) { return CreateStatusWithPayload( absl::StatusCode::kNotFound, - "Input tensor has type kTfLiteFloat32: it requires specifying " - "NormalizationOptions metadata to preprocess input images.", + "Tensor has type kTfLiteFloat32: it requires specifying " + "NormalizationOptions metadata to process images.", TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError); } else if (bytes_size / sizeof(float) % normalization_options.value().num_values != 0) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "The number of elements in the input tensor must be a multiple of " + "The number of elements in the tensor must be a multiple of " "the number of normalization parameters.", TfLiteSupportStatus::kInvalidArgumentError); } } if (width <= 0) { return CreateStatusWithPayload( - StatusCode::kInvalidArgument, "The input width should be positive.", + StatusCode::kInvalidArgument, "The width should be positive.", TfLiteSupportStatus::kInvalidInputTensorDimensionsError); } if (height <= 0) { return CreateStatusWithPayload( - StatusCode::kInvalidArgument, "The input height should be positive.", + StatusCode::kInvalidArgument, "The height should be positive.", TfLiteSupportStatus::kInvalidInputTensorDimensionsError); } if (bytes_size != height * width * depth * byte_depth) { return CreateStatusWithPayload( StatusCode::kInvalidArgument, - "The input size in bytes does not correspond to the expected number of " + "The tensor size in bytes does not correspond to the expected number " + "of " "pixels.", TfLiteSupportStatus::kInvalidInputTensorSizeError); } @@ -243,7 +227,7 @@ StatusOr BuildInputImageTensorSpecs( result.image_width = width; result.image_height = height; result.color_space = ColorSpaceType_RGB; - result.tensor_type = input_type; + result.tensor_type = tensor_type; result.normalization_options = normalization_options; return result; diff --git a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h index d15be3f8e..81e6bb202 100644 --- a/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h +++ b/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h @@ -73,7 +73,7 @@ struct ImageTensorSpecs { absl::optional normalization_options; }; -// Performs sanity checks on the expected input tensor including consistency +// Performs sanity checks on the expected input/output tensor including consistency // checks against model metadata, if any. For now, a single RGB input with BHWD // layout, where B = 1 and D = 3, is expected. Returns the corresponding input // specifications if they pass, or an error otherwise (too many input tensors, @@ -82,9 +82,9 @@ struct ImageTensorSpecs { // initialized before calling this function by means of (respectively): // - `tflite::InterpreterBuilder`, // - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`. -tflite::support::StatusOr BuildInputImageTensorSpecs( - const tflite::task::core::TfLiteEngine::Interpreter& interpreter, - const tflite::metadata::ModelMetadataExtractor& metadata_extractor); +tflite::support::StatusOr BuildImageTensorSpecs( + const tflite::metadata::ModelMetadataExtractor& metadata_extractor, + const TensorMetadata* tensor_metadata, const TfLiteTensor* tensor); } // namespace vision } // namespace task