Skip to content

Commit fe089a3

Browse files
committed
Add cc and h files.
1 parent b611580 commit fe089a3

File tree

2 files changed

+343
-0
lines changed

2 files changed

+343
-0
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_lite_support/cc/task/vision/image_transformer.h"
17+
18+
#include "external/com_google_absl/absl/algorithm/container.h"
19+
#include "external/com_google_absl/absl/strings/str_format.h"
20+
#include "external/com_google_absl/absl/strings/string_view.h"
21+
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
22+
#include "tensorflow_lite_support/cc/common.h"
23+
#include "tensorflow_lite_support/cc/port/integral_types.h"
24+
#include "tensorflow_lite_support/cc/port/status_macros.h"
25+
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
26+
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
27+
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
28+
#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_utils.h"
29+
#include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
30+
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
31+
32+
namespace tflite {
33+
namespace task {
34+
namespace vision {
35+
36+
namespace {
37+
38+
using ::absl::StatusCode;
39+
using ::tflite::metadata::ModelMetadataExtractor;
40+
using ::tflite::support::CreateStatusWithPayload;
41+
using ::tflite::support::StatusOr;
42+
using ::tflite::support::TfLiteSupportStatus;
43+
using ::tflite::task::core::AssertAndReturnTypedTensor;
44+
using ::tflite::task::core::TaskAPIFactory;
45+
using ::tflite::task::core::TfLiteEngine;
46+
47+
} // namespace
48+
49+
/* static */
50+
StatusOr<std::unique_ptr<ImageTransformer>> ImageTransformer::CreateFromOptions(
51+
const ImageTransformerOptions& options,
52+
std::unique_ptr<tflite::OpResolver> resolver) {
53+
RETURN_IF_ERROR(SanityCheckOptions(options));
54+
55+
// Copy options to ensure the ExternalFile outlives the constructed object.
56+
auto options_copy = absl::make_unique<ImageTransformerOptions>(options);
57+
58+
std::unique_ptr<ImageTransformer> image_transformer;
59+
//TODO: Should be model_file_with_metadata?
60+
if (options_copy->base_options().has_model_file()) {
61+
ASSIGN_OR_RETURN(
62+
image_classifier,
63+
TaskAPIFactory::CreateFromExternalFileProto<ImageTransformer>(
64+
&options_copy->model_file_with_metadata(), std::move(resolver),
65+
options_copy->num_threads(), options_copy->compute_settings()));
66+
} else if (options_copy->base_options().has_model_file()) {
67+
ASSIGN_OR_RETURN(image_classifier,
68+
TaskAPIFactory::CreateFromBaseOptions<ImageTransformer>(
69+
&options_copy->base_options(), std::move(resolver)));
70+
} else {
71+
// Should never happen because of SanityCheckOptions.
72+
return CreateStatusWithPayload(
73+
StatusCode::kInvalidArgument,
74+
absl::StrFormat("Expected exactly one of `base_options.model_file` or "
75+
"`model_file_with_metadata` to be provided, found 0."),
76+
TfLiteSupportStatus::kInvalidArgumentError);
77+
}
78+
79+
RETURN_IF_ERROR(image_transformer->Init(std::move(options_copy)));
80+
81+
return image_transformer;
82+
}
83+
84+
/* static */
85+
absl::Status ImageTransformer::SanityCheckOptions(
86+
const ImageTransformerOptions& options) {
87+
int num_input_models = (options.base_options().has_model_file() ? 1 : 0) +
88+
(options.has_model_file_with_metadata() ? 1 : 0);
89+
90+
if (num_input_models != 1) {
91+
return CreateStatusWithPayload(
92+
StatusCode::kInvalidArgument,
93+
absl::StrFormat("Expected exactly one of `base_options.model_file` or "
94+
"`model_file_with_metadata` to be provided, found %d.",
95+
num_input_models),
96+
TfLiteSupportStatus::kInvalidArgumentError);
97+
}
98+
if (options.base_options().compute_settings().tflite_settings().cpu_settings().num_threads() == 0 ||
99+
options.base_options().compute_settings().tflite_settings().cpu_settings().num_threads() < -1) {
100+
return CreateStatusWithPayload(
101+
StatusCode::kInvalidArgument,
102+
"`num_threads` must be greater than 0 or equal to -1.",
103+
TfLiteSupportStatus::kInvalidArgumentError);
104+
}
105+
return absl::OkStatus();
106+
}
107+
108+
absl::Status ImageTransformer::Init(
109+
std::unique_ptr<ImageTransformerOptions> options) {
110+
// Set options.
111+
options_ = std::move(options);
112+
113+
// Perform pre-initialization actions (by default, sets the process engine for
114+
// image pre-processing to kLibyuv as a sane default).
115+
RETURN_IF_ERROR(PreInit());
116+
117+
// Sanity check and set inputs and outputs.
118+
RETURN_IF_ERROR(CheckAndSetInputs());
119+
RETURN_IF_ERROR(CheckAndSetOutputs());
120+
121+
RETURN_IF_ERROR(PostInit());
122+
123+
return absl::OkStatus();
124+
}
125+
126+
absl::Status ImageTransformer::PreInit() {
127+
SetProcessEngine(FrameBufferUtils::ProcessEngine::kLibyuv);
128+
return absl::OkStatus();
129+
}
130+
131+
absl::Status ImageTransformer::PostInit() {
132+
// Nothing to do.
133+
return absl::OkStatus();
134+
}
135+
136+
absl::Status ImageTransformer::CheckAndSetOutputs() {
137+
// First, sanity checks on the model itself.
138+
const TfLiteEngine::Interpreter* interpreter =
139+
GetTfLiteEngine()->interpreter();
140+
141+
// Check the number of output tensors.
142+
if (TfLiteEngine::OutputCount(interpreter) != 1) {
143+
return CreateStatusWithPayload(
144+
StatusCode::kInvalidArgument,
145+
absl::StrFormat("Image segmentation models are expected to have only 1 "
146+
"output, found %d",
147+
TfLiteEngine::OutputCount(interpreter)),
148+
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
149+
}
150+
151+
const TfLiteTensor* output_tensor = TfLiteEngine::GetOutput(interpreter, 0);
152+
153+
// Check tensor dimensions.
154+
if (output_tensor->dims->size != 4) {
155+
return CreateStatusWithPayload(
156+
StatusCode::kInvalidArgument,
157+
absl::StrFormat(
158+
"Output tensor is expected to have 4 dimensions, found %d.",
159+
output_tensor->dims->size),
160+
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
161+
}
162+
163+
if (output_tensor->dims->data[0] != 1) {
164+
return CreateStatusWithPayload(
165+
StatusCode::kInvalidArgument,
166+
absl::StrFormat("Expected batch size of 1, found %d.",
167+
output_tensor->dims->data[0]),
168+
TfLiteSupportStatus::kInvalidOutputTensorDimensionsError);
169+
}
170+
// TODO: Will the output be float and should be converted or directly available?
171+
// The example had float and it had to be converted. Anyway, we're guaranteed to have uint8 as output.
172+
has_uint8_outputs_ = (output_tensor->type == kTfLiteUInt8);
173+
return absl::OkStatus();
174+
}
175+
176+
StatusOr<TransformationResult> ImageTransformer::Transform(
177+
const FrameBuffer& frame_buffer) {
178+
BoundingBox roi;
179+
roi.set_width(frame_buffer.dimension().width);
180+
roi.set_height(frame_buffer.dimension().height);
181+
return Transform(frame_buffer, roi);
182+
}
183+
184+
StatusOr<TransformationResult> ImageTransformer::Transform(
185+
const FrameBuffer& frame_buffer, const BoundingBox& roi) {
186+
return InferWithFallback(frame_buffer, roi);
187+
}
188+
189+
StatusOr<std::unique_ptr<FrameBuffer>> ImageTransformer::Postprocess(
190+
const std::vector<const TfLiteTensor*>& output_tensors,
191+
const FrameBuffer& /*frame_buffer*/, const BoundingBox& /*roi*/) {
192+
}
193+
} // namespace vision
194+
} // namespace task
195+
} // namespace tflite
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_TRANSFORMER_H_
17+
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_TRANSFORMER_H_
18+
19+
#include <memory>
20+
#include <vector>
21+
22+
#include "external/com_google_absl/absl/container/flat_hash_set.h"
23+
#include "external/com_google_absl/absl/status/status.h"
24+
#include "tensorflow/lite/c/common.h"
25+
#include "tensorflow/lite/core/api/op_resolver.h"
26+
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
27+
#include "tensorflow_lite_support/cc/port/integral_types.h"
28+
#include "tensorflow_lite_support/cc/port/statusor.h"
29+
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
30+
#include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
31+
#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
32+
#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
33+
34+
namespace tflite {
35+
namespace task {
36+
namespace vision {
37+
38+
// Performs transformation on images.
39+
//
40+
// The API expects a TFLite model with optional, but strongly recommended,
41+
// TFLite Model Metadata.
42+
//
43+
// Input tensor:
44+
// (kTfLiteUInt8/kTfLiteFloat32)
45+
// - image input of size `[batch x height x width x channels]`.
46+
// - batch inference is not supported (`batch` is required to be 1).
47+
// - only RGB inputs are supported (`channels` is required to be 3).
48+
// - if type is kTfLiteFloat32, NormalizationOptions are required to be
49+
// attached to the metadata for input normalization.
50+
// At least one output tensor with:
51+
// (kTfLiteUInt8/kTfLiteFloat32)
52+
// - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or
53+
// `[1 x 1 x 1 x N]`
54+
// - optional (but recommended) label map(s) as AssociatedFile-s with type
55+
// TENSOR_AXIS_LABELS, containing one label per line. The first such
56+
// AssociatedFile (if any) is used to fill the `class_name` field of the
57+
// results. The `display_name` field is filled from the AssociatedFile (if
58+
// any) whose locale matches the `display_names_locale` field of the
59+
// `ImageTransformerOptions` used at creation time ("en" by default, i.e.
60+
// English). If none of these are available, only the `index` field of the
61+
// results will be filled.
62+
//
63+
// An example of such model can be found at:
64+
// https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1
65+
//
66+
// A CLI demo tool is available for easily trying out this API, and provides
67+
// example usage. See:
68+
// examples/task/vision/desktop/image_classifier_demo.cc
69+
class ImageTransformer : public BaseVisionTaskApi<TransformationResult> {
70+
public:
71+
using BaseVisionTaskApi::BaseVisionTaskApi;
72+
73+
// Creates an ImageTransformer from the provided options. A non-default
74+
// OpResolver can be specified in order to support custom Ops or specify a
75+
// subset of built-in Ops.f
76+
static tflite::support::StatusOr<std::unique_ptr<ImageTransformer>>
77+
CreateFromOptions(
78+
const ImageTransformerOptions& options,
79+
std::unique_ptr<tflite::OpResolver> resolver =
80+
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>());
81+
82+
// Performs actual transformation on the provided FrameBuffer.
83+
//
84+
// The FrameBuffer can be of any size and any of the supported formats, i.e.
85+
// RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before
86+
// inference in order to (and in this order):
87+
// - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
88+
// the dimensions of the model input tensor,
89+
// - convert it to the colorspace of the input tensor (i.e. RGB, which is the
90+
// only supported colorspace for now),
91+
// - rotate it according to its `Orientation` so that inference is performed
92+
// on an "upright" image.
93+
tflite::support::StatusOr<TransformationResult> Transform(
94+
const FrameBuffer& frame_buffer);
95+
96+
// Same as above, except that the transformation is performed based on the
97+
// input region of interest. Cropping according to this region of interest is
98+
// prepended to the pre-processing operations.
99+
//
100+
// IMPORTANT: as a consequence of cropping occurring first, the provided
101+
// region of interest is expressed in the unrotated frame of reference
102+
// coordinates system, i.e. in `[0, frame_buffer.width) x [0,
103+
// frame_buffer.height)`, which are the dimensions of the underlying
104+
// `frame_buffer` data before any `Orientation` flag gets applied. Also, the
105+
// region of interest is not clamped, so this method will return a non-ok
106+
// status if the region is out of these bounds.
107+
tflite::support::StatusOr<TransformationResult> Transform(
108+
const FrameBuffer& frame_buffer, const BoundingBox& roi);
109+
110+
protected:
111+
// The options used to build this ImageTransformer.
112+
std::unique_ptr<ImageTransformerOptions> options_;
113+
114+
// Post-processing to transform the raw model outputs into classification
115+
// results.
116+
tflite::support::StatusOr<TransformationResult> Postprocess(
117+
const std::vector<const TfLiteTensor*>& output_tensors,
118+
const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
119+
120+
// Performs sanity checks on the provided ImageTransformerOptions.
121+
static absl::Status SanityCheckOptions(const ImageTransformerOptions& options);
122+
123+
// Initializes the ImageTransformer from the provided ImageTransformerOptions,
124+
// whose ownership is transferred to this object.
125+
absl::Status Init(std::unique_ptr<ImageTransformerOptions> options);
126+
127+
// Performs pre-initialization actions.
128+
virtual absl::Status PreInit();
129+
// Performs post-initialization actions.
130+
virtual absl::Status PostInit();
131+
132+
private:
133+
// Performs sanity checks on the model outputs and extracts their metadata.
134+
absl::Status CheckAndSetOutputs();
135+
136+
// The number of output tensors. This corresponds to the number of
137+
// classification heads.
138+
int num_outputs_;
139+
// Whether the model features quantized inference type (QUANTIZED_UINT8). This
140+
// is currently detected by checking if all output tensors data type is uint8.
141+
bool has_uint8_outputs_;
142+
};
143+
144+
} // namespace vision
145+
} // namespace task
146+
} // namespace tflite
147+
148+
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_TRANSFORMER_H_

0 commit comments

Comments
 (0)