Skip to content
97 changes: 82 additions & 15 deletions firebase_admin/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datetime import datetime, timedelta, timezone
from urllib import parse
import re
import os
import json
from base64 import b64encode
from typing import Any, Optional, Dict
Expand Down Expand Up @@ -49,6 +50,8 @@
'https://cloudtasks.googleapis.com/v2/' + _CLOUD_TASKS_API_RESOURCE_PATH
_FIREBASE_FUNCTION_URL_FORMAT = \
'https://{location_id}-{project_id}.cloudfunctions.net/{resource_id}'
_EMULATOR_HOST_ENV_VAR = 'CLOUD_TASKS_EMULATOR_HOST'
_EMULATED_SERVICE_ACCOUNT_DEFAULT = '[email protected]'

_FUNCTIONS_HEADERS = {
'X-GOOG-API-FORMAT-VERSION': '2',
Expand All @@ -58,6 +61,17 @@
# Default canonical location ID of the task queue.
_DEFAULT_LOCATION = 'us-central1'

def _get_emulator_host() -> Optional[str]:
emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR)
if emulator_host:
if '//' in emulator_host:
raise ValueError(
f'Invalid {_EMULATOR_HOST_ENV_VAR}: "{emulator_host}". It must follow format '
'"host:port".')
return emulator_host
return None


def _get_functions_service(app) -> _FunctionsService:
return _utils.get_app_service(app, _FUNCTIONS_ATTRIBUTE, _FunctionsService)

Expand Down Expand Up @@ -103,13 +117,19 @@ def __init__(self, app: App):
'projectId option, or use service account credentials. Alternatively, set the '
'GOOGLE_CLOUD_PROJECT environment variable.')

self._credential = app.credential.get_credential()
self._emulator_host = _get_emulator_host()
if self._emulator_host:
self._credential = _utils.EmulatorAdminCredentials()
else:
self._credential = app.credential.get_credential()

self._http_client = _http_client.JsonHttpClient(credential=self._credential)

def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue:
"""Creates a TaskQueue instance."""
return TaskQueue(
function_name, extension_id, self._project_id, self._credential, self._http_client)
function_name, extension_id, self._project_id, self._credential, self._http_client,
self._emulator_host)

@classmethod
def handle_functions_error(cls, error: Any):
Expand All @@ -125,7 +145,8 @@ def __init__(
extension_id: Optional[str],
project_id,
credential,
http_client
http_client,
emulator_host: Optional[str] = None
) -> None:

# Validate function_name
Expand All @@ -134,6 +155,7 @@ def __init__(
self._project_id = project_id
self._credential = credential
self._http_client = http_client
self._emulator_host = emulator_host
self._function_name = function_name
self._extension_id = extension_id
# Parse resources from function_name
Expand Down Expand Up @@ -167,16 +189,26 @@ def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str:
str: The ID of the task relative to this queue.
"""
task = self._validate_task_options(task_data, self._resource, opts)
service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT)
emulator_url = self._get_emulator_url(self._resource)
service_url = emulator_url or self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT)
task_payload = self._update_task_payload(task, self._resource, self._extension_id)
try:
resp = self._http_client.body(
'post',
url=service_url,
headers=_FUNCTIONS_HEADERS,
json={'task': task_payload.__dict__}
json={'task': task_payload.to_api_dict()}
)
task_name = resp.get('name', None)
if self._is_emulated():
# Emulator returns a response with format {task: {name: <task_name>}}
# The task name also has an extra '/' at the start compared to prod
task_info = resp.get('task') or {}
task_name = task_info.get('name')
if task_name:
task_name = task_name[1:]
else:
# Production returns a response with format {name: <task_name>}
task_name = resp.get('name')
task_resource = \
self._parse_resource_name(task_name, f'queues/{self._resource.resource_id}/tasks')
return task_resource.resource_id
Expand All @@ -197,7 +229,11 @@ def delete(self, task_id: str) -> None:
ValueError: If the input arguments are invalid.
"""
_Validators.check_non_empty_string('task_id', task_id)
service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT + f'/{task_id}')
emulator_url = self._get_emulator_url(self._resource)
if emulator_url:
service_url = emulator_url + f'/{task_id}'
else:
service_url = self._get_url(self._resource, _CLOUD_TASKS_API_URL_FORMAT + f'/{task_id}')
try:
self._http_client.body(
'delete',
Expand Down Expand Up @@ -235,8 +271,8 @@ def _validate_task_options(
"""Validate and create a Task from optional ``TaskOptions``."""
task_http_request = {
'url': '',
'oidc_token': {
'service_account_email': ''
'oidcToken': {
'serviceAccountEmail': ''
},
'body': b64encode(json.dumps(data).encode()).decode(),
'headers': {
Expand All @@ -250,7 +286,7 @@ def _validate_task_options(
task.http_request['headers'] = {**task.http_request['headers'], **opts.headers}
if opts.schedule_time is not None and opts.schedule_delay_seconds is not None:
raise ValueError(
'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.')
'Both schedule_delay_seconds and schedule_time cannot be set at the same time.')
if opts.schedule_time is not None and opts.schedule_delay_seconds is None:
if not isinstance(opts.schedule_time, datetime):
raise ValueError('schedule_time should be UTC datetime.')
Expand Down Expand Up @@ -288,7 +324,10 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
"""Prepares task to be sent with credentials."""
# Get function url from task or generate from resources
if not _Validators.is_non_empty_string(task.http_request['url']):
task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT)
if self._is_emulated():
task.http_request['url'] = ''
else:
task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT)

# Refresh the credential to ensure all attributes (e.g. service_account_email, id_token)
# are populated, preventing cold start errors.
Expand All @@ -298,20 +337,40 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
except RefreshError as err:
raise ValueError(f'Initial task payload credential refresh failed: {err}') from err

# If extension id is provided, it emplies that it is being run from a deployed extension.
# If extension id is provided, it implies that it is being run from a deployed extension.
# Meaning that it's credential should be a Compute Engine Credential.
if _Validators.is_non_empty_string(extension_id) and \
isinstance(self._credential, ComputeEngineCredentials):
id_token = self._credential.token
task.http_request['headers'] = \
{**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'}
# Delete oidc token
del task.http_request['oidc_token']
del task.http_request['oidcToken']
else:
task.http_request['oidc_token'] = \
{'service_account_email': self._credential.service_account_email}
try:
task.http_request['oidcToken'] = \
{'serviceAccountEmail': self._credential.service_account_email}
except AttributeError as error:
if self._is_emulated():
task.http_request['oidcToken'] = \
{'serviceAccountEmail': _EMULATED_SERVICE_ACCOUNT_DEFAULT}
else:
raise ValueError(
'Failed to determine service account. Initialize the SDK with service '
'account credentials or set service account ID as an app option.'
) from error
return task

def _get_emulator_url(self, resource: Resource):
if self._emulator_host:
emulator_url_format = f'http://{self._emulator_host}/' + _CLOUD_TASKS_API_RESOURCE_PATH
url = self._get_url(resource, emulator_url_format)
return url
return None

def _is_emulated(self):
return self._emulator_host is not None


class _Validators:
"""A collection of data validation utilities."""
Expand Down Expand Up @@ -436,6 +495,14 @@ class Task:
schedule_time: Optional[str] = None
dispatch_deadline: Optional[str] = None

def to_api_dict(self) -> dict:
"""Converts the Task object to a dictionary suitable for the Cloud Tasks API."""
return {
'httpRequest': self.http_request,
'name': self.name,
'scheduleTime': self.schedule_time,
'dispatchDeadline': self.dispatch_deadline,
}

@dataclass
class Resource:
Expand Down
91 changes: 72 additions & 19 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ def setup_class(cls):
def teardown_class(cls):
testutils.cleanup_apps()

def _instrument_functions_service(self, app=None, status=200, payload=_DEFAULT_RESPONSE):
def _instrument_functions_service(
self, app=None, status=200, payload=_DEFAULT_RESPONSE, mounted_url=_CLOUD_TASKS_URL):
if not app:
app = firebase_admin.get_app()
functions_service = functions._get_functions_service(app)
recorder = []
functions_service._http_client.session.mount(
_CLOUD_TASKS_URL,
mounted_url,
testutils.MockAdapter(payload, status, recorder))
return functions_service, recorder

Expand Down Expand Up @@ -125,8 +126,8 @@ def test_task_enqueue(self):
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
assert task['httpRequest']['oidcToken'] == {'serviceAccountEmail': 'mock-email'}
assert task['httpRequest']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_with_extension(self):
resource_name = (
Expand All @@ -147,8 +148,8 @@ def test_task_enqueue_with_extension(self):
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
assert task['httpRequest']['oidcToken'] == {'serviceAccountEmail': 'mock-email'}
assert task['httpRequest']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_compute_engine(self):
app = firebase_admin.initialize_app(
Expand All @@ -168,8 +169,8 @@ def test_task_enqueue_compute_engine(self):
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'}
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
assert task['httpRequest']['oidcToken'] == {'serviceAccountEmail': 'mock-gce-email'}
assert task['httpRequest']['headers'] == {'Content-Type': 'application/json'}

def test_task_enqueue_with_extension_compute_engine(self):
resource_name = (
Expand All @@ -194,8 +195,8 @@ def test_task_enqueue_with_extension_compute_engine(self):
assert task_id == 'test-task-id'

task = json.loads(recorder[0].body.decode())['task']
assert 'oidc_token' not in task['http_request']
assert task['http_request']['headers'] == {
assert 'oidcToken' not in task['httpRequest']
assert task['httpRequest']['headers'] == {
'Content-Type': 'application/json',
'Authorization': 'Bearer mock-compute-engine-token'}

Expand All @@ -209,6 +210,58 @@ def test_task_delete(self):
expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag'
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header

def test_task_enqueue_with_emulator_host(self, monkeypatch):
emulator_host = 'localhost:8124'
emulator_url = f'http://{emulator_host}/'
request_url = emulator_url + _DEFAULT_TASK_PATH.replace('/tasks/test-task-id', '/tasks')

monkeypatch.setenv('CLOUD_TASKS_EMULATOR_HOST', emulator_host)
app = firebase_admin.initialize_app(
_utils.EmulatorAdminCredentials(), {'projectId': 'test-project'}, name='emulator-app')

expected_task_name = (
'/projects/test-project/locations/us-central1'
'/queues/test-function-name/tasks/test-task-id'
)
expected_response = json.dumps({'task': {'name': expected_task_name}})
_, recorder = self._instrument_functions_service(
app, payload=expected_response, mounted_url=emulator_url)

queue = functions.task_queue('test-function-name', app=app)
task_id = queue.enqueue(_DEFAULT_DATA)

assert len(recorder) == 1
assert recorder[0].method == 'POST'
assert recorder[0].url == request_url
assert recorder[0].headers['Content-Type'] == 'application/json'

task = json.loads(recorder[0].body.decode())['task']
assert task['httpRequest']['oidcToken'] == {
'serviceAccountEmail': '[email protected]'
}
assert task_id == 'test-task-id'

def test_task_enqueue_without_emulator_host_error(self, monkeypatch):
app = firebase_admin.initialize_app(
_utils.EmulatorAdminCredentials(),
{'projectId': 'test-project'}, name='no-emulator-app')

_, recorder = self._instrument_functions_service(app)
monkeypatch.delenv('CLOUD_TASKS_EMULATOR_HOST', raising=False)
queue = functions.task_queue('test-function-name', app=app)
with pytest.raises(ValueError) as excinfo:
queue.enqueue(_DEFAULT_DATA)
assert "Failed to determine service account" in str(excinfo.value)
assert len(recorder) == 0

def test_get_emulator_url_invalid_format(self, monkeypatch):
monkeypatch.setenv('CLOUD_TASKS_EMULATOR_HOST', 'http://localhost:8124')
app = firebase_admin.initialize_app(
testutils.MockCredential(), {'projectId': 'test-project'}, name='invalid-host-app')
with pytest.raises(ValueError) as excinfo:
functions.task_queue('test-function-name', app=app)
assert 'Invalid CLOUD_TASKS_EMULATOR_HOST' in str(excinfo.value)

class TestTaskQueueOptions:

_DEFAULT_TASK_OPTS = {'schedule_delay_seconds': None, 'schedule_time': None, \
Expand Down Expand Up @@ -259,13 +312,13 @@ def test_task_options_delay_seconds(self):
assert len(recorder) == 1
task = json.loads(recorder[0].body.decode())['task']

task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00'))
task_schedule_time = datetime.fromisoformat(task['scheduleTime'].replace('Z', '+00:00'))
delta = abs(task_schedule_time - expected_schedule_time)
assert delta <= timedelta(seconds=1)

assert task['dispatch_deadline'] == '200s'
assert task['http_request']['headers']['x-test-header'] == 'test-header-value'
assert task['http_request']['url'] in ['http://google.com', 'https://google.com']
assert task['dispatchDeadline'] == '200s'
assert task['httpRequest']['headers']['x-test-header'] == 'test-header-value'
assert task['httpRequest']['url'] in ['http://google.com', 'https://google.com']
assert task['name'] == _DEFAULT_TASK_PATH

def test_task_options_utc_time(self):
Expand All @@ -287,12 +340,12 @@ def test_task_options_utc_time(self):
assert len(recorder) == 1
task = json.loads(recorder[0].body.decode())['task']

task_schedule_time = datetime.fromisoformat(task['schedule_time'].replace('Z', '+00:00'))
task_schedule_time = datetime.fromisoformat(task['scheduleTime'].replace('Z', '+00:00'))
assert task_schedule_time == expected_schedule_time

assert task['dispatch_deadline'] == '200s'
assert task['http_request']['headers']['x-test-header'] == 'test-header-value'
assert task['http_request']['url'] in ['http://google.com', 'https://google.com']
assert task['dispatchDeadline'] == '200s'
assert task['httpRequest']['headers']['x-test-header'] == 'test-header-value'
assert task['httpRequest']['url'] in ['http://google.com', 'https://google.com']
assert task['name'] == _DEFAULT_TASK_PATH

def test_schedule_set_twice_error(self):
Expand All @@ -304,7 +357,7 @@ def test_schedule_set_twice_error(self):
queue.enqueue(_DEFAULT_DATA, opts)
assert len(recorder) == 0
assert str(excinfo.value) == \
'Both sechdule_delay_seconds and schedule_time cannot be set at the same time.'
'Both schedule_delay_seconds and schedule_time cannot be set at the same time.'


@pytest.mark.parametrize('schedule_time', [
Expand Down
Loading