diff --git a/docs/source/exporters/onnx/overview.mdx b/docs/source/exporters/onnx/overview.mdx index e70e8afa84..2c075d02f8 100644 --- a/docs/source/exporters/onnx/overview.mdx +++ b/docs/source/exporters/onnx/overview.mdx @@ -41,6 +41,7 @@ Supported architectures: - Donut-Swin - Electra - Flaubert +- Funnel - GPT-2 - GPT-BigCode - GPT-J diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 0d4202fd74..277cae2a29 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1330,3 +1330,8 @@ def generate_dummy_inputs_for_validation( reference_model_inputs["encoder_hidden_states"] = reference_model_inputs.pop("encoder_outputs")[0] return super().generate_dummy_inputs_for_validation(reference_model_inputs) + + +class FunnelTransformerOnnxConfig(BertOnnxConfig): + DEFAULT_ONNX_OPSET = 12 + pass diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 54a36f06e9..20f1cf7517 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -507,6 +507,15 @@ class TasksManager: onnx="FlaubertOnnxConfig", tflite="FlaubertTFLiteConfig", ), + "funnel": supported_tasks_mapping( + "feature-extraction", + "text-classification", + "token-classification", + "multiple-choice", + "fill-mask", + "question-answering", + onnx="FunnelTransformerOnnxConfig", + ), "gpt2": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index 6da01ff8de..1dddf3bf44 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -246,6 +246,7 @@ class NormalizedConfigManager: "yolos": NormalizedVisionConfig, "mpt": MPTNormalizedTextConfig, "gpt_bigcode": GPTBigCodeNormalizedTextConfig, + "funnel": NormalizedTextConfig, } @classmethod diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ab4ce97b75..c55855d278 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -56,6 +56,7 @@ "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "flaubert": "hf-internal-testing/tiny-random-flaubert", + "funnel": "hf-internal-testing/tiny-random-funnel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt-bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt-neo": "hf-internal-testing/tiny-random-GPTNeoModel", @@ -173,6 +174,7 @@ "distilbert": "distilbert-base-cased", "electra": "google/electra-base-generator", "flaubert": "hf-internal-testing/tiny-random-flaubert", # TODO + "funnel": "funnel-transformer/large", "gpt2": "gpt2", "gpt-neo": "EleutherAI/gpt-neo-125M", "gpt-neox": "EleutherAI/gpt-neox-20b", diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index ab39351319..ec4dece25b 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -1847,6 +1847,7 @@ class ORTModelForMultipleChoiceIntegrationTest(ORTModelTestMixin): "distilbert", "electra", "flaubert", + "funnel", "ibert", "mobilebert", "nystromformer",