-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
124 lines (113 loc) · 4.56 KB
/
Copy pathmain.py
File metadata and controls
124 lines (113 loc) · 4.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# stacks/main_stack.py
from aws_cdk import Stack
from constructs import Construct
from .auth import CognitoSamlAuth
from .backend import RagBackend
from .frontend import RagFrontend
from .ingest import RagIngest
from .waf import Waf
class RagChatbotStack(Stack):
def __init__(
self,
scope: Construct,
construct_id: str,
embeddings_model_id: str,
video_text_model_id: str,
opensearch_collection_name: str,
opensearch_index_name: str,
chat_model: str,
embedding_model: str,
chat_prompt: str,
classifier_model: str,
document_filter_model: str,
platform_classifier_prompt: str,
document_filter_prompt: str,
config_path: str,
max_concurrency: int,
step_function_timeout_hours: int,
chunk_size: str,
overlap: str,
docs_retrieved: int,
docs_after_falloff: int,
conversation_history_turns: int = 4,
max_history_characters: int = 100000,
temperature: float = 1.0,
top_p: float = 0.999,
max_tokens: int = 4096,
api_key_value: str = None,
# Cognito / SAML auth (optional)
cognito_domain_prefix: str = None,
saml_idp_name: str = None,
saml_idp_metadata_url: str = None,
saml_attribute_mapping: dict = None,
# Custom domain for CloudFront (optional)
frontend_domain_name: str = None,
frontend_certificate_arn: str = None,
**kwargs,
) -> None:
super().__init__(scope, construct_id, **kwargs)
ingest_stack = RagIngest(
self,
"RagIngest",
opensearch_index_name=opensearch_index_name,
opensearch_collection_name=opensearch_collection_name,
embeddings_model_id=embeddings_model_id,
video_text_model_id=video_text_model_id,
region=self.region,
max_concurrency=max_concurrency,
step_function_timeout_hours=step_function_timeout_hours,
chunk_size=chunk_size,
overlap=overlap,
)
# Create WAF WebACL before CloudFront (must be in us-east-1 for CloudFront scope)
waf = Waf(self, "Waf")
# Create frontend first to get CloudFront distribution domain
frontend_stack = RagFrontend(
self,
"RagFrontend",
web_acl_id=waf.web_acl_arn,
domain_name=frontend_domain_name,
certificate_arn=frontend_certificate_arn,
)
# Create Cognito SAML auth (optional — only if all required config is provided)
if all([cognito_domain_prefix, saml_idp_name, saml_idp_metadata_url, saml_attribute_mapping]):
# Allow both the custom domain (primary) and the default CloudFront URL
# (fallback) as valid OAuth callbacks.
extra_callbacks = []
if frontend_stack.custom_domain_name:
extra_callbacks.append(f"https://{frontend_stack.distribution_domain_name}")
CognitoSamlAuth(
self,
"CognitoSamlAuth",
cognito_domain_prefix=cognito_domain_prefix,
saml_idp_name=saml_idp_name,
saml_idp_metadata_url=saml_idp_metadata_url,
saml_attribute_mapping=saml_attribute_mapping,
cloudfront_url=frontend_stack.public_url,
extra_callback_urls=extra_callbacks,
)
# Create backend stack with frontend domain for CORS configuration
rag_api_stack = RagBackend(
self,
"RagBackend",
opensearch_endpoint=ingest_stack.opensearch_endpoint,
opensearch_index_name=opensearch_index_name,
opensearch_collection_arn=ingest_stack.collection_arn,
chat_model=chat_model,
embedding_model=embedding_model,
chat_prompt=chat_prompt,
classifier_model=classifier_model,
document_filter_model=document_filter_model,
platform_classifier_prompt=platform_classifier_prompt,
document_filter_prompt=document_filter_prompt,
bucket_arn=ingest_stack.bucket_arn,
docs_retrieved=docs_retrieved,
docs_after_falloff=docs_after_falloff,
conversation_history_turns=conversation_history_turns,
max_history_characters=max_history_characters,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
api_key_value=api_key_value,
frontend_distribution_domain=frontend_stack.public_domain_name,
)