Skip to content

feat: Add audio_classification #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

Wkstr
Copy link

@Wkstr Wkstr commented Jul 16, 2025

✨ Feat: Add Audio Classification

Description

This PR implements a complete, memory-efficient audio classification pipeline for real-time inference on ESP32. It successfully processes audio from an I2S microphone, extracts features, and runs a TFLite model to classify audios.

🚀 How It Works

The data flows through a simple, 4-stage pipeline:

  1. 🎤 Audio Capture: Reads 44032 16-bit PCM samples from an I2S microphone.
  2. 🔬 Feature Extraction: Converts audio to a 9976-element feature vector (43 frames × 232 features) .
  3. 🧠 Model Inference: The TFLite engine performs int8 quantized inference on an embedded model. It internally handles all float32int8 conversions.
  4. 📊 Results: Outputs float32 probabilities for 3 classes: "Background Noise", "one", and "two".

🛠️ Key Architectural Features

  • Modular Design: Built with standardized Sensor, Processor, and Engine components for easy extension and testing.
  • Memory-Efficient: Leverages PSRAM for large buffers (heap_caps), employs RAII for automatic cleanup, and minimizes data copies.
  • Optimized Performance: Uses KissFFT for fast spectral analysis and Int8 Quantization for accelerated TFLite inference.
  • Robust & Safe: Implements clear core::Status error handling and C++ template-based type safety.

📂 Core Components Changed

  • InferencePipeline: Orchestrates the entire data flow.
  • FeatureExtractor: Implements the STFT-based feature generation logic.
  • TFLiteEngine: Wraps the TFLite Micro interpreter, handling model loading and inference.
  • MicrophoneI2S: Provides the hardware audio capture interface.
  • main.cpp (Example): Demonstrates the end-to-end usage of the pipeline.

Verification

The full pipeline has been tested on hardware and successfully classifies audio in real-time. The architecture is now stable and ready for further feature development.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This script need to be refactored:

  1. Format and validate using Pylint or Ruff, imports should be sorted.
  2. Remove print and use standard logging library, reference the logging fmt from other scripts.
  3. No need for header generation, xxd is enough for test demo.
  4. Use argparse and pass files as arguments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipeline 应优先实现抽象类,抽象类不仅仅是应用于音频任务,因此配置项及其内容需要进行调整,具体的设计方案可以一起讨论后再决定

Wkstr added 6 commits July 18, 2025 09:29
- Apply Ruff formatting and linting rules
- Sort imports according to Python standards
- Replace print statements with proper logging
- Add command line argument parsing
- Improve code organization and readability
- Move esp-tflite-micro and esp-nn from examples to acoustics-porting
- Update component dependency configuration
- Centralize dependency management for better organization
@Wkstr Wkstr force-pushed the feature/audio_classificaton branch from 077c96a to 84faaa9 Compare July 18, 2025 01:34
Copy link
Collaborator

@iChizer0 iChizer0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

先简单 review 了一下,需要调整的地方还很多,如果有什么问题、建议或不确定的地方请及时和我沟通

Comment on lines 1 to 18
#include "hal/engine.hpp"
#include "hal/engines/tflite_engine.hpp"
#include "hal/processor.hpp"
#include "hal/processors/feature_extractor.hpp"

namespace bridge {

void __REGISTER_PROCESSORS__()
{
[[maybe_unused]] static hal::ProcessorFeatureExtractor processor_feature_extractor;
}

void __REGISTER_ENGINES__()
{
[[maybe_unused]] static hal::EngineTFLite engine_tflite;
}

} // namespace bridge
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前处理的设计有待商榷,Engine 的注册函数的定义不应该出现在前处理部分

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前处理现写在examples里。

Comment on lines 60 to 79
self.config = config
self.model_info = {}
self.labels = []

# Set up file paths based on config
self.tfjs_model_json = config.tfjs_model
self.output_dir = config.output_dir
self.samples_dir = config.samples_dir
self.metadata_file = config.metadata

# Create output directory if it doesn't exist
os.makedirs(self.output_dir, exist_ok=True)

# Derived file paths
model_name = os.path.splitext(os.path.basename(self.tfjs_model_json))[0]
self.keras_h5_file = os.path.join(self.output_dir, f"{model_name}.h5")
self.tflite_file = os.path.join(self.output_dir, f"{model_name}.tflite")
self.model_data_header = os.path.join(self.output_dir, "model_data.h")
self.model_config_header = os.path.join(self.output_dir, "model_config.h")
self.acousticslab_header = os.path.join(self.output_dir, "acousticslab_model.h")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

内部使用的变量建议用 ._ 来 protect

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已保护,路径暴露在外部

Comment on lines 83 to 103
logger.info(
"Step 1: Converting %s to %s", self.tfjs_model_json, self.keras_h5_file
)
if not os.path.exists(self.tfjs_model_json):
logger.error("TF.js model file '%s' not found.", self.tfjs_model_json)
sys.exit(1)

command = [
"tensorflowjs_converter",
"--input_format=tfjs_layers_model",
"--output_format=keras",
self.tfjs_model_json,
self.keras_h5_file,
]

result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
logger.error("tensorflowjs_converter failed.")
logger.error("STDOUT: %s", result.stdout)
logger.error("STDERR: %s", result.stderr)
sys.exit(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一部分存在以下需要改进的:

  1. 隐含的依赖 tfjs 没有提示安装
  2. 没有必要使用 subprocess + cmd 的方式去掉一个 python 包,import 对应的模块使用就好

Copy link
Author

@Wkstr Wkstr Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.已加入requirements.txt
2.tensorflowjs_converter没有暴露在外部公用的api可调用转到keras,https://github.com/tensorflow/tfjs/tree/master/tfjs-converter 有Converter Function,适用于Flax/JAX

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于 2,有没有要自己判断:

from tensorflowjs.converters.converter import (
    dispatch_tensorflowjs_to_keras_h5_conversion,
)

tfjs_model_path = "/workspaces/tfjs/tmp/model.json"
keras_model_path = "/workspaces/tfjs/tmp/keras_model.h5"

dispatch_tensorflowjs_to_keras_h5_conversion(
    tfjs_model_path,
    keras_model_path,
)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已使用api来替换cmd步骤

Comment on lines 109 to 112
if not self.metadata_file:
logger.info("Step 2: No metadata file specified, using default labels")
self.labels = ["Class_0", "Class_1", "Class_2"] # Default fallback
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

即使没有 metadata,class 的数量也可以从 model 的 output shape 来判断,这个 fallback 设计需要优化(直接抛出错误个人觉得比这个更优)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已优化

Comment on lines 338 to 478
f.write(content)

logger.info("✅ Model data header: %s", self.model_data_header)

def _generate_model_config_header(self):
"""Generate C++ header with model configuration."""
content = f"""#ifndef MODEL_CONFIG_H_
#define MODEL_CONFIG_H_

// Auto-generated model configuration for AcousticsLab

// Model specifications
static constexpr int kModelInputHeight = {self.model_info["input_shape"][1]};
static constexpr int kModelInputWidth = {self.model_info["input_shape"][2]};
static constexpr int kModelInputChannels = {self.model_info["input_shape"][3]};
static constexpr int kModelInputSize = {np.prod(self.model_info["input_shape"][1:])};

static constexpr int kModelOutputSize = {self.model_info["output_shape"][1]};
static constexpr int kNumClasses = {self.model_info["num_classes"]};

// Quantization parameters
static constexpr float kInputScale = {self.model_info["input_scale"]}f;
static constexpr int kInputZeroPoint = {self.model_info["input_zero_point"]};
static constexpr float kOutputScale = {self.model_info["output_scale"]}f;
static constexpr int kOutputZeroPoint = {self.model_info["output_zero_point"]};

// Class labels
static const char* kClassLabels[kNumClasses] = {{
"""

for label in self.labels:
content += f' "{label}",\n'

content += """};

#endif // MODEL_CONFIG_H_
"""

with open(self.model_config_header, "w") as f:
f.write(content)

logger.info("✅ Model config header: %s", self.model_config_header)

def _generate_acousticslab_header(self):
"""Generate unified header for easy AcousticsLab integration."""
content = """#ifndef ACOUSTICSLAB_MODEL_H_
#define ACOUSTICSLAB_MODEL_H_

// Auto-generated unified header for AcousticsLab integration
// Include this file in your AcousticsLab project for automatic model loading

#include "model_data.h"
#include "model_config.h"

// Convenience functions for AcousticsLab integration
namespace acousticslab {

inline const unsigned char* getModelData() {
return g_model_data;
}

inline unsigned int getModelSize() {
return g_model_data_len;
}

inline int getNumClasses() {
return kNumClasses;
}

inline const char* getClassName(int class_id) {
if (class_id >= 0 && class_id < kNumClasses) {
return kClassLabels[class_id];
}
return "Unknown";
}

inline int getModelInputSize() {
return kModelInputSize;
}

inline float getInputScale() {
return kInputScale;
}

inline int getInputZeroPoint() {
return kInputZeroPoint;
}

inline float getOutputScale() {
return kOutputScale;
}

inline int getOutputZeroPoint() {
return kOutputZeroPoint;
}

} // namespace acousticslab

#endif // ACOUSTICSLAB_MODEL_H_
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一部分建议移除

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已移除,头文件这块都已移除

Comment on lines 490 to 495
self.convert_tfjs_to_keras()
self.load_metadata()
self.create_models()
self.transfer_weights()
self.quantize_and_convert()
self.generate_headers()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否缺少每一步的状态检查?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加入状态检查

@Wkstr Wkstr force-pushed the feature/audio_classificaton branch from 5940e49 to 9c0736a Compare July 23, 2025 06:07
@Wkstr Wkstr requested a review from iChizer0 July 25, 2025 12:19
#include <esp_log.h>

static const char *TAG = "FeatureGenerator";
static const float g_hanning_window[CONFIG_FRAME_LEN]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非必要请避免在 data 段存放大量数据(或将数据单独存放在一个头文件中),此外这里需要修正:

  1. Speech Commands 实现明确说明使用了 Blackman Window,并不是 Hann Window
  2. 窗生成函数把窗存放在 Heap 上,考虑内存对齐,参考窗函数:Blackman Window

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-07-26 at 20 06 05

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.窗内部存放的是原始数据,Hann Window为命名错误
2.已初始化窗生成函数,由于原始数据非对称性,窗函数为
$w(n) = 0.42 - 0.5 \cos\left(\frac{2\pi n}{N}\right) + 0.08 \cos\left(\frac{4\pi n}{N}\right), \quad 0 \le n \le N-1$ 最为精确

@@ -15,7 +15,7 @@ set(ACOUSTICS_PORTING_LIB_TFLM_ENABLE ON CACHE BOOL "Enable TFLM support in acou
idf_component_register(
SRCS ${PROJ_SRCS}
INCLUDE_DIRS ${PROJ_INCLUDE_DIRS}
REQUIRES freertos esp_psram esp_system acoustics-porting
REQUIRES freertos esp_psram esp_system acoustics-porting espressif__dl_fft
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dl_fft 之后应该放到 porting 的 idf_component.yml 中,另 espressif__ 前缀应该可以去掉

Comment on lines +7 to +9
espressif/esp-nn: '*'
espressif/esp-dsp: '*'
espressif/dl_fft: '*'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

esp-nn 和 esp-dsp 可以移除,dl_fft 重构 node 时请放入 porting

tensorflow
tensorflowjs
scipy
tflite
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引入 tflite 这个包具体是原因?这并不是 Tensorflow 官方的 PyPI 包

Comment on lines +13 to +16
2. **Architecture Optimization**: Reconstructs the Keras model architecture by
replacing computationally expensive Dense layers with equivalent,
but more efficient, 1x1 Convolutional layers. This is a critical
optimization for on-device performance.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一点在设备上有验证带来性能提升吗?

Comment on lines +82 to +90
class ConversionState(Enum):
INITIALIZED = "initialized"
TFJS_CONVERTED = "tfjs_converted"
METADATA_LOADED = "metadata_loaded"
MODELS_CREATED = "models_created"
WEIGHTS_TRANSFERRED = "weights_transferred"
QUANTIZED = "quantized"
COMPLETED = "completed"
FAILED = "failed"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个状态机制目的是什么?可以精简一下吗?

Comment on lines +103 to +124
paths = {
"tfjs_model_path": tfjs_model_path,
"metadata_path": metadata_path,
"keras_output_path": keras_output_path,
"tflite_output_path": tflite_output_path,
"wav_input_dir": wav_input_dir,
"preprocessing_model_path": preprocessing_model_path,
}

for param_name, path in paths.items():
if not isinstance(path, str):
raise TypeError(f"{param_name} must be a string, got {type(path)}")
if not path or not path.strip():
raise ValueError(f"{param_name} cannot be empty or None")

# Public file path attributes (assigned early for easy access)
self.tfjs_model_path = tfjs_model_path.strip()
self.metadata_path = metadata_path.strip()
self.keras_output_path = keras_output_path.strip()
self.tflite_output_path = tflite_output_path.strip()
self.wav_input_dir = wav_input_dir.strip()
self.preprocessing_model_path = preprocessing_model_path.strip()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一部分可以进一步优化

Comment on lines +138 to +167
def _validate_state(self, expected_state: ConversionState) -> None:
if self._conversion_state != expected_state:
error_msg = f"Invalid operation: expected state '{expected_state.value}', but current state is '{self._conversion_state.value}'"

if self._conversion_state == ConversionState.FAILED:
error_msg += "\nThe conversion process has failed. Please create a new ModelConverter instance."
elif self._conversion_state == ConversionState.COMPLETED:
error_msg += (
"\nThe conversion process has already completed successfully."
)
else:
state_order = [
ConversionState.INITIALIZED,
ConversionState.TFJS_CONVERTED,
ConversionState.METADATA_LOADED,
ConversionState.MODELS_CREATED,
ConversionState.WEIGHTS_TRANSFERRED,
ConversionState.QUANTIZED,
ConversionState.COMPLETED,
]

try:
current_idx = state_order.index(self._conversion_state)
if current_idx < len(state_order) - 1:
next_state = state_order[current_idx + 1]
error_msg += f"\nNext expected operation corresponds to state: '{next_state.value}'"
except ValueError:
pass

raise RuntimeError(error_msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,这个状态机制目的是什么?

Comment on lines +175 to +193
def _generate_features_from_wavs(self) -> Iterator[np.ndarray]:
logging.info(
"--- Starting feature generation from WAV files for quantization ---"
)

if not os.path.exists(self.wav_input_dir):
raise FileNotFoundError(
f"WAV input directory not found: '{self.wav_input_dir}'"
)

wav_files = glob.glob(os.path.join(self.wav_input_dir, "*.wav"))
wav_files += glob.glob(os.path.join(self.wav_input_dir, "*.WAV"))

if not wav_files:
raise FileNotFoundError(f"No .wav files found in '{self.wav_input_dir}'")

logging.info(
f"Found {len(wav_files)} WAV files. Processing them to generate representative data..."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的更优解是使用 mime types

Comment on lines +185 to +186
wav_files = glob.glob(os.path.join(self.wav_input_dir, "*.wav"))
wav_files += glob.glob(os.path.join(self.wav_input_dir, "*.WAV"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

samples 的数量需要加以限制,如随机抽 100 条,如果在大型数据集上这么操作可能会 OOM

Comment on lines +263 to +304
output_dir = os.path.dirname(self.keras_output_path)
if output_dir:
try:
os.makedirs(output_dir, exist_ok=True)
except PermissionError:
self._set_state(ConversionState.FAILED)
raise PermissionError(
f"Cannot create output directory: '{output_dir}'\n"
f"Please check directory permissions."
)

try:
logging.info("🔄 Converting TF.js model using tensorflowjs converter...")

dispatch_tensorflowjs_to_keras_h5_conversion(
self.tfjs_model_path,
self.keras_output_path,
)

logging.info("TensorFlow.js to Keras conversion completed successfully")

# Verify the output file was created successfully
if not os.path.exists(self.keras_output_path):
self._set_state(ConversionState.FAILED)
raise RuntimeError(
f"Keras model file was not created: '{self.keras_output_path}'\n"
f"The conversion process completed but no output file was generated."
)

# Verify the output file is not empty
if os.path.getsize(self.keras_output_path) == 0:
self._set_state(ConversionState.FAILED)
raise RuntimeError(
f"Keras model file is empty: '{self.keras_output_path}'\n"
f"The conversion process may have failed silently."
)

self._set_state(ConversionState.TFJS_CONVERTED)
file_size = os.path.getsize(self.keras_output_path)
logging.info(
f"✅ TF.js to Keras conversion successful! Output size: {file_size:,} bytes"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几步仍然可以精简优化

Comment on lines +306 to +322
except Exception as e:
self._set_state(ConversionState.FAILED)
# Check if it's a TensorFlow.js specific error based on the error message
error_msg = str(e).lower()
if any(
keyword in error_msg
for keyword in ["tensorflowjs", "tfjs", "model.json", "invalid model"]
):
raise RuntimeError(
f"TensorFlow.js conversion failed: {str(e)}\n"
f"Please ensure the input file is a valid TF.js model."
) from e
else:
raise RuntimeError(
f"Unexpected error during TF.js to Keras conversion: {str(e)}\n"
f"Please check the input file format and try again."
) from e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

普通 exception 的情况去解析 errmsg 个人认为不是好的设计

Comment on lines +336 to +341
if not os.access(self.metadata_path, os.R_OK):
self._set_state(ConversionState.FAILED)
raise PermissionError(
f"Metadata file is not readable: '{self.metadata_path}'\n"
f"Please check file permissions."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

open 时如果不可读会抛出异常,没必要在这里多加一步判断

Comment on lines +362 to +381
if not isinstance(word_labels, list):
raise ValueError(
f"'wordLabels' must be a list, got {type(word_labels)}\n"
f"Please ensure 'wordLabels' is an array of strings."
)

if not word_labels:
raise ValueError(
"'wordLabels' list cannot be empty.\n"
"Please provide at least one class label."
)
for i, label in enumerate(word_labels):
if not isinstance(label, str):
raise ValueError(
f"All labels must be strings. Label at index {i} is {type(label)}: {label}"
)
if not label.strip():
raise ValueError(
f"Label at index {i} is empty or contains only whitespace"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

从用户体验来看,这里如果为空的话报警告然后使用对应 id 的 str 更优

Comment on lines +387 to +401
except json.JSONDecodeError as e:
self._set_state(ConversionState.FAILED)
raise ValueError(
f"Invalid JSON format in metadata file: {str(e)}\n"
f"Please ensure the file contains valid JSON."
) from e
except (ValueError, TypeError) as e:
self._set_state(ConversionState.FAILED)
raise ValueError(f"Metadata validation failed: {str(e)}") from e
except Exception as e:
self._set_state(ConversionState.FAILED)
raise RuntimeError(
f"Unexpected error while processing metadata: {str(e)}\n"
f"Please check the file format and try again."
) from e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以不需要精细化检查 exception 的类型

Comment on lines +520 to +535
if self._original_model is None:
raise RuntimeError("Failed to create original model architecture")
if self._final_model is None:
raise RuntimeError("Failed to create optimized model architecture")

expected_output_shape = (None, num_classes)
if self._original_model.output_shape != expected_output_shape:
raise RuntimeError(
f"Original model output shape mismatch. Expected {expected_output_shape}, "
f"got {self._original_model.output_shape}"
)
if self._final_model.output_shape != expected_output_shape:
raise RuntimeError(
f"Optimized model output shape mismatch. Expected {expected_output_shape}, "
f"got {self._final_model.output_shape}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议是优先根据原模型的输出 shape 确定 classes,metadata 作为一个 optional 的选项

Comment on lines +785 to +791
parser.add_argument(
"--files",
nargs="+",
metavar="FILE",
help="Simplified usage: specify TF.js model and metadata files in order. "
"Example: --files model.json metadata.json",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个可以优化掉

Comment on lines +852 to +867
_validate_path_exists(args.tfjs_model, "TF.js model file")
_validate_path_exists(args.metadata, "Metadata file")
_validate_path_exists(args.wav_dir, "WAV input directory")
_validate_path_exists(args.preprocessing_model, "Preprocessing model file")
_validate_output_path(args.keras_output, "Keras output file")
_validate_output_path(args.tflite_output, "TFLite output file")

logging.info("✅ All input paths validated successfully")
logging.info("📁 Input files:")
logging.info(f" - TF.js model: {args.tfjs_model}")
logging.info(f" - Metadata: {args.metadata}")
logging.info(f" - WAV directory: {args.wav_dir}")
logging.info(f" - Preprocessing model: {args.preprocessing_model}")
logging.info("📁 Output files:")
logging.info(f" - Keras model: {args.keras_output}")
logging.info(f" - TFLite model: {args.tflite_output}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有点过度检查了,输出的路径可以不存在,动态创建就好了,输入的路径也可以不用检查,在运行到对应使用的代码行时一样会抛出错误(因为这些输入路径到它们被使用之间的运行时很短)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要提供原模型仓库的 license

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体上比较明显的问题我 comments 了,其它的没提出来的细节问题可参考我提问题的思路优化一下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一部分放到 node 和 graph 来实现

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants