-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathgetImages.py
138 lines (120 loc) · 5.53 KB
/
getImages.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import html
import json
import flask
import functions_framework
import vertexai
from vertexai.preview.vision_models import ImageGenerationModel
from vertexai.generative_models import GenerativeModel, Part, FinishReason
import vertexai.preview.generative_models as generative_models
MAX_IMAGE_COUNT = 5
VERTEX_MAX_IMAGE_COUNT = 4
PROJECT_ID = "imagenio"
LOCATION = "us-central1"
vertexai.init(project=PROJECT_ID, location=LOCATION)
image_model = ImageGenerationModel.from_pretrained("imagegeneration@006")
caption_model = GenerativeModel("gemini-1.5-pro-preview-0409")
caption_generation_config = {
"max_output_tokens": 8192,
"temperature": 1,
"top_p": 0.95,
}
safety_settings = {
generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
@functions_framework.http
def get_image(request):
http_origin = request.environ.get('HTTP_ORIGIN', 'no origin')
if request.method == "OPTIONS":
# Allows GET requests from any origin with the Content-Type
# header and caches preflight response for an 3600s
headers = {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST",
"Access-Control-Allow-Headers": "Content-Type",
"Access-Control-Max-Age": "3600",
}
return ("", 204, headers)
request_json = request.get_json(silent=True)
request_args = request.args
default_image_prompt = 'a picture of a cute cat jumping'
default_description_prompt = 'decribe the image'
default_image_count = 1
image_prompt = (request_json or request_args).get('image_prompt', default_image_prompt)
input_prompt = (request_json or request_args).get('desc_prompt', default_description_prompt)
text_prompt = f"""Do this for each image separately: "{html.escape(input_prompt)}". We will call the result of it as the information about an image. Give each image a title. Return the result as a list of objects in json format; each object will correspond one image and the fields for the object will be "title" for the title and "info" for the information."""
image_count = int((request_json or request_args).get('image_count', default_image_count))
if image_count > MAX_IMAGE_COUNT:
return ("Invalid image_count. Maximum image count is 5.", 406)
images = get_images_with_count(image_prompt, image_count)
image_strings = []
caption_input = []
for img in images:
temp_bytes = img._image_bytes
image_strings.append(base64.b64encode(temp_bytes).decode("ascii"))
temp_image=Part.from_data(
mime_type="image/png",
data=temp_bytes)
caption_input.append(temp_image)
captions = caption_model.generate_content(
caption_input + [text_prompt],
generation_config=caption_generation_config,
safety_settings=safety_settings,
)
captions_list = make_captions(captions)
resp_images_dict = []
for img, cap in zip(image_strings, captions_list):
resp_images_dict.append({"image": img, "caption": cap["description"], "title": cap["title"]})
resp = flask.jsonify(resp_images_dict)
resp.headers.set("Access-Control-Allow-Origin", "*")
return resp
def get_images_with_count(image_prompt, image_count):
current_image_count = 0
images = []
while current_image_count < image_count:
remaining_image_count = image_count - current_image_count
allowed_image_count = min(VERTEX_MAX_IMAGE_COUNT, remaining_image_count)
temp_images = image_model.generate_images(
prompt=image_prompt,
# Optional parameters
number_of_images=allowed_image_count,
language="en",
# You can't use a seed value and watermark at the same time.
# add_watermark=False,
# seed=100,
aspect_ratio="1:1",
safety_filter_level="block_some",
person_generation="allow_adult",
)
images.extend(temp_images)
current_image_count = len(images)
print(f'Images generated so far: {current_image_count}')
return images
def make_captions(captions):
captions_text = captions.text
# Sometimes the result is returned with a json field specifier
if captions_text.startswith("```json"):
captions_text = captions_text[7:-4]
captions_list = json.loads(captions_text)
final_captions = []
for caption in captions_list:
title = caption["title"]
desc = caption["info"]
final_captions.append({"title": title, "description": desc})
return final_captions