diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD index 481fd8ecf..0e8a5002a 100644 --- a/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD @@ -16,6 +16,7 @@ cc_library( "//tensorflow_lite_support/cc/port:integral_types", "//tensorflow_lite_support/cc/port:status_macros", "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/vision/core:frame_buffer", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc index 6f3aa737b..18f56c7a5 100644 --- a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc @@ -87,6 +87,49 @@ absl::Status EncodeImageToPngFile(const ImageData& image_data, return absl::OkStatus(); } +absl::Status EncodeImageToPngFile(const FrameBuffer& image_buffer, + const std::string& image_path) { + + const int channels = [&image_buffer]() + { + switch(image_buffer.format()) { + case FrameBuffer::Format::kGRAY: + return 1; + case FrameBuffer::Format::kRGB: + return 3; + case FrameBuffer::Format::kRGBA: + return 4; + default: + return -1; + } + }(); + // Sanity check inputs. + if (image_buffer.dimension().width <= 0 || image_buffer.dimension().height <= 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Expected positive image dimensions, found %d x %d.", + image_buffer.dimension().width, image_buffer.dimension().height)); + } + if (channels == -1) { + return absl::UnimplementedError( + absl::StrFormat("Expected image buffer with 1 (grayscale), 3 (RGB) or 4 " + "(RGBA) channels, found %d", + image_buffer.format())); + } + if (image_buffer.plane(0).buffer == nullptr) { + return absl::InvalidArgumentError( + "Expected plane buffer to be set, found nullptr."); + } + + if (stbi_write_png( + image_path.c_str(), image_buffer.dimension().width, image_buffer.dimension().height, + channels, image_buffer.plane(0).buffer, + /*stride_in_bytes=*/image_buffer.dimension().width * channels) == 0) { + return absl::InternalError("An error occurred while encoding image."); + } + + return absl::OkStatus(); +} + void ImageDataFree(ImageData* image) { stbi_image_free(image->pixel_data); } } // namespace vision diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h index a0b0c6bba..b0a7b7dea 100644 --- a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" // from @com_google_absl #include "tensorflow_lite_support/cc/port/integral_types.h" #include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h" namespace tflite { namespace task { @@ -48,6 +49,11 @@ tflite::support::StatusOr DecodeImageFromFile( absl::Status EncodeImageToPngFile(const ImageData& image_data, const std::string& image_path); +// Encodes the image provided as an FrameBuffer as lossless PNG to the provided +// path. +absl::Status EncodeImageToPngFile(const FrameBuffer& image_buffer, + const std::string& image_path); + // Releases image pixel data memory. void ImageDataFree(ImageData* image);