-
Notifications
You must be signed in to change notification settings - Fork 139
Add Image Transformer Library #679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
f6bda02
proto for input options.
jonpsy 79ca011
All except base_options + added \n
jonpsy f12ce03
Add \n in transformatinos.proto
jonpsy 58a7d54
Yet another \n
jonpsy bcefc49
add transformation result message
jonpsy b43ef89
add proto include header
jonpsy 5499ba4
Add cc and h files.
jonpsy f284452
minor adjust
jonpsy 540c5d5
no need to check num threads
jonpsy c9bfd13
rm uint8t comment
jonpsy 18d6bec
no need transformation data structure
jonpsy ea30bae
* has model file check already handled: TaskAPIFac
jonpsy ebecdfc
postprocess logic done.
jonpsy 772ea3a
Added RGB check
jonpsy dea4922
Add BUILD dep
jonpsy eb81aac
Remove redundant includes.
jonpsy 8644dc9
remove redundant deps in proto
jonpsy 6567129
introduce rgbPixelBytes
jonpsy b939dd5
1. Ditch std::unique_ptr, use FrameBuffer directly.
jonpsy eed46c5
Add postprocessor class and delegate task there.
jonpsy bf41eb9
GetNormalizationOptions is public
jonpsy a4cb45a
move postprocess to cc
jonpsy c50d657
Init done.
jonpsy 4612e36
Just pass input tensor indices
jonpsy ecb411d
updated header as per new API
jonpsy fe248ca
use GetTFLiteEngine()
jonpsy 28500ba
copy GetNormalizationOptionsIfAny() code into .cc
jonpsy 1586155
hold NormalizationOptions state.
jonpsy 82dc6ec
Further document image_postprocessor.h
jonpsy df9fc0a
Move output count to postprocess
jonpsy ce176e8
Check in a single line.
jonpsy 6706c1e
end() => begin()
jonpsy 08c536b
Use the latest API
jonpsy d93e81d
ESR-GAN models with metadata.
jonpsy 745618c
Add fox images.
jonpsy beed843
minor comment fix.
jonpsy ce53d49
Add unit tests.
jonpsy 0181b1c
use husky
jonpsy f1309d0
output_meta 0, 1
jonpsy cc95019
Use the correct model.
jonpsy 1c0409d
enhanced husky
jonpsy 878aef9
test
jonpsy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
229 changes: 229 additions & 0 deletions
229
tensorflow_lite_support/cc/task/processor/image_postprocessor.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow_lite_support/cc/task/processor/image_postprocessor.h" | ||
|
||
namespace tflite { | ||
namespace task { | ||
namespace processor { | ||
|
||
namespace { | ||
|
||
using ::absl::StatusCode; | ||
using ::tflite::metadata::ModelMetadataExtractor; | ||
using ::tflite::support::CreateStatusWithPayload; | ||
using ::tflite::support::StatusOr; | ||
using ::tflite::support::TfLiteSupportStatus; | ||
|
||
StatusOr<absl::optional<vision::NormalizationOptions>> | ||
GetNormalizationOptionsIfAny(const TensorMetadata& tensor_metadata) { | ||
ASSIGN_OR_RETURN( | ||
const tflite::ProcessUnit* normalization_process_unit, | ||
ModelMetadataExtractor::FindFirstProcessUnit( | ||
tensor_metadata, tflite::ProcessUnitOptions_NormalizationOptions)); | ||
if (normalization_process_unit == nullptr) { | ||
return {absl::nullopt}; | ||
} | ||
const tflite::NormalizationOptions* tf_normalization_options = | ||
normalization_process_unit->options_as_NormalizationOptions(); | ||
const auto mean_values = tf_normalization_options->mean(); | ||
const auto std_values = tf_normalization_options->std(); | ||
if (mean_values->size() != std_values->size()) { | ||
return CreateStatusWithPayload( | ||
StatusCode::kInvalidArgument, | ||
absl::StrCat("NormalizationOptions: expected mean and std of same " | ||
"dimension, got ", | ||
mean_values->size(), " and ", std_values->size(), "."), | ||
TfLiteSupportStatus::kMetadataInvalidProcessUnitsError); | ||
} | ||
absl::optional<vision::NormalizationOptions> normalization_options; | ||
if (mean_values->size() == 1) { | ||
normalization_options = vision::NormalizationOptions{ | ||
.mean_values = {mean_values->Get(0), mean_values->Get(0), | ||
mean_values->Get(0)}, | ||
.std_values = {std_values->Get(0), std_values->Get(0), | ||
std_values->Get(0)}, | ||
.num_values = 1}; | ||
} else if (mean_values->size() == 3) { | ||
normalization_options = vision::NormalizationOptions{ | ||
.mean_values = {mean_values->Get(0), mean_values->Get(1), | ||
mean_values->Get(2)}, | ||
.std_values = {std_values->Get(0), std_values->Get(1), | ||
std_values->Get(2)}, | ||
.num_values = 3}; | ||
} else { | ||
return CreateStatusWithPayload( | ||
StatusCode::kInvalidArgument, | ||
absl::StrCat("NormalizationOptions: only 1 or 3 mean and std " | ||
"values are supported, got ", | ||
mean_values->size(), "."), | ||
TfLiteSupportStatus::kMetadataInvalidProcessUnitsError); | ||
} | ||
return normalization_options; | ||
} | ||
} // namespace | ||
|
||
/* static */ | ||
tflite::support::StatusOr<std::unique_ptr<ImagePostprocessor>> | ||
ImagePostprocessor::Create(core::TfLiteEngine* engine, | ||
const std::initializer_list<int> output_indices, | ||
const std::initializer_list<int> input_indices) { | ||
ASSIGN_OR_RETURN(auto processor, Processor::Create<ImagePostprocessor>(/* num_expected_tensors = */ 1, engine, output_indices, /* requires_metadata = */ false)); | ||
|
||
RETURN_IF_ERROR(processor->Init(input_indices)); | ||
return processor; | ||
} | ||
|
||
absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) { | ||
if (core::TfLiteEngine::OutputCount(engine_->interpreter()) != 1) { | ||
return tflite::support::CreateStatusWithPayload( | ||
absl::StatusCode::kInvalidArgument, | ||
absl::StrFormat( | ||
"Image segmentation models are expected to have only 1 " | ||
"output, found %d", | ||
core::TfLiteEngine::OutputCount(engine_->interpreter())), | ||
tflite::support::TfLiteSupportStatus::kInvalidNumOutputTensorsError); | ||
} | ||
|
||
if (GetTensor()->type != kTfLiteUInt8 && GetTensor()->type != kTfLiteFloat32) { | ||
return tflite::support::CreateStatusWithPayload( | ||
absl::StatusCode::kInvalidArgument, | ||
absl::StrFormat("Type mismatch for output tensor %s. Requested one " | ||
"of these types: " | ||
"kTfLiteUint8/kTfLiteFloat32, got %s.", | ||
GetTensor()->name, TfLiteTypeGetName(GetTensor()->type)), | ||
tflite::support::TfLiteSupportStatus::kInvalidOutputTensorTypeError); | ||
} | ||
|
||
if (GetTensor()->dims->data[0] != 1 || GetTensor()->dims->data[3] != 3) { | ||
return CreateStatusWithPayload( | ||
absl::StatusCode::kInvalidArgument, | ||
absl::StrCat("The input tensor should have dimensions 1 x height x " | ||
"width x 3. Got ", | ||
GetTensor()->dims->data[0], " x ", GetTensor()->dims->data[1], | ||
" x ", GetTensor()->dims->data[2], " x ", | ||
GetTensor()->dims->data[3], "."), | ||
tflite::support::TfLiteSupportStatus:: | ||
kInvalidInputTensorDimensionsError); | ||
} | ||
|
||
// Gather metadata | ||
auto* output_metadata = | ||
engine_->metadata_extractor()->GetOutputTensorMetadata( | ||
tensor_indices_.at(0)); | ||
auto* input_metadata = engine_->metadata_extractor()->GetInputTensorMetadata( | ||
input_indices.at(0)); | ||
|
||
// Use input metadata for normalization as fallback. | ||
auto* processing_metadata = | ||
output_metadata != nullptr ? output_metadata : input_metadata; | ||
|
||
absl::optional<vision::NormalizationOptions> normalization_options; | ||
ASSIGN_OR_RETURN(normalization_options, | ||
GetNormalizationOptionsIfAny(*processing_metadata)); | ||
|
||
if (GetTensor()->type == kTfLiteFloat32) { | ||
if (!normalization_options.has_value()) { | ||
return CreateStatusWithPayload( | ||
absl::StatusCode::kNotFound, | ||
"Output tensor has type kTfLiteFloat32: it requires specifying " | ||
"NormalizationOptions metadata to preprocess output images.", | ||
TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError); | ||
} else if (GetTensor()->bytes / sizeof(float) % | ||
normalization_options.value().num_values != | ||
0) { | ||
return CreateStatusWithPayload( | ||
StatusCode::kInvalidArgument, | ||
"The number of elements in the output tensor must be a multiple of " | ||
"the number of normalization parameters.", | ||
TfLiteSupportStatus::kInvalidArgumentError); | ||
} | ||
|
||
options_ = std::make_unique<vision::NormalizationOptions>( | ||
normalization_options.value()); | ||
} | ||
|
||
return absl::OkStatus(); | ||
} | ||
|
||
absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() { | ||
has_uint8_outputs_ = GetTensor()->type == kTfLiteUInt8; | ||
const int kRgbPixelBytes = 3; | ||
|
||
vision::FrameBuffer::Dimension to_buffer_dimension = { | ||
GetTensor()->dims->data[2], GetTensor()->dims->data[1]}; | ||
size_t output_byte_size = | ||
GetBufferByteSize(to_buffer_dimension, vision::FrameBuffer::Format::kRGB); | ||
std::vector<uint8> postprocessed_data(output_byte_size / sizeof(uint8), 0); | ||
|
||
if (has_uint8_outputs_) { // No denormalization required. | ||
if (GetTensor()->bytes != output_byte_size) { | ||
return tflite::support::CreateStatusWithPayload( | ||
absl::StatusCode::kInternal, | ||
"Size mismatch or unsupported padding bytes between pixel data " | ||
"and output tensor."); | ||
} | ||
const uint8* output_data = | ||
core::AssertAndReturnTypedTensor<uint8>(GetTensor()).value(); | ||
postprocessed_data.insert(postprocessed_data.begin(), &output_data[0], | ||
&output_data[output_byte_size / sizeof(uint8)]); | ||
} else { // Denormalize to [0, 255] range. | ||
if (GetTensor()->bytes / sizeof(float) != output_byte_size / sizeof(uint8)) { | ||
return tflite::support::CreateStatusWithPayload( | ||
absl::StatusCode::kInternal, | ||
"Size mismatch or unsupported padding bytes between pixel data " | ||
"and output tensor."); | ||
} | ||
|
||
uint8* denormalized_output_data = postprocessed_data.data(); | ||
const float* output_data = | ||
core::AssertAndReturnTypedTensor<float>(GetTensor()).value(); | ||
const auto norm_options = GetNormalizationOptions(); | ||
|
||
if (norm_options.num_values == 1) { | ||
float mean_value = norm_options.mean_values[0]; | ||
float std_value = norm_options.std_values[0]; | ||
|
||
for (size_t i = 0; i < output_byte_size / sizeof(uint8); | ||
++i, ++denormalized_output_data, ++output_data) { | ||
*denormalized_output_data = static_cast<uint8>(std::round(std::min( | ||
255.f, std::max(0.f, (*output_data) * std_value + mean_value)))); | ||
} | ||
} else { | ||
for (size_t i = 0; i < output_byte_size / sizeof(uint8); | ||
++i, ++denormalized_output_data, ++output_data) { | ||
*denormalized_output_data = static_cast<uint8>(std::round(std::min( | ||
255.f, | ||
std::max(0.f, (*output_data) * norm_options.std_values[i % 3] + | ||
norm_options.mean_values[i % 3])))); | ||
} | ||
} | ||
} | ||
|
||
vision::FrameBuffer::Plane postprocessed_plane = { | ||
/*buffer=*/postprocessed_data.data(), | ||
/*stride=*/{GetTensor()->dims->data[2] * kRgbPixelBytes, kRgbPixelBytes}}; | ||
auto postprocessed_frame_buffer = | ||
vision::FrameBuffer::Create({postprocessed_plane}, to_buffer_dimension, | ||
vision::FrameBuffer::Format::kRGB, | ||
vision::FrameBuffer::Orientation::kTopLeft); | ||
|
||
vision::FrameBuffer postprocessed_result = *postprocessed_frame_buffer.get(); | ||
return postprocessed_result; | ||
} | ||
|
||
} // namespace processor | ||
} // namespace task | ||
} // namespace tflite |
70 changes: 70 additions & 0 deletions
70
tensorflow_lite_support/cc/task/processor/image_postprocessor.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either exPostss or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_POSTPROCESSOR_H_ | ||
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_POSTPROCESSOR_H_ | ||
|
||
#include "tensorflow_lite_support/cc/port/status_macros.h" | ||
#include "tensorflow_lite_support/cc/task/processor/processor.h" | ||
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" | ||
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h" | ||
#include "tensorflow_lite_support/cc/task/core/task_utils.h" | ||
#include "tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h" | ||
|
||
namespace tflite { | ||
namespace task { | ||
namespace processor { | ||
|
||
// Process the associated output image tensor and convert it to a FrameBuffer. | ||
// Requirement for the output tensor: | ||
// (kTfLiteUInt8/kTfLiteFloat32) | ||
// - image input of size `[batch x height x width x channels]`. | ||
// - batch inference is not supported (`batch` is required to be 1). | ||
// - only RGB inputs are supported (`channels` is required to be 3). | ||
// - if type is kTfLiteFloat32, NormalizationOptions are required to be | ||
// attached to the metadata for output de-normalization. Uses input metadata | ||
// as fallback in case output metadata isn't provided. | ||
class ImagePostprocessor : public Postprocessor { | ||
public: | ||
static tflite::support::StatusOr<std::unique_ptr<ImagePostprocessor>> | ||
Create(core::TfLiteEngine* engine, | ||
const std::initializer_list<int> output_indices, | ||
const std::initializer_list<int> input_indices); | ||
|
||
// Processes the output tensor to an RGB of FrameBuffer type. | ||
// If output tensor is of type kTfLiteFloat32, denormalize it into [0 - 255] | ||
// via normalization parameters. | ||
absl::StatusOr<vision::FrameBuffer> Postprocess(); | ||
|
||
private: | ||
using Postprocessor::Postprocessor; | ||
|
||
// Whether the model features quantized inference type (QUANTIZED_UINT8). This | ||
// is currently detected by checking if all output tensors data type is uint8. | ||
bool has_uint8_outputs_; | ||
|
||
std::unique_ptr<vision::NormalizationOptions> options_; | ||
|
||
absl::Status Init(const std::vector<int>& input_indices); | ||
|
||
const vision::NormalizationOptions& GetNormalizationOptions() { | ||
return *options_.get(); | ||
} | ||
}; | ||
} // namespace processor | ||
} // namespace task | ||
} // namespace tflite | ||
|
||
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_PROCESSOR_IMAGE_POSTPROCESSOR_H_ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
GetNormalizationOptionsIfAny
was wrapped insideunknown namespace
insideimage_tensor_specs.cc
we might need to copy-paste the code here unfortunately.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please share as much code as possible between
ImagePostprocessor::Init
andImageTensorSpecs::BuildInputImageTensorSpecs
. You can put a todo and implement it in a follow up PR.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so we're clear, do you mean copy code from
BuildInputImageTensorSpecs
when you said "share".There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Share code" means
ImagePostprocessor::Init
andImageTensorSpecs::BuildInputImageTensorSpecs
use the same piece of code to do processing or validation.