Skip to content

Commit b1956fc

Browse files
Merge pull request #366 from rootcodelabs/ec2-bugfix
Updates to classifier
2 parents c2ce888 + b36bc76 commit b1956fc

File tree

12 files changed

+73
-34
lines changed

12 files changed

+73
-34
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ This repo will primarily contain:
106106
- `JIRA_WEBHOOK_SECRET` – Jira webhook secret you got in **Create Jira Webhook** step.
107107
108108
4. **Create a `.env` file for Jira Configuration:**
109-
- Create a `.env` file called `jira_config.env` and add the following:
109+
- Create a `.env` file in the folder called `jira-verification` and add the following:
110110
```env
111111
JIRA_WEBHOOK_SECRET=<<JIRA_WEBHOOK_SECRET>>
112112
```

docker-compose.gpu.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ services:
311311
- FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id
312312
- UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update
313313
- UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status
314+
- GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata
314315
ports:
315316
- "8003:8003"
316317
networks:
@@ -491,7 +492,7 @@ services:
491492
ports:
492493
- "3008:3008"
493494
env_file:
494-
- jira_config.env
495+
- ./jira-verification/.env
495496
environment:
496497
RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept
497498
networks:

docker-compose.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ services:
305305
- FIND_FINAL_FOLDER_ID_URL=http://hierarchy-validation:8009/find-folder-id
306306
- UPDATE_DATAMODEL_PROGRESS_URL=http://ruuter-private:8088/classifier/datamodel/progress/update
307307
- UPDATE_MODEL_TRAINING_STATUS_ENDPOINT=http://ruuter-private:8088/classifier/datamodel/update/training/status
308+
- GET_DATASET_METADATA_ENDPOINT=http://ruuter-private:8088/classifier/datasetgroup/group/metadata
308309
ports:
309310
- "8003:8003"
310311
networks:
@@ -473,7 +474,7 @@ services:
473474
ports:
474475
- "3008:3008"
475476
env_file:
476-
- jira_config.env
477+
- ./jira-verification/.env
477478
environment:
478479
RUUTER_PUBLIC_JIRA_URL: http://ruuter-public:8086/internal/jira/accept
479480
networks:

model-inference/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class UpdateRequest(BaseModel):
4646
bestBaseModel:str
4747
updateType: Optional[str] = None
4848
progressSessionId: int
49+
dgId:Optional[int]= None
4950

5051
class OutlookInferenceRequest(BaseModel):
5152
inputId:str

model-inference/inference_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def find_missing_classes(self, main_classes, uploaded_classes):
104104

105105

106106

107-
def predict_class(self,text_input):
107+
def predict_class(self,text_input, platform):
108108

109109
logger.info("ENTERING PREDICT CLASS")
110110

@@ -117,10 +117,11 @@ def predict_class(self,text_input):
117117
self.base_model.to(self.device)
118118

119119
logger.info(f"CLASS HIERARCHY FILE {self.hierarchy_file}")
120+
logger.info(f"PLATFORM IN PREDICT CLASS {platform}")
120121

121-
122-
123122
data = self.hierarchy_file
123+
if platform == 'jira':
124+
data = data['classHierarchy']
124125
parent = 1
125126

126127
logger.info(f"DATA - {data}")

model-inference/inference_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ def inference(self, text:str, deployment_platform:str):
6262
if(deployment_platform == "jira" and self.active_jira_model):
6363

6464
logger.info("ENTERING JIRA INFERENCE")
65-
predicted_labels, probabilities = self.active_jira_model.predict_class(text)
65+
predicted_labels, probabilities = self.active_jira_model.predict_class(text, deployment_platform)
6666

6767
logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}")
6868
logger.info(f"PROBABILITIES INSIDE .inference() FUNCTION - {probabilities}")
6969

7070

7171
if(deployment_platform == "outlook" and self.active_outlook_model):
7272
logger.info("ENTERING OUTLOOK INFERENCE")
73-
predicted_labels, probabilities = self.active_outlook_model.predict_class(text)
73+
predicted_labels, probabilities = self.active_outlook_model.predict_class(text, deployment_platform)
7474

7575

7676
logger.info(f"PREDICTED LABELS INSIDE .inference() FUNCTION - {predicted_labels}")

model-inference/model_inference.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
UPDATE_DATAMODEL_PROGRESS_URL = os.getenv("UPDATE_DATAMODEL_PROGRESS_URL")
2020
UPDATE_MODEL_TRAINING_STATUS_ENDPOINT = os.getenv("UPDATE_MODEL_TRAINING_STATUS_ENDPOINT")
2121
RUUTER_PRIVATE_URL = os.getenv("RUUTER_PRIVATE_URL")
22+
GET_DATASET_METADATA_ENDPOINT=os.getenv("GET_DATASET_METADATA_ENDPOINT")
2223

2324
class ModelInference:
2425
def __init__(self):
2526
pass
2627

27-
def get_class_hierarchy_by_model_id(self, model_id):
28+
def get_outlook_class_hierarchy_by_model_id(self, model_id):
2829

2930
try:
3031
logger.info(f"get_class_hierarchy_by_model_id - {model_id}")
@@ -123,9 +124,9 @@ def validate_class_hierarchy(self, class_hierarchy, model_id):
123124

124125

125126

126-
def get_class_hierarchy_and_validate(self, model_id):
127+
def get_outlook_class_hierarchy_and_validate(self, model_id):
127128
try:
128-
class_hierarchy = self.get_class_hierarchy_by_model_id(model_id)
129+
class_hierarchy = self.get_outlook_class_hierarchy_by_model_id(model_id)
129130
if class_hierarchy:
130131
is_valid = self.validate_class_hierarchy(class_hierarchy, model_id)
131132
return is_valid, class_hierarchy
@@ -252,4 +253,33 @@ def create_inference(self, payload):
252253
raise RuntimeError(f"Failed to call create inference. Reason: {e}")
253254

254255

255-
256+
def get_class_hierarchy_by_dg_id(self, cookies, dg_id):
257+
logger.info("********************************************************************")
258+
logger.info(f"****** Calling function get_class_hierarchy_by_dg_id ******")
259+
try:
260+
logger.info(f"get_class_hierarchy_by_dg_id - {dg_id}")
261+
logger.info(f"cookie : {cookies}")
262+
cookies_updated = {"customJwtCookie":cookies}
263+
logger.info(f"cookie_updated : {cookies_updated}")
264+
logger.info(f"GET_DATASET_METADATA_ENDPOINT : {GET_DATASET_METADATA_ENDPOINT}")
265+
266+
response_hierarchy = requests.get(GET_DATASET_METADATA_ENDPOINT, params={'groupId': dg_id}, cookies=cookies_updated)
267+
268+
logger.info(f"response_hierarchy : {response_hierarchy}")
269+
270+
if response_hierarchy.status_code == 200:
271+
logger.info("DATASET HIERARCHY RETREIVAL SUCCESSFUL")
272+
hierarchy = response_hierarchy.json()
273+
logger.info(f"DATASET HIERARCHY - {hierarchy}")
274+
class_hierarchy = hierarchy['response']['data'][0]
275+
logger.info(f"CLASS HIERARCHY - {class_hierarchy}")
276+
return class_hierarchy
277+
278+
else:
279+
logger.error(f"DATASET HIERARCHY RETRIEVAL FAILED: {response_hierarchy.status_code}")
280+
raise RuntimeError(f"ERROR RESPONSE\n {response_hierarchy.text}")
281+
282+
283+
except Exception as e:
284+
logger.error(f"Failed to retrieve the class hierarchy Reason: {e}")
285+
raise RuntimeError(f"Failed to retrieve the class hierarchy Reason: {e}")

model-inference/model_inference_api.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ async def download_outlook_model(request: Request, model_data:UpdateRequest):
7070
model_progress_session_id = model_data.progressSessionId
7171

7272
## Get class hierarchy and validate it
73-
is_valid, class_hierarchy = model_inference.get_class_hierarchy_and_validate(model_data.modelId)
73+
is_valid, class_hierarchy = model_inference.get_outlook_class_hierarchy_and_validate(model_data.modelId)
7474

7575
logger.info(f"IS VALID VALUE : {is_valid}")
7676
logger.info(f"CLASS HIERARCHY VALUE : {class_hierarchy}")
@@ -230,7 +230,7 @@ async def download_jira_model(request: Request, model_data:UpdateRequest):
230230

231231
logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id")
232232

233-
class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId)
233+
class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId)
234234

235235
logger.info(f"JIRA UPDATE CLASS HIERARCHY - {class_hierarchy}")
236236

@@ -354,7 +354,7 @@ async def download_test_model(request: Request, model_data:UpdateRequest):
354354

355355
logger.info("JUST ABOUT TO ENTER get_class_hierarchy_by_model_id")
356356

357-
class_hierarchy = model_inference.get_class_hierarchy_by_model_id(model_data.modelId)
357+
class_hierarchy = model_inference.get_class_hierarchy_by_dg_id(cookies=cookie, dg_id=model_data.dgId)
358358

359359
logger.info(f"TEST UPDATE CLASS HIERARCHY - {class_hierarchy}")
360360

@@ -556,12 +556,10 @@ async def outlook_inference(request:Request, inference_data:OutlookInferenceRequ
556556
async def jira_inference(request:Request, inferenceData:JiraInferenceRequest):
557557
try:
558558

559-
560559
logger.info(f"INFERENCE DATA IN JIRA INFERENCE - {inferenceData}")
561560

562561
model_id = model_inference_wrapper.get_jira_model_id()
563562

564-
565563
if(model_id):
566564
# 1 . Check whether the if the Inference Exists
567565
is_exist, inference_id = model_inference.check_inference_data_exists(input_id=inferenceData.inputId)

model-inference/test_inference_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def inference(self, text: str, model_id: int):
3030
predicted_labels = None
3131
probabilities = None
3232
model = self.model_dictionary[model_id]
33-
predicted_labels, probabilities = model.predict_class(text_input=text)
33+
predicted_labels, probabilities = model.predict_class(text_input=text, platform="test")
3434
return predicted_labels, probabilities
3535
else:
3636
raise Exception(f"Model with ID {model_id} not found")

0 commit comments

Comments
 (0)