-
Notifications
You must be signed in to change notification settings - Fork 1
feat: Add audio_classification #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
scripts/improved_model_converter.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This script need to be refactored:
- Format and validate using Pylint or Ruff, imports should be sorted.
- Remove print and use standard logging library, reference the logging fmt from other scripts.
- No need for header generation, xxd is enough for test demo.
- Use argparse and pass files as arguments.
components/acoustics-porting/porting/driver/microphone/idf_component.yml
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pipeline 应优先实现抽象类,抽象类不仅仅是应用于音频任务,因此配置项及其内容需要进行调整,具体的设计方案可以一起讨论后再决定
- Apply Ruff formatting and linting rules - Sort imports according to Python standards - Replace print statements with proper logging - Add command line argument parsing - Improve code organization and readability
- Move esp-tflite-micro and esp-nn from examples to acoustics-porting - Update component dependency configuration - Centralize dependency management for better organization
077c96a
to
84faaa9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
先简单 review 了一下,需要调整的地方还很多,如果有什么问题、建议或不确定的地方请及时和我沟通
#include "hal/engine.hpp" | ||
#include "hal/engines/tflite_engine.hpp" | ||
#include "hal/processor.hpp" | ||
#include "hal/processors/feature_extractor.hpp" | ||
|
||
namespace bridge { | ||
|
||
void __REGISTER_PROCESSORS__() | ||
{ | ||
[[maybe_unused]] static hal::ProcessorFeatureExtractor processor_feature_extractor; | ||
} | ||
|
||
void __REGISTER_ENGINES__() | ||
{ | ||
[[maybe_unused]] static hal::EngineTFLite engine_tflite; | ||
} | ||
|
||
} // namespace bridge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前处理的设计有待商榷,Engine 的注册函数的定义不应该出现在前处理部分
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前处理现写在examples里。
scripts/improved_model_converter.py
Outdated
self.config = config | ||
self.model_info = {} | ||
self.labels = [] | ||
|
||
# Set up file paths based on config | ||
self.tfjs_model_json = config.tfjs_model | ||
self.output_dir = config.output_dir | ||
self.samples_dir = config.samples_dir | ||
self.metadata_file = config.metadata | ||
|
||
# Create output directory if it doesn't exist | ||
os.makedirs(self.output_dir, exist_ok=True) | ||
|
||
# Derived file paths | ||
model_name = os.path.splitext(os.path.basename(self.tfjs_model_json))[0] | ||
self.keras_h5_file = os.path.join(self.output_dir, f"{model_name}.h5") | ||
self.tflite_file = os.path.join(self.output_dir, f"{model_name}.tflite") | ||
self.model_data_header = os.path.join(self.output_dir, "model_data.h") | ||
self.model_config_header = os.path.join(self.output_dir, "model_config.h") | ||
self.acousticslab_header = os.path.join(self.output_dir, "acousticslab_model.h") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
内部使用的变量建议用 ._ 来 protect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已保护,路径暴露在外部
scripts/improved_model_converter.py
Outdated
logger.info( | ||
"Step 1: Converting %s to %s", self.tfjs_model_json, self.keras_h5_file | ||
) | ||
if not os.path.exists(self.tfjs_model_json): | ||
logger.error("TF.js model file '%s' not found.", self.tfjs_model_json) | ||
sys.exit(1) | ||
|
||
command = [ | ||
"tensorflowjs_converter", | ||
"--input_format=tfjs_layers_model", | ||
"--output_format=keras", | ||
self.tfjs_model_json, | ||
self.keras_h5_file, | ||
] | ||
|
||
result = subprocess.run(command, capture_output=True, text=True) | ||
if result.returncode != 0: | ||
logger.error("tensorflowjs_converter failed.") | ||
logger.error("STDOUT: %s", result.stdout) | ||
logger.error("STDERR: %s", result.stderr) | ||
sys.exit(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一部分存在以下需要改进的:
- 隐含的依赖 tfjs 没有提示安装
- 没有必要使用 subprocess + cmd 的方式去掉一个 python 包,import 对应的模块使用就好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.已加入requirements.txt
2.tensorflowjs_converter没有暴露在外部公用的api可调用转到keras,https://github.com/tensorflow/tfjs/tree/master/tfjs-converter 有Converter Function,适用于Flax/JAX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于 2,有没有要自己判断:
from tensorflowjs.converters.converter import (
dispatch_tensorflowjs_to_keras_h5_conversion,
)
tfjs_model_path = "/workspaces/tfjs/tmp/model.json"
keras_model_path = "/workspaces/tfjs/tmp/keras_model.h5"
dispatch_tensorflowjs_to_keras_h5_conversion(
tfjs_model_path,
keras_model_path,
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已使用api来替换cmd步骤
scripts/improved_model_converter.py
Outdated
if not self.metadata_file: | ||
logger.info("Step 2: No metadata file specified, using default labels") | ||
self.labels = ["Class_0", "Class_1", "Class_2"] # Default fallback | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
即使没有 metadata,class 的数量也可以从 model 的 output shape 来判断,这个 fallback 设计需要优化(直接抛出错误个人觉得比这个更优)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已优化
scripts/improved_model_converter.py
Outdated
f.write(content) | ||
|
||
logger.info("✅ Model data header: %s", self.model_data_header) | ||
|
||
def _generate_model_config_header(self): | ||
"""Generate C++ header with model configuration.""" | ||
content = f"""#ifndef MODEL_CONFIG_H_ | ||
#define MODEL_CONFIG_H_ | ||
|
||
// Auto-generated model configuration for AcousticsLab | ||
|
||
// Model specifications | ||
static constexpr int kModelInputHeight = {self.model_info["input_shape"][1]}; | ||
static constexpr int kModelInputWidth = {self.model_info["input_shape"][2]}; | ||
static constexpr int kModelInputChannels = {self.model_info["input_shape"][3]}; | ||
static constexpr int kModelInputSize = {np.prod(self.model_info["input_shape"][1:])}; | ||
|
||
static constexpr int kModelOutputSize = {self.model_info["output_shape"][1]}; | ||
static constexpr int kNumClasses = {self.model_info["num_classes"]}; | ||
|
||
// Quantization parameters | ||
static constexpr float kInputScale = {self.model_info["input_scale"]}f; | ||
static constexpr int kInputZeroPoint = {self.model_info["input_zero_point"]}; | ||
static constexpr float kOutputScale = {self.model_info["output_scale"]}f; | ||
static constexpr int kOutputZeroPoint = {self.model_info["output_zero_point"]}; | ||
|
||
// Class labels | ||
static const char* kClassLabels[kNumClasses] = {{ | ||
""" | ||
|
||
for label in self.labels: | ||
content += f' "{label}",\n' | ||
|
||
content += """}; | ||
|
||
#endif // MODEL_CONFIG_H_ | ||
""" | ||
|
||
with open(self.model_config_header, "w") as f: | ||
f.write(content) | ||
|
||
logger.info("✅ Model config header: %s", self.model_config_header) | ||
|
||
def _generate_acousticslab_header(self): | ||
"""Generate unified header for easy AcousticsLab integration.""" | ||
content = """#ifndef ACOUSTICSLAB_MODEL_H_ | ||
#define ACOUSTICSLAB_MODEL_H_ | ||
|
||
// Auto-generated unified header for AcousticsLab integration | ||
// Include this file in your AcousticsLab project for automatic model loading | ||
|
||
#include "model_data.h" | ||
#include "model_config.h" | ||
|
||
// Convenience functions for AcousticsLab integration | ||
namespace acousticslab { | ||
|
||
inline const unsigned char* getModelData() { | ||
return g_model_data; | ||
} | ||
|
||
inline unsigned int getModelSize() { | ||
return g_model_data_len; | ||
} | ||
|
||
inline int getNumClasses() { | ||
return kNumClasses; | ||
} | ||
|
||
inline const char* getClassName(int class_id) { | ||
if (class_id >= 0 && class_id < kNumClasses) { | ||
return kClassLabels[class_id]; | ||
} | ||
return "Unknown"; | ||
} | ||
|
||
inline int getModelInputSize() { | ||
return kModelInputSize; | ||
} | ||
|
||
inline float getInputScale() { | ||
return kInputScale; | ||
} | ||
|
||
inline int getInputZeroPoint() { | ||
return kInputZeroPoint; | ||
} | ||
|
||
inline float getOutputScale() { | ||
return kOutputScale; | ||
} | ||
|
||
inline int getOutputZeroPoint() { | ||
return kOutputZeroPoint; | ||
} | ||
|
||
} // namespace acousticslab | ||
|
||
#endif // ACOUSTICSLAB_MODEL_H_ | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一部分建议移除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已移除,头文件这块都已移除
scripts/improved_model_converter.py
Outdated
self.convert_tfjs_to_keras() | ||
self.load_metadata() | ||
self.create_models() | ||
self.transfer_weights() | ||
self.quantize_and_convert() | ||
self.generate_headers() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否缺少每一步的状态检查?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已加入状态检查
5940e49
to
9c0736a
Compare
#include <esp_log.h> | ||
|
||
static const char *TAG = "FeatureGenerator"; | ||
static const float g_hanning_window[CONFIG_FRAME_LEN] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1.窗内部存放的是原始数据,Hann Window为命名错误
2.已初始化窗生成函数,由于原始数据非对称性,窗函数为
@@ -15,7 +15,7 @@ set(ACOUSTICS_PORTING_LIB_TFLM_ENABLE ON CACHE BOOL "Enable TFLM support in acou | |||
idf_component_register( | |||
SRCS ${PROJ_SRCS} | |||
INCLUDE_DIRS ${PROJ_INCLUDE_DIRS} | |||
REQUIRES freertos esp_psram esp_system acoustics-porting | |||
REQUIRES freertos esp_psram esp_system acoustics-porting espressif__dl_fft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dl_fft 之后应该放到 porting 的 idf_component.yml 中,另 espressif__ 前缀应该可以去掉
espressif/esp-nn: '*' | ||
espressif/esp-dsp: '*' | ||
espressif/dl_fft: '*' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
esp-nn 和 esp-dsp 可以移除,dl_fft 重构 node 时请放入 porting
tensorflow | ||
tensorflowjs | ||
scipy | ||
tflite |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
引入 tflite 这个包具体是原因?这并不是 Tensorflow 官方的 PyPI 包
2. **Architecture Optimization**: Reconstructs the Keras model architecture by | ||
replacing computationally expensive Dense layers with equivalent, | ||
but more efficient, 1x1 Convolutional layers. This is a critical | ||
optimization for on-device performance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一点在设备上有验证带来性能提升吗?
class ConversionState(Enum): | ||
INITIALIZED = "initialized" | ||
TFJS_CONVERTED = "tfjs_converted" | ||
METADATA_LOADED = "metadata_loaded" | ||
MODELS_CREATED = "models_created" | ||
WEIGHTS_TRANSFERRED = "weights_transferred" | ||
QUANTIZED = "quantized" | ||
COMPLETED = "completed" | ||
FAILED = "failed" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个状态机制目的是什么?可以精简一下吗?
paths = { | ||
"tfjs_model_path": tfjs_model_path, | ||
"metadata_path": metadata_path, | ||
"keras_output_path": keras_output_path, | ||
"tflite_output_path": tflite_output_path, | ||
"wav_input_dir": wav_input_dir, | ||
"preprocessing_model_path": preprocessing_model_path, | ||
} | ||
|
||
for param_name, path in paths.items(): | ||
if not isinstance(path, str): | ||
raise TypeError(f"{param_name} must be a string, got {type(path)}") | ||
if not path or not path.strip(): | ||
raise ValueError(f"{param_name} cannot be empty or None") | ||
|
||
# Public file path attributes (assigned early for easy access) | ||
self.tfjs_model_path = tfjs_model_path.strip() | ||
self.metadata_path = metadata_path.strip() | ||
self.keras_output_path = keras_output_path.strip() | ||
self.tflite_output_path = tflite_output_path.strip() | ||
self.wav_input_dir = wav_input_dir.strip() | ||
self.preprocessing_model_path = preprocessing_model_path.strip() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一部分可以进一步优化
def _validate_state(self, expected_state: ConversionState) -> None: | ||
if self._conversion_state != expected_state: | ||
error_msg = f"Invalid operation: expected state '{expected_state.value}', but current state is '{self._conversion_state.value}'" | ||
|
||
if self._conversion_state == ConversionState.FAILED: | ||
error_msg += "\nThe conversion process has failed. Please create a new ModelConverter instance." | ||
elif self._conversion_state == ConversionState.COMPLETED: | ||
error_msg += ( | ||
"\nThe conversion process has already completed successfully." | ||
) | ||
else: | ||
state_order = [ | ||
ConversionState.INITIALIZED, | ||
ConversionState.TFJS_CONVERTED, | ||
ConversionState.METADATA_LOADED, | ||
ConversionState.MODELS_CREATED, | ||
ConversionState.WEIGHTS_TRANSFERRED, | ||
ConversionState.QUANTIZED, | ||
ConversionState.COMPLETED, | ||
] | ||
|
||
try: | ||
current_idx = state_order.index(self._conversion_state) | ||
if current_idx < len(state_order) - 1: | ||
next_state = state_order[current_idx + 1] | ||
error_msg += f"\nNext expected operation corresponds to state: '{next_state.value}'" | ||
except ValueError: | ||
pass | ||
|
||
raise RuntimeError(error_msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上,这个状态机制目的是什么?
def _generate_features_from_wavs(self) -> Iterator[np.ndarray]: | ||
logging.info( | ||
"--- Starting feature generation from WAV files for quantization ---" | ||
) | ||
|
||
if not os.path.exists(self.wav_input_dir): | ||
raise FileNotFoundError( | ||
f"WAV input directory not found: '{self.wav_input_dir}'" | ||
) | ||
|
||
wav_files = glob.glob(os.path.join(self.wav_input_dir, "*.wav")) | ||
wav_files += glob.glob(os.path.join(self.wav_input_dir, "*.WAV")) | ||
|
||
if not wav_files: | ||
raise FileNotFoundError(f"No .wav files found in '{self.wav_input_dir}'") | ||
|
||
logging.info( | ||
f"Found {len(wav_files)} WAV files. Processing them to generate representative data..." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的更优解是使用 mime types
wav_files = glob.glob(os.path.join(self.wav_input_dir, "*.wav")) | ||
wav_files += glob.glob(os.path.join(self.wav_input_dir, "*.WAV")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
samples 的数量需要加以限制,如随机抽 100 条,如果在大型数据集上这么操作可能会 OOM
output_dir = os.path.dirname(self.keras_output_path) | ||
if output_dir: | ||
try: | ||
os.makedirs(output_dir, exist_ok=True) | ||
except PermissionError: | ||
self._set_state(ConversionState.FAILED) | ||
raise PermissionError( | ||
f"Cannot create output directory: '{output_dir}'\n" | ||
f"Please check directory permissions." | ||
) | ||
|
||
try: | ||
logging.info("🔄 Converting TF.js model using tensorflowjs converter...") | ||
|
||
dispatch_tensorflowjs_to_keras_h5_conversion( | ||
self.tfjs_model_path, | ||
self.keras_output_path, | ||
) | ||
|
||
logging.info("TensorFlow.js to Keras conversion completed successfully") | ||
|
||
# Verify the output file was created successfully | ||
if not os.path.exists(self.keras_output_path): | ||
self._set_state(ConversionState.FAILED) | ||
raise RuntimeError( | ||
f"Keras model file was not created: '{self.keras_output_path}'\n" | ||
f"The conversion process completed but no output file was generated." | ||
) | ||
|
||
# Verify the output file is not empty | ||
if os.path.getsize(self.keras_output_path) == 0: | ||
self._set_state(ConversionState.FAILED) | ||
raise RuntimeError( | ||
f"Keras model file is empty: '{self.keras_output_path}'\n" | ||
f"The conversion process may have failed silently." | ||
) | ||
|
||
self._set_state(ConversionState.TFJS_CONVERTED) | ||
file_size = os.path.getsize(self.keras_output_path) | ||
logging.info( | ||
f"✅ TF.js to Keras conversion successful! Output size: {file_size:,} bytes" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这几步仍然可以精简优化
except Exception as e: | ||
self._set_state(ConversionState.FAILED) | ||
# Check if it's a TensorFlow.js specific error based on the error message | ||
error_msg = str(e).lower() | ||
if any( | ||
keyword in error_msg | ||
for keyword in ["tensorflowjs", "tfjs", "model.json", "invalid model"] | ||
): | ||
raise RuntimeError( | ||
f"TensorFlow.js conversion failed: {str(e)}\n" | ||
f"Please ensure the input file is a valid TF.js model." | ||
) from e | ||
else: | ||
raise RuntimeError( | ||
f"Unexpected error during TF.js to Keras conversion: {str(e)}\n" | ||
f"Please check the input file format and try again." | ||
) from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
普通 exception 的情况去解析 errmsg 个人认为不是好的设计
if not os.access(self.metadata_path, os.R_OK): | ||
self._set_state(ConversionState.FAILED) | ||
raise PermissionError( | ||
f"Metadata file is not readable: '{self.metadata_path}'\n" | ||
f"Please check file permissions." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
open 时如果不可读会抛出异常,没必要在这里多加一步判断
if not isinstance(word_labels, list): | ||
raise ValueError( | ||
f"'wordLabels' must be a list, got {type(word_labels)}\n" | ||
f"Please ensure 'wordLabels' is an array of strings." | ||
) | ||
|
||
if not word_labels: | ||
raise ValueError( | ||
"'wordLabels' list cannot be empty.\n" | ||
"Please provide at least one class label." | ||
) | ||
for i, label in enumerate(word_labels): | ||
if not isinstance(label, str): | ||
raise ValueError( | ||
f"All labels must be strings. Label at index {i} is {type(label)}: {label}" | ||
) | ||
if not label.strip(): | ||
raise ValueError( | ||
f"Label at index {i} is empty or contains only whitespace" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从用户体验来看,这里如果为空的话报警告然后使用对应 id 的 str 更优
except json.JSONDecodeError as e: | ||
self._set_state(ConversionState.FAILED) | ||
raise ValueError( | ||
f"Invalid JSON format in metadata file: {str(e)}\n" | ||
f"Please ensure the file contains valid JSON." | ||
) from e | ||
except (ValueError, TypeError) as e: | ||
self._set_state(ConversionState.FAILED) | ||
raise ValueError(f"Metadata validation failed: {str(e)}") from e | ||
except Exception as e: | ||
self._set_state(ConversionState.FAILED) | ||
raise RuntimeError( | ||
f"Unexpected error while processing metadata: {str(e)}\n" | ||
f"Please check the file format and try again." | ||
) from e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以不需要精细化检查 exception 的类型
if self._original_model is None: | ||
raise RuntimeError("Failed to create original model architecture") | ||
if self._final_model is None: | ||
raise RuntimeError("Failed to create optimized model architecture") | ||
|
||
expected_output_shape = (None, num_classes) | ||
if self._original_model.output_shape != expected_output_shape: | ||
raise RuntimeError( | ||
f"Original model output shape mismatch. Expected {expected_output_shape}, " | ||
f"got {self._original_model.output_shape}" | ||
) | ||
if self._final_model.output_shape != expected_output_shape: | ||
raise RuntimeError( | ||
f"Optimized model output shape mismatch. Expected {expected_output_shape}, " | ||
f"got {self._final_model.output_shape}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议是优先根据原模型的输出 shape 确定 classes,metadata 作为一个 optional 的选项
parser.add_argument( | ||
"--files", | ||
nargs="+", | ||
metavar="FILE", | ||
help="Simplified usage: specify TF.js model and metadata files in order. " | ||
"Example: --files model.json metadata.json", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可以优化掉
_validate_path_exists(args.tfjs_model, "TF.js model file") | ||
_validate_path_exists(args.metadata, "Metadata file") | ||
_validate_path_exists(args.wav_dir, "WAV input directory") | ||
_validate_path_exists(args.preprocessing_model, "Preprocessing model file") | ||
_validate_output_path(args.keras_output, "Keras output file") | ||
_validate_output_path(args.tflite_output, "TFLite output file") | ||
|
||
logging.info("✅ All input paths validated successfully") | ||
logging.info("📁 Input files:") | ||
logging.info(f" - TF.js model: {args.tfjs_model}") | ||
logging.info(f" - Metadata: {args.metadata}") | ||
logging.info(f" - WAV directory: {args.wav_dir}") | ||
logging.info(f" - Preprocessing model: {args.preprocessing_model}") | ||
logging.info("📁 Output files:") | ||
logging.info(f" - Keras model: {args.keras_output}") | ||
logging.info(f" - TFLite model: {args.tflite_output}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有点过度检查了,输出的路径可以不存在,动态创建就好了,输入的路径也可以不用检查,在运行到对应使用的代码行时一样会抛出错误(因为这些输入路径到它们被使用之间的运行时很短)
models/preprocessing.tflite
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要提供原模型仓库的 license
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体上比较明显的问题我 comments 了,其它的没提出来的细节问题可参考我提问题的思路优化一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一部分放到 node 和 graph 来实现
✨ Feat: Add Audio Classification
Description
This PR implements a complete, memory-efficient audio classification pipeline for real-time inference on ESP32. It successfully processes audio from an I2S microphone, extracts features, and runs a TFLite model to classify audios.
🚀 How It Works
The data flows through a simple, 4-stage pipeline:
44032
16-bit PCM samples from an I2S microphone.9976
-element feature vector (43 frames × 232 features
) .int8
quantized inference on an embedded model. It internally handles allfloat32
↔int8
conversions.float32
probabilities for 3 classes: "Background Noise", "one", and "two".🛠️ Key Architectural Features
Sensor
,Processor
, andEngine
components for easy extension and testing.heap_caps
), employs RAII for automatic cleanup, and minimizes data copies.core::Status
error handling and C++ template-based type safety.📂 Core Components Changed
InferencePipeline
: Orchestrates the entire data flow.FeatureExtractor
: Implements the STFT-based feature generation logic.TFLiteEngine
: Wraps the TFLite Micro interpreter, handling model loading and inference.MicrophoneI2S
: Provides the hardware audio capture interface.main.cpp
(Example): Demonstrates the end-to-end usage of the pipeline.✅ Verification
The full pipeline has been tested on hardware and successfully classifies audio in real-time. The architecture is now stable and ready for further feature development.