Skip to content

Commit 74a151b

Browse files
Send slack notification when model training finishes
* update remote url of cluster-utils submodule * remove unused ufo-map submodule
1 parent 2c91efd commit 74a151b

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

.gitmodules

+1-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
[submodule "src/ufo-map"]
2-
path = src/ufo_map
3-
url = [email protected]:nikolami/ufo-map.git
41
[submodule "cluster-utils"]
52
path = cluster-utils
6-
url = [email protected]:FlorianNachtigall/cluster-utils.git
3+
url = [email protected]:ai4up/cluster-utils.git

bin/train.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
import time
66
import logging
7+
import datetime
78

89
PROJECT_ROOT = os.path.realpath(os.path.join(__file__, '..', '..'))
910
PROJECT_SRC = os.path.join(PROJECT_ROOT, 'src')
@@ -17,16 +18,20 @@
1718
import preprocessing as pp
1819
from prediction_age import AgePredictor
1920
import cluster_utils.dataset as cluster_dataset
20-
21-
logger = logging.getLogger(__name__)
22-
logger.setLevel(logging.INFO)
21+
import cluster_utils.slack_notifications as slack
2322

2423
COUNTRY = 'spain'
2524
N_CITIES = 4000
2625
CITIES = []
2726
DATA_DIR = '/p/projects/eubucco/data/2-database-city-level-v0_1'
2827
RESULT_DIR = '/p/tmp/floriann/ml-training'
2928

29+
start_time = time.time()
30+
logger = logging.getLogger(__name__)
31+
logger.setLevel(logging.INFO)
32+
slack_channel = os.environ.get('SLACK_CHANNEL')
33+
slack_token = os.environ.get('SLACK_TOKEN')
34+
3035
logger.info('Extracting features...')
3136
df = cluster_dataset.load(country_name=COUNTRY, path=DATA_DIR, cities=CITIES, n_cities=N_CITIES)
3237

@@ -42,3 +47,10 @@
4247
timestr = time.strftime('%Y%m%d-%H-%M-%S')
4348
model_path = f'{RESULT_DIR}/model-{COUNTRY}-{N_CITIES or len(CITIES)}-{timestr}.pkl'
4449
predictor.save(model_path)
50+
51+
logger.info('Sending slack notification...')
52+
try:
53+
duration = str(datetime.timedelta(seconds=time.time() - start_time)).split('.')[0]
54+
slack.send_message(f'Model training for {COUNTRY} finished after {duration}. 🚀', slack_channel, slack_token)
55+
except Exception as e:
56+
logger.error(f'Failed to send Slack message: {e}')

cluster-utils

src/ufo_map

-1
This file was deleted.

0 commit comments

Comments
 (0)