diff --git a/docs/apache-airflow-providers-google/operators/cloud/automl.rst b/docs/apache-airflow-providers-google/operators/cloud/automl.rst index 4eb461409fa3c..4cb6a9724cadb 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/automl.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/automl.rst @@ -163,25 +163,21 @@ You can find example on how to use VertexAI operators here: :end-before: [END how_to_cloud_vertex_ai_delete_model_operator] .. _howto/operator:AutoMLPredictOperator: -.. _howto/operator:AutoMLBatchPredictOperator: Making Predictions ^^^^^^^^^^^^^^^^^^ To obtain predictions from Google Cloud AutoML model you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator` or -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator`. In the first case +:class:`~airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator`. In the first case the model must be deployed. -Th :class:`~airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator` deprecated for tables, -video intelligence, vision and natural language is deprecated and will be removed after 31.03.2024. -Please use +For tables, video intelligence, vision and natural language you can use the following operators: + :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.CreateBatchPredictionJobOperator`, :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator`, :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.ListBatchPredictionJobsOperator`, -:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.DeleteBatchPredictionJobOperator`, -instead. +:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.DeleteBatchPredictionJobOperator`. You can find examples on how to use VertexAI operators here: .. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst index 3213aec60690e..6dd405ce93213 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/dataflow.rst @@ -54,9 +54,6 @@ There are several ways to run a Dataflow pipeline depending on your environment, command-line tool to build and save the Flex Template spec file in Cloud Storage. See: :ref:`howto/operator:DataflowStartFlexTemplateOperator` -- **SQL pipeline**: Developer can write pipeline as SQL statement and then execute it in Dataflow. See: - :ref:`howto/operator:DataflowStartSqlJobOperator` - It is a good idea to test your pipeline using the non-templated pipeline, and then run the pipeline in production using the templates. @@ -283,29 +280,6 @@ Also for this action you can use the operator in the deferrable mode: :start-after: [START howto_operator_start_flex_template_job_deferrable] :end-before: [END howto_operator_start_flex_template_job_deferrable] -.. _howto/operator:DataflowStartSqlJobOperator: - -Dataflow SQL -"""""""""""" -Dataflow SQL supports a variant of the ZetaSQL query syntax and includes additional streaming -extensions for running Dataflow streaming jobs. - -Here is an example of running Dataflow SQL job with -:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartSqlJobOperator`: - -.. exampleinclude:: /../../providers/tests/system/google/cloud/dataflow/example_dataflow_sql.py - :language: python - :dedent: 4 - :start-after: [START howto_operator_start_sql_job] - :end-before: [END howto_operator_start_sql_job] - -.. warning:: - This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker - `__ - -See the `Dataflow SQL reference -`_. - .. _howto/operator:DataflowStartYamlJobOperator: Dataflow YAML diff --git a/docs/docker-stack/recipes.rst b/docs/docker-stack/recipes.rst index 3402acb1019ca..7666fa892f747 100644 --- a/docs/docker-stack/recipes.rst +++ b/docs/docker-stack/recipes.rst @@ -26,8 +26,7 @@ Google Cloud SDK installation ----------------------------- Some operators, such as :class:`~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator`, -:class:`~airflow.providers.google.cloud.operators.dataflow.DataflowStartSqlJobOperator`, require -the installation of `Google Cloud SDK `__ (includes ``gcloud``). +require the installation of `Google Cloud SDK `__ (includes ``gcloud``). You can also run these commands with BashOperator. Create a new Dockerfile like the one shown below. diff --git a/providers/src/airflow/providers/google/CHANGELOG.rst b/providers/src/airflow/providers/google/CHANGELOG.rst index 8b8a8bc83ff9d..de464fb4849da 100644 --- a/providers/src/airflow/providers/google/CHANGELOG.rst +++ b/providers/src/airflow/providers/google/CHANGELOG.rst @@ -27,6 +27,37 @@ Changelog --------- +13.0.0 +...... + +.. note:: + This release of provider is only available for Airflow 2.9+ as explained in the + `Apache Airflow providers support policy `_. + +Breaking changes +~~~~~~~~~~~~~~~~ + +.. warning:: + Deprecated classes, parameters and features have been removed from the Google provider package. + The following breaking changes were introduced: + + * Operators + + * Removed ``AutoMLBatchPredictOperator``. Please use the operators from ``airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job`` instead + * Removed ``DataflowStartSqlJobOperator``. Please ``DataflowStartYamlJobOperator`` instead + * Removed ``PromptLanguageModelOperator``. Please ``TextGenerationModelPredictOperator`` instead + * Removed ``GenerateTextEmbeddingsOperator``. Please ``TextEmbeddingModelGetEmbeddingsOperator`` instead + * Removed ``PromptMultimodalModelOperator``. Please ``GenerativeModelGenerateContentOperator`` instead + * Removed ``PromptMultimodalModelWithMediaOperator``. Please ``GenerativeModelGenerateContentOperator`` instead + + * Hooks + + * Removed ``GenerativeModelHook.prompt_multimodal_model_with_media()``. Please use ``GenerativeModelHook.generative_model_generate_content()`` instead + * Removed ``GenerativeModelHook.prompt_multimodal_model()``. Please use ``GenerativeModelHook.generative_model_generate_content()`` instead + * Removed ``GenerativeModelHook.get_generative_model_part()``. Please use ``GenerativeModelHook.generative_model_generate_content()`` instead + * Removed ``GenerativeModelHook.prompt_language_model()``. Please use ``GenerativeModelHook.text_generation_model_predict()`` instead + * Removed ``GenerativeModelHook.generate_text_embeddings()``. Please use ``GenerativeModelHook.text_generation_model_predict()`` instead + 12.0.0 ...... diff --git a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index 7e506641484b3..8f06d4974e137 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING import vertexai -from vertexai.generative_models import GenerativeModel, Part +from vertexai.generative_models import GenerativeModel from vertexai.language_models import TextEmbeddingModel, TextGenerationModel from vertexai.preview.caching import CachedContent from vertexai.preview.evaluation import EvalResult, EvalTask @@ -100,186 +100,6 @@ def get_cached_context_model( cached_context_model = preview_generative_model.from_cached_content(cached_content) return cached_context_model - @deprecated( - planned_removal_date="January 01, 2025", - use_instead="Part objects included in contents parameter of " - "airflow.providers.google.cloud.hooks.generative_model." - "GenerativeModelHook.generative_model_generate_content", - category=AirflowProviderDeprecationWarning, - ) - def get_generative_model_part(self, content_gcs_path: str, content_mime_type: str | None = None) -> Part: - """Return a Generative Model Part object.""" - part = Part.from_uri(content_gcs_path, mime_type=content_mime_type) - return part - - @deprecated( - planned_removal_date="January 01, 2025", - use_instead="airflow.providers.google.cloud.hooks.generative_model." - "GenerativeModelHook.text_generation_model_predict", - category=AirflowProviderDeprecationWarning, - ) - @GoogleBaseHook.fallback_to_default_project_id - def prompt_language_model( - self, - prompt: str, - pretrained_model: str, - temperature: float, - max_output_tokens: int, - top_p: float, - top_k: int, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - ) -> str: - """ - Use the Vertex AI PaLM API to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param prompt: Required. Inputs or queries that a user or a program gives - to the Vertex AI PaLM API, in order to elicit a specific response. - :param pretrained_model: A pre-trained model optimized for performing natural - language tasks such as classification, summarization, extraction, content - creation, and ideation. - :param temperature: Temperature controls the degree of randomness in token - selection. - :param max_output_tokens: Token limit determines the maximum amount of text - output. - :param top_p: Tokens are selected from most probable to least until the sum - of their probabilities equals the top_p value. Defaults to 0.8. - :param top_k: A top_k of 1 means the selected token is the most probable - among all tokens. - """ - vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) - - parameters = { - "temperature": temperature, - "max_output_tokens": max_output_tokens, - "top_p": top_p, - "top_k": top_k, - } - - model = self.get_text_generation_model(pretrained_model) - - response = model.predict( - prompt=prompt, - **parameters, - ) - return response.text - - @deprecated( - planned_removal_date="January 01, 2025", - use_instead="airflow.providers.google.cloud.hooks.generative_model." - "GenerativeModelHook.text_embedding_model_get_embeddings", - category=AirflowProviderDeprecationWarning, - ) - @GoogleBaseHook.fallback_to_default_project_id - def generate_text_embeddings( - self, - prompt: str, - pretrained_model: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - ) -> list: - """ - Use the Vertex AI PaLM API to generate text embeddings. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param prompt: Required. Inputs or queries that a user or a program gives - to the Vertex AI PaLM API, in order to elicit a specific response. - :param pretrained_model: A pre-trained model optimized for generating text embeddings. - """ - vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) - model = self.get_text_embedding_model(pretrained_model) - - response = model.get_embeddings([prompt])[0] # single prompt - - return response.values - - @deprecated( - planned_removal_date="January 01, 2025", - use_instead="airflow.providers.google.cloud.hooks.generative_model." - "GenerativeModelHook.generative_model_generate_content", - category=AirflowProviderDeprecationWarning, - ) - @GoogleBaseHook.fallback_to_default_project_id - def prompt_multimodal_model( - self, - prompt: str, - location: str, - generation_config: dict | None = None, - safety_settings: dict | None = None, - pretrained_model: str = "gemini-pro", - project_id: str = PROVIDE_PROJECT_ID, - ) -> str: - """ - Use the Vertex AI Gemini Pro foundation model to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param prompt: Required. Inputs or queries that a user or a program gives - to the Multi-modal model, in order to elicit a specific response. - :param generation_config: Optional. Generation configuration settings. - :param safety_settings: Optional. Per request settings for blocking unsafe content. - :param pretrained_model: By default uses the pre-trained model `gemini-pro`, - supporting prompts with text-only input, including natural language - tasks, multi-turn text and code chat, and code generation. It can - output text and code. - """ - vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) - - model = self.get_generative_model(pretrained_model) - response = model.generate_content( - contents=[prompt], generation_config=generation_config, safety_settings=safety_settings - ) - - return response.text - - @deprecated( - planned_removal_date="January 01, 2025", - use_instead="airflow.providers.google.cloud.hooks.generative_model." - "GenerativeModelHook.generative_model_generate_content", - category=AirflowProviderDeprecationWarning, - ) - @GoogleBaseHook.fallback_to_default_project_id - def prompt_multimodal_model_with_media( - self, - prompt: str, - location: str, - media_gcs_path: str, - mime_type: str, - generation_config: dict | None = None, - safety_settings: dict | None = None, - pretrained_model: str = "gemini-pro-vision", - project_id: str = PROVIDE_PROJECT_ID, - ) -> str: - """ - Use the Vertex AI Gemini Pro foundation model to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param location: Required. The ID of the Google Cloud location that the service belongs to. - :param prompt: Required. Inputs or queries that a user or a program gives - to the Multi-modal model, in order to elicit a specific response. - :param generation_config: Optional. Generation configuration settings. - :param safety_settings: Optional. Per request settings for blocking unsafe content. - :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`, - supporting prompts with text-only input, including natural language - tasks, multi-turn text and code chat, and code generation. It can - output text and code. - :param media_gcs_path: A GCS path to a content file such as an image or a video. - Can be passed to the multi-modal model as part of the prompt. Used with vision models. - :param mime_type: Validates the media type presented by the file in the media_gcs_path. - """ - vertexai.init(project=project_id, location=location, credentials=self.get_credentials()) - - model = self.get_generative_model(pretrained_model) - part = self.get_generative_model_part(media_gcs_path, mime_type) - response = model.generate_content( - contents=[prompt, part], generation_config=generation_config, safety_settings=safety_settings - ) - - return response.text - @deprecated( planned_removal_date="April 09, 2025", use_instead="GenerativeModelHook.generative_model_generate_content", diff --git a/providers/src/airflow/providers/google/cloud/operators/automl.py b/providers/src/airflow/providers/google/cloud/operators/automl.py index 7ef0716615126..2a683938ed9ac 100644 --- a/providers/src/airflow/providers/google/cloud/operators/automl.py +++ b/providers/src/airflow/providers/google/cloud/operators/automl.py @@ -26,7 +26,6 @@ from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.automl_v1beta1 import ( - BatchPredictResult, ColumnSpec, Dataset, Model, @@ -322,145 +321,6 @@ def execute(self, context: Context): return PredictResponse.to_dict(result) -@deprecated( - planned_removal_date="January 01, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job", - category=AirflowProviderDeprecationWarning, -) -class AutoMLBatchPredictOperator(GoogleCloudBaseOperator): - """ - Perform a batch prediction on Google Cloud AutoML. - - .. warning:: - AutoMLBatchPredictOperator for tables, video intelligence, vision and natural language has been deprecated - and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.CreateBatchPredictionJobOperator`, - :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator`, - :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.ListBatchPredictionJobsOperator`, - :class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.DeleteBatchPredictionJobOperator`, - instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLBatchPredictOperator` - - :param project_id: ID of the Google Cloud project where model will be created if None then - default project_id is used. - :param location: The location of the project. - :param model_id: Name of the model_id requested to serve the batch prediction. - :param input_config: Required. The input configuration for batch prediction. - If a dict is provided, it must be of the same form as the protobuf message - `google.cloud.automl_v1beta1.types.BatchPredictInputConfig` - :param output_config: Required. The Configuration specifying where output predictions should be - written. If a dict is provided, it must be of the same form as the protobuf message - `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig` - :param prediction_params: Additional domain-specific parameters for the predictions, - any string must be up to 25000 characters long. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "model_id", - "input_config", - "output_config", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyModelPredictLink(),) - - def __init__( - self, - *, - model_id: str, - input_config: dict, - output_config: dict, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - prediction_params: dict[str, str] | None = None, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.model_id = model_id - self.location = location - self.project_id = project_id - self.prediction_params = prediction_params - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - self.input_config = input_config - self.output_config = output_config - - @cached_property - def hook(self) -> CloudAutoMLHook: - return CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def model(self) -> Model: - return self.hook.get_model( - model_id=self.model_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - - def execute(self, context: Context): - self.log.info("Fetch batch prediction.") - operation = self.hook.batch_predict( - model_id=self.model_id, - input_config=self.input_config, - output_config=self.output_config, - project_id=self.project_id, - location=self.location, - params=self.prediction_params, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - operation_result = self.hook.wait_for_operation(timeout=self.timeout, operation=operation) - result = BatchPredictResult.to_dict(operation_result) - self.log.info("Batch prediction is ready.") - project_id = self.project_id or self.hook.project_id - if project_id: - TranslationLegacyModelPredictLink.persist( - context=context, - task_instance=self, - model_id=self.model_id, - project_id=project_id, - dataset_id=self.model.dataset_id, - ) - return result - - @deprecated( planned_removal_date="September 30, 2025", use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator, " diff --git a/providers/src/airflow/providers/google/cloud/operators/dataflow.py b/providers/src/airflow/providers/google/cloud/operators/dataflow.py index 3fcbc7f67b784..c881853374ead 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataflow.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataflow.py @@ -28,7 +28,7 @@ from googleapiclient.errors import HttpError from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( DEFAULT_DATAFLOW_LOCATION, DataflowHook, @@ -40,7 +40,6 @@ TemplateJobStartTrigger, ) from airflow.providers.google.common.consts import GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME -from airflow.providers.google.common.deprecated import deprecated from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID if TYPE_CHECKING: @@ -654,116 +653,6 @@ def on_kill(self) -> None: ) -@deprecated( - planned_removal_date="January 31, 2025", - use_instead="DataflowStartYamlJobOperator", - category=AirflowProviderDeprecationWarning, -) -class DataflowStartSqlJobOperator(GoogleCloudBaseOperator): - """ - Starts Dataflow SQL query. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:DataflowStartSqlJobOperator` - - .. warning:: - This operator requires ``gcloud`` command (Google Cloud SDK) must be installed on the Airflow worker - `__ - - :param job_name: The unique name to assign to the Cloud Dataflow job. - :param query: The SQL query to execute. - :param options: Job parameters to be executed. It can be a dictionary with the following keys. - - For more information, look at: - `https://cloud.google.com/sdk/gcloud/reference/beta/dataflow/sql/query - `__ - command reference - - :param location: The location of the Dataflow job (for example europe-west1) - :param project_id: The ID of the GCP project that owns the job. - If set to ``None`` or missing, the default project_id from the GCP connection is used. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud - Platform. - :param drain_pipeline: Optional, set to True if want to stop streaming job by draining it - instead of canceling during killing task instance. See: - https://cloud.google.com/dataflow/docs/guides/stopping-a-pipeline - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "job_name", - "query", - "options", - "location", - "project_id", - "gcp_conn_id", - ) - template_fields_renderers = {"query": "sql"} - - def __init__( - self, - job_name: str, - query: str, - options: dict[str, Any], - location: str = DEFAULT_DATAFLOW_LOCATION, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - drain_pipeline: bool = False, - impersonation_chain: str | Sequence[str] | None = None, - *args, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.job_name = job_name - self.query = query - self.options = options - self.location = location - self.project_id = project_id - self.gcp_conn_id = gcp_conn_id - self.drain_pipeline = drain_pipeline - self.impersonation_chain = impersonation_chain - self.job = None - self.hook: DataflowHook | None = None - - def execute(self, context: Context): - self.hook = DataflowHook( - gcp_conn_id=self.gcp_conn_id, - drain_pipeline=self.drain_pipeline, - impersonation_chain=self.impersonation_chain, - ) - - def set_current_job(current_job): - self.job = current_job - - job = self.hook.start_sql_job( - job_name=self.job_name, - query=self.query, - options=self.options, - location=self.location, - project_id=self.project_id, - on_new_job_callback=set_current_job, - ) - - return job - - def on_kill(self) -> None: - self.log.info("On kill.") - if self.job: - self.hook.cancel_job( - job_id=self.job.get("id"), - project_id=self.job.get("projectId"), - location=self.job.get("location"), - ) - - class DataflowStartYamlJobOperator(GoogleCloudBaseOperator): """ Launch a Dataflow YAML job and return the result. diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 42e4fdc588e43..71af5659552e2 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -31,328 +31,6 @@ from airflow.utils.context import Context -@deprecated( - planned_removal_date="January 01, 2025", - use_instead="TextGenerationModelPredictOperator", - category=AirflowProviderDeprecationWarning, -) -class PromptLanguageModelOperator(GoogleCloudBaseOperator): - """ - Uses the Vertex AI PaLM API to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the - service belongs to (templated). - :param location: Required. The ID of the Google Cloud location that the - service belongs to (templated). - :param prompt: Required. Inputs or queries that a user or a program gives - to the Vertex AI PaLM API, in order to elicit a specific response (templated). - :param pretrained_model: By default uses the pre-trained model `text-bison`, - optimized for performing natural language tasks such as classification, - summarization, extraction, content creation, and ideation. - :param temperature: Temperature controls the degree of randomness in token - selection. Defaults to 0.0. - :param max_output_tokens: Token limit determines the maximum amount of text - output. Defaults to 256. - :param top_p: Tokens are selected from most probable to least until the sum - of their probabilities equals the top_p value. Defaults to 0.8. - :param top_k: A top_k of 1 means the selected token is the most probable - among all tokens. Defaults to 0.4. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields = ("location", "project_id", "impersonation_chain", "prompt") - - def __init__( - self, - *, - project_id: str, - location: str, - prompt: str, - pretrained_model: str = "text-bison", - temperature: float = 0.0, - max_output_tokens: int = 256, - top_p: float = 0.8, - top_k: int = 40, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.prompt = prompt - self.pretrained_model = pretrained_model - self.temperature = temperature - self.max_output_tokens = max_output_tokens - self.top_p = top_p - self.top_k = top_k - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - self.hook = GenerativeModelHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - self.log.info("Submitting prompt") - response = self.hook.prompt_language_model( - project_id=self.project_id, - location=self.location, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - temperature=self.temperature, - max_output_tokens=self.max_output_tokens, - top_p=self.top_p, - top_k=self.top_k, - ) - - self.log.info("Model response: %s", response) - self.xcom_push(context, key="prompt_response", value=response) - - return response - - -@deprecated( - planned_removal_date="January 01, 2025", - use_instead="TextEmbeddingModelGetEmbeddingsOperator", - category=AirflowProviderDeprecationWarning, -) -class GenerateTextEmbeddingsOperator(GoogleCloudBaseOperator): - """ - Uses the Vertex AI PaLM API to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the - service belongs to (templated). - :param location: Required. The ID of the Google Cloud location that the - service belongs to (templated). - :param prompt: Required. Inputs or queries that a user or a program gives - to the Vertex AI PaLM API, in order to elicit a specific response (templated). - :param pretrained_model: By default uses the pre-trained model `textembedding-gecko`, - optimized for performing text embeddings. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields = ("location", "project_id", "impersonation_chain", "prompt") - - def __init__( - self, - *, - project_id: str, - location: str, - prompt: str, - pretrained_model: str = "textembedding-gecko", - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.prompt = prompt - self.pretrained_model = pretrained_model - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - self.hook = GenerativeModelHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - self.log.info("Generating text embeddings") - response = self.hook.generate_text_embeddings( - project_id=self.project_id, - location=self.location, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - ) - - self.log.info("Model response: %s", response) - self.xcom_push(context, key="prompt_response", value=response) - - return response - - -@deprecated( - planned_removal_date="January 01, 2025", - use_instead="GenerativeModelGenerateContentOperator", - category=AirflowProviderDeprecationWarning, -) -class PromptMultimodalModelOperator(GoogleCloudBaseOperator): - """ - Use the Vertex AI Gemini Pro foundation model to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the - service belongs to (templated). - :param location: Required. The ID of the Google Cloud location that the - service belongs to (templated). - :param prompt: Required. Inputs or queries that a user or a program gives - to the Multi-modal model, in order to elicit a specific response (templated). - :param generation_config: Optional. Generation configuration settings. - :param safety_settings: Optional. Per request settings for blocking unsafe content. - :param pretrained_model: By default uses the pre-trained model `gemini-pro`, - supporting prompts with text-only input, including natural language - tasks, multi-turn text and code chat, and code generation. It can - output text and code. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields = ("location", "project_id", "impersonation_chain", "prompt") - - def __init__( - self, - *, - project_id: str, - location: str, - prompt: str, - generation_config: dict | None = None, - safety_settings: dict | None = None, - pretrained_model: str = "gemini-pro", - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.prompt = prompt - self.generation_config = generation_config - self.safety_settings = safety_settings - self.pretrained_model = pretrained_model - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - self.hook = GenerativeModelHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - response = self.hook.prompt_multimodal_model( - project_id=self.project_id, - location=self.location, - prompt=self.prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - ) - - self.log.info("Model response: %s", response) - self.xcom_push(context, key="prompt_response", value=response) - - return response - - -@deprecated( - planned_removal_date="January 01, 2025", - use_instead="GenerativeModelGenerateContentOperator", - category=AirflowProviderDeprecationWarning, -) -class PromptMultimodalModelWithMediaOperator(GoogleCloudBaseOperator): - """ - Use the Vertex AI Gemini Pro foundation model to generate natural language text. - - :param project_id: Required. The ID of the Google Cloud project that the - service belongs to (templated). - :param location: Required. The ID of the Google Cloud location that the - service belongs to (templated). - :param prompt: Required. Inputs or queries that a user or a program gives - to the Multi-modal model, in order to elicit a specific response (templated). - :param generation_config: Optional. Generation configuration settings. - :param safety_settings: Optional. Per request settings for blocking unsafe content. - :param pretrained_model: By default uses the pre-trained model `gemini-pro-vision`, - supporting prompts with text-only input, including natural language - tasks, multi-turn text and code chat, and code generation. It can - output text and code. - :param media_gcs_path: A GCS path to a media file such as an image or a video. - Can be passed to the multi-modal model as part of the prompt. Used with vision models. - :param mime_type: Validates the media type presented by the file in the media_gcs_path. - :param gcp_conn_id: The connection ID to use connecting to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields = ("location", "project_id", "impersonation_chain", "prompt") - - def __init__( - self, - *, - project_id: str, - location: str, - prompt: str, - media_gcs_path: str, - mime_type: str, - generation_config: dict | None = None, - safety_settings: dict | None = None, - pretrained_model: str = "gemini-pro-vision", - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.prompt = prompt - self.generation_config = generation_config - self.safety_settings = safety_settings - self.pretrained_model = pretrained_model - self.media_gcs_path = media_gcs_path - self.mime_type = mime_type - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - self.hook = GenerativeModelHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - response = self.hook.prompt_multimodal_model_with_media( - project_id=self.project_id, - location=self.location, - prompt=self.prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - media_gcs_path=self.media_gcs_path, - mime_type=self.mime_type, - ) - - self.log.info("Model response: %s", response) - self.xcom_push(context, key="prompt_response", value=response) - - return response - - @deprecated( planned_removal_date="April 09, 2025", use_instead="GenerativeModelGenerateContentOperator", diff --git a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py index 21741a617ea92..762958d621a58 100644 --- a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py +++ b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py @@ -148,61 +148,6 @@ def setup_method(self): self.hook = GenerativeModelHook(gcp_conn_id=TEST_GCP_CONN_ID) self.hook.get_credentials = self.dummy_get_credentials - @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model")) - def test_prompt_language_model(self, mock_model) -> None: - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - self.hook.prompt_language_model( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=TEST_PROMPT, - pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL, - temperature=TEST_TEMPERATURE, - max_output_tokens=TEST_MAX_OUTPUT_TOKENS, - top_p=TEST_TOP_P, - top_k=TEST_TOP_K, - ) - assert_warning("text_generation_model_predict", warnings) - - @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_embedding_model")) - def test_generate_text_embeddings(self, mock_model) -> None: - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - self.hook.generate_text_embeddings( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=TEST_PROMPT, - pretrained_model=TEST_TEXT_EMBEDDING_MODEL, - ) - assert_warning("text_embedding_model_get_embeddings", warnings) - - @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model")) - def test_prompt_multimodal_model(self, mock_model) -> None: - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - self.hook.prompt_multimodal_model( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=TEST_PROMPT, - generation_config=TEST_GENERATION_CONFIG, - safety_settings=TEST_SAFETY_SETTINGS, - pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL, - ) - assert_warning("generative_model_generate_content", warnings) - - @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model_part")) - @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model")) - def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None: - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - self.hook.prompt_multimodal_model_with_media( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=TEST_VISION_PROMPT, - generation_config=TEST_GENERATION_CONFIG, - safety_settings=TEST_SAFETY_SETTINGS, - pretrained_model=TEST_MULTIMODAL_VISION_MODEL, - media_gcs_path=TEST_MEDIA_GCS_PATH, - mime_type=TEST_MIME_TYPE, - ) - assert_warning("generative_model_generate_content", warnings) - @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model")) def test_text_generation_model_predict(self, mock_model) -> None: with pytest.warns(AirflowProviderDeprecationWarning) as warnings: diff --git a/providers/tests/google/cloud/links/test_translate.py b/providers/tests/google/cloud/links/test_translate.py index 69c860a8c53fb..1d3822ad32d3e 100644 --- a/providers/tests/google/cloud/links/test_translate.py +++ b/providers/tests/google/cloud/links/test_translate.py @@ -22,19 +22,14 @@ # For no Pydantic environment, we need to skip the tests pytest.importorskip("google.cloud.aiplatform_v1") -from google.cloud.automl_v1beta1 import Model - -from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.links.translate import ( TRANSLATION_BASE_LINK, TranslationDatasetListLink, TranslationLegacyDatasetLink, TranslationLegacyModelLink, - TranslationLegacyModelPredictLink, TranslationLegacyModelTrainLink, ) from airflow.providers.google.cloud.operators.automl import ( - AutoMLBatchPredictOperator, AutoMLCreateDatasetOperator, AutoMLListDatasetOperator, AutoMLTrainModelOperator, @@ -137,36 +132,3 @@ def test_get_link(self, create_task_instance_of_operator, session): ) actual_url = link.get_link(operator=ti.task, ti_key=ti.key) assert actual_url == expected_url - - -class TestTranslationLegacyModelPredictLink: - @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session): - expected_url = ( - f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" - f"predict;modelId={MODEL}?project={GCP_PROJECT_ID}" - ) - link = TranslationLegacyModelPredictLink() - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLBatchPredictOperator, - dag_id="test_legacy_model_predict_link_dag", - task_id="test_legacy_model_predict_link_task", - model_id=MODEL, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - input_config="input_config", - output_config="input_config", - ) - ti.task.model = Model(dataset_id=DATASET, display_name=MODEL) - session.add(ti) - session.commit() - link.persist( - context={"ti": ti}, - task_instance=ti.task, - model_id=MODEL, - project_id=GCP_PROJECT_ID, - dataset_id=DATASET, - ) - actual_url = link.get_link(operator=ti.task, ti_key=ti.key) - assert actual_url == expected_url diff --git a/providers/tests/google/cloud/operators/test_automl.py b/providers/tests/google/cloud/operators/test_automl.py index 94dca98be917b..7ae70c83c9ed3 100644 --- a/providers/tests/google/cloud/operators/test_automl.py +++ b/providers/tests/google/cloud/operators/test_automl.py @@ -26,13 +26,12 @@ pytest.importorskip("google.cloud.aiplatform_v1") from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.automl_v1beta1 import BatchPredictResult, Dataset, Model, PredictResponse +from google.cloud.automl_v1beta1 import Dataset, Model, PredictResponse from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook from airflow.providers.google.cloud.operators.automl import ( - AutoMLBatchPredictOperator, AutoMLCreateDatasetOperator, AutoMLDeleteDatasetOperator, AutoMLDeleteModelOperator, @@ -125,73 +124,6 @@ def test_templating(self, create_task_instance_of_operator, session): assert task.impersonation_chain == "impersonation_chain" -class TestAutoMLBatchPredictOperator: - @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook, mock_link_persist): - mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult() - mock_hook.return_value.extract_object_id = extract_object_id - mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult() - mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL) - mock_context = {"ti": mock.MagicMock()} - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLBatchPredictOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - input_config=INPUT_CONFIG, - output_config=OUTPUT_CONFIG, - task_id=TASK_ID, - prediction_params={}, - ) - op.execute(context=mock_context) - mock_hook.return_value.batch_predict.assert_called_once_with( - input_config=INPUT_CONFIG, - location=GCP_LOCATION, - metadata=(), - model_id=MODEL_ID, - output_config=OUTPUT_CONFIG, - params={}, - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - mock_link_persist.assert_called_once_with( - context=mock_context, - task_instance=op, - model_id=MODEL_ID, - project_id=GCP_PROJECT_ID, - dataset_id=DATASET_ID, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLBatchPredictOperator, - # Templated fields - model_id="{{ 'model' }}", - input_config="{{ 'input-config' }}", - output_config="{{ 'output-config' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLBatchPredictOperator = ti.task - assert task.model_id == "model" - assert task.input_config == "input-config" - assert task.output_config == "output-config" - assert task.location == "location" - assert task.project_id == "project-id" - assert task.impersonation_chain == "impersonation-chain" - - class TestAutoMLPredictOperator: @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") diff --git a/providers/tests/google/cloud/operators/test_dataflow.py b/providers/tests/google/cloud/operators/test_dataflow.py index 83b33eaccf001..89b5f9180838f 100644 --- a/providers/tests/google/cloud/operators/test_dataflow.py +++ b/providers/tests/google/cloud/operators/test_dataflow.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -from copy import deepcopy from unittest import mock import httplib2 import pytest from googleapiclient.errors import HttpError -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataflow import ( DEFAULT_DATAFLOW_LOCATION, DataflowJobStatus, @@ -34,7 +33,6 @@ DataflowDeletePipelineOperator, DataflowRunPipelineOperator, DataflowStartFlexTemplateOperator, - DataflowStartSqlJobOperator, DataflowStartYamlJobOperator, DataflowStopJobOperator, DataflowTemplatedJobStartOperator, @@ -348,40 +346,6 @@ def test_execute_with_deferrable_mode(self, mock_hook, mock_defer_method, deferr mock_defer_method.assert_called_once() -class TestDataflowStartSqlJobOperator: - @mock.patch("airflow.providers.google.cloud.operators.dataflow.DataflowHook") - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - start_sql = DataflowStartSqlJobOperator( - task_id="start_sql_query", - job_name=TEST_SQL_JOB_NAME, - query=TEST_SQL_QUERY, - options=deepcopy(TEST_SQL_OPTIONS), - location=TEST_LOCATION, - do_xcom_push=True, - ) - start_sql.execute(mock.MagicMock()) - - mock_hook.assert_called_once_with( - gcp_conn_id="google_cloud_default", - drain_pipeline=False, - impersonation_chain=None, - ) - mock_hook.return_value.start_sql_job.assert_called_once_with( - job_name=TEST_SQL_JOB_NAME, - query=TEST_SQL_QUERY, - options=TEST_SQL_OPTIONS, - location=TEST_LOCATION, - project_id=None, - on_new_job_callback=mock.ANY, - ) - start_sql.job = TEST_SQL_JOB - start_sql.on_kill() - mock_hook.return_value.cancel_job.assert_called_once_with( - job_id="test-job-id", project_id=None, location=None - ) - - class TestDataflowStartYamlJobOperator: @pytest.fixture def sync_operator(self): diff --git a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py index 709e5d1f78402..8712830c6eee3 100644 --- a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py +++ b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py @@ -35,11 +35,7 @@ CountTokensOperator, CreateCachedContentOperator, GenerateFromCachedContentOperator, - GenerateTextEmbeddingsOperator, GenerativeModelGenerateContentOperator, - PromptLanguageModelOperator, - PromptMultimodalModelOperator, - PromptMultimodalModelWithMediaOperator, RunEvaluationOperator, SupervisedFineTuningTrainOperator, TextEmbeddingModelGetEmbeddingsOperator, @@ -59,224 +55,6 @@ def assert_warning(msg: str, warnings): assert any(msg in str(w) for w in warnings) -class TestVertexAIPromptLanguageModelOperator: - prompt = "In 10 words or less, what is Apache Airflow?" - pretrained_model = "text-bison" - temperature = 0.0 - max_output_tokens = 256 - top_p = 0.8 - top_k = 40 - - def test_deprecation_warning(self): - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - PromptLanguageModelOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - temperature=self.temperature, - max_output_tokens=self.max_output_tokens, - top_p=self.top_p, - top_k=self.top_k, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - assert_warning("TextGenerationModelPredictOperator", warnings) - - @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = PromptLanguageModelOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - temperature=self.temperature, - max_output_tokens=self.max_output_tokens, - top_p=self.top_p, - top_k=self.top_k, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - op.execute(context={"ti": mock.MagicMock()}) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.prompt_language_model.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - temperature=self.temperature, - max_output_tokens=self.max_output_tokens, - top_p=self.top_p, - top_k=self.top_k, - ) - - -class TestVertexAIGenerateTextEmbeddingsOperator: - prompt = "In 10 words or less, what is Apache Airflow?" - pretrained_model = "textembedding-gecko" - - def test_deprecation_warning(self): - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - GenerateTextEmbeddingsOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - assert_warning("TextEmbeddingModelGetEmbeddingsOperator", warnings) - - @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = GenerateTextEmbeddingsOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - op.execute(context={"ti": mock.MagicMock()}) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.generate_text_embeddings.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - pretrained_model=self.pretrained_model, - ) - - -class TestVertexAIPromptMultimodalModelOperator: - prompt = "In 10 words or less, what is Apache Airflow?" - pretrained_model = "gemini-pro" - safety_settings = { - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - } - generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0} - - def test_deprecation_warning(self): - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - PromptMultimodalModelOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - assert_warning("GenerativeModelGenerateContentOperator", warnings) - - @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = PromptMultimodalModelOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - op.execute(context={"ti": mock.MagicMock()}) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.prompt_multimodal_model.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - ) - - -class TestVertexAIPromptMultimodalModelWithMediaOperator: - pretrained_model = "gemini-pro-vision" - vision_prompt = "In 10 words or less, describe this content." - media_gcs_path = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg" - mime_type = "image/jpeg" - safety_settings = { - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, - } - generation_config = {"max_output_tokens": 256, "top_p": 0.8, "temperature": 0.0} - - def test_deprecation_warning(self): - with pytest.warns(AirflowProviderDeprecationWarning) as warnings: - PromptMultimodalModelWithMediaOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.vision_prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - media_gcs_path=self.media_gcs_path, - mime_type=self.mime_type, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - assert_warning("GenerativeModelGenerateContentOperator", warnings) - - @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = PromptMultimodalModelWithMediaOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.vision_prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - media_gcs_path=self.media_gcs_path, - mime_type=self.mime_type, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - op.execute(context={"ti": mock.MagicMock()}) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.prompt_multimodal_model_with_media.assert_called_once_with( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=self.vision_prompt, - generation_config=self.generation_config, - safety_settings=self.safety_settings, - pretrained_model=self.pretrained_model, - media_gcs_path=self.media_gcs_path, - mime_type=self.mime_type, - ) - - class TestVertexAITextGenerationModelPredictOperator: prompt = "In 10 words or less, what is Apache Airflow?" pretrained_model = "text-bison" diff --git a/providers/tests/system/google/cloud/dataflow/example_dataflow_sql.py b/providers/tests/system/google/cloud/dataflow/example_dataflow_sql.py deleted file mode 100644 index 2ba0bf0534c59..0000000000000 --- a/providers/tests/system/google/cloud/dataflow/example_dataflow_sql.py +++ /dev/null @@ -1,149 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Example Airflow DAG for Google Cloud Dataflow service -""" - -from __future__ import annotations - -import os -from datetime import datetime - -from airflow.models.dag import DAG -from airflow.providers.google.cloud.operators.bigquery import ( - BigQueryCreateEmptyDatasetOperator, - BigQueryCreateEmptyTableOperator, - BigQueryDeleteDatasetOperator, - BigQueryDeleteTableOperator, - BigQueryInsertJobOperator, -) -from airflow.providers.google.cloud.operators.dataflow import DataflowStartSqlJobOperator -from airflow.utils.trigger_rule import TriggerRule - -from providers.tests.system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID - -PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID -ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") -DAG_ID = "dataflow_sql" -LOCATION = "europe-west3" -DATAFLOW_SQL_JOB_NAME = f"{DAG_ID}_{ENV_ID}".replace("_", "-") -BQ_SQL_DATASET = f"{DAG_ID}_{ENV_ID}".replace("-", "_") -BQ_SQL_TABLE_INPUT = f"input_{ENV_ID}".replace("-", "_") -BQ_SQL_TABLE_OUTPUT = f"output_{ENV_ID}".replace("-", "_") -INSERT_ROWS_QUERY = ( - f"INSERT {BQ_SQL_DATASET}.{BQ_SQL_TABLE_INPUT} VALUES " - "('John Doe', 900), " - "('Alice Storm', 1200)," - "('Bob Max', 1000)," - "('Peter Jackson', 800)," - "('Mia Smith', 1100);" -) - - -with DAG( - dag_id=DAG_ID, - start_date=datetime(2021, 1, 1), - schedule="@once", - catchup=False, - tags=["example", "dataflow-sql"], -) as dag: - create_bq_dataset = BigQueryCreateEmptyDatasetOperator( - task_id="create_bq_dataset", - dataset_id=BQ_SQL_DATASET, - location=LOCATION, - ) - - create_bq_table = BigQueryCreateEmptyTableOperator( - task_id="create_bq_table", - dataset_id=BQ_SQL_DATASET, - table_id=BQ_SQL_TABLE_INPUT, - schema_fields=[ - {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, - ], - ) - - insert_query_job = BigQueryInsertJobOperator( - task_id="insert_query_job", - configuration={ - "query": { - "query": INSERT_ROWS_QUERY, - "useLegacySql": False, - "priority": "BATCH", - } - }, - location=LOCATION, - ) - - # [START howto_operator_start_sql_job] - start_sql = DataflowStartSqlJobOperator( - task_id="start_sql_query", - job_name=DATAFLOW_SQL_JOB_NAME, - query=f""" - SELECT - emp_name as employee, - salary as employee_salary - FROM - bigquery.table.`{PROJECT_ID}`.`{BQ_SQL_DATASET}`.`{BQ_SQL_TABLE_INPUT}` - WHERE salary >= 1000; - """, - options={ - "bigquery-project": PROJECT_ID, - "bigquery-dataset": BQ_SQL_DATASET, - "bigquery-table": BQ_SQL_TABLE_OUTPUT, - }, - location=LOCATION, - do_xcom_push=True, - ) - # [END howto_operator_start_sql_job] - - delete_bq_table = BigQueryDeleteTableOperator( - task_id="delete_bq_table", - deletion_dataset_table=f"{PROJECT_ID}.{BQ_SQL_DATASET}.{BQ_SQL_TABLE_INPUT}", - trigger_rule=TriggerRule.ALL_DONE, - ) - - delete_bq_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_bq_dataset", - dataset_id=BQ_SQL_DATASET, - delete_contents=True, - trigger_rule=TriggerRule.ALL_DONE, - ) - - ( - # TEST SETUP - create_bq_dataset - >> create_bq_table - >> insert_query_job - # TEST BODY - >> start_sql - # TEST TEARDOWN - >> delete_bq_table - >> delete_bq_dataset - ) - - from tests_common.test_utils.watcher import watcher - - # This test needs watcher in order to properly mark success/failure - # when "tearDown" task with trigger rule is part of the DAG - list(dag.tasks) >> watcher() - -from tests_common.test_utils.system_tests import get_test_run # noqa: E402 - -# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) -test_run = get_test_run(dag) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 3d609c3048e60..2548193c2fba3 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -381,7 +381,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator", "airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator", "airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator", - "airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator", "airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator", "airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator", "airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator", @@ -405,10 +404,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator", "airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator", "airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator", "airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator", "airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360CreateQueryOperator", "airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360RunQueryOperator",