diff --git a/Categorical.ipynb b/Categorical.ipynb index 4e2fb27..146cc20 100644 --- a/Categorical.ipynb +++ b/Categorical.ipynb @@ -6,80 +6,83 @@ "metadata": {}, "outputs": [], "source": [ + "import pandas as pd\n", + "import tensorflow as tf\n", + "import dataset_ops\n", + "import functools\n", + "from transfer_learning import evaluate_model\n", + "import numpy as np # noqa\n", + "import matplotlib.pyplot as plt # noqa\n", + "from model_helper import make_model\n", + "from metrics import F1, Precision, Recall, soft_dice_loss, remove_clutter_one_sample, ClassPrecision, ClassRecall\n", "import datetime\n", - "import os\n", - "\n", + "import cuda\n", + "import pandas_format # noqa\n", + "from pathlib import Path\n", + "from tensorboard.plugins.hparams import api as hp\n", "try:\n", " from tqdm import notebook as tqdm\n", "except ImportError:\n", " tqdm = None\n", - " \n", - "import pandas as pd\n", - "import tensorflow as tf\n", - "import numpy as np\n", + "\n", "%matplotlib inline\n", "%load_ext tensorboard\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "import dataset_ops\n", - "\n", + "cuda.initialize()\n", "\n", - "pd.set_option('display.max_columns', None) # show all columns\n", - "GPUs = tf.config.list_physical_devices('GPU')\n", - "if GPUs is None or len(GPUs) == 0:\n", - " print(\"WARNING: No GPU, all there is is:\")\n", - " for device in tf.config.list_physical_devices():\n", - " print(f'- {device}')\n", - "else:\n", - " for gpu in GPUs:\n", - " tf.config.experimental.set_memory_growth(gpu, True)\n", - " print(\"Initialized\", gpu)\n", "\n", - "dataset_manager = dataset_ops.TestsManager(dataset_dir='./h5', runs_filename='runs.hdf')\n", + "dataset_manager = dataset_ops.MicroPilotTestsManager(dataset_dir=Path('h5'), runs_filename='runs.hdf')\n", + "# dataset_manager = dataset_ops.PaparazziTestManager(dataset_dir=Path('pprz_h5'), runs_filename='pprz_runs.hdf')\n", "all_runs = dataset_manager.get_all_available_tests()\n", "\n", + "\n", "selected_runs = all_runs.loc[(all_runs['Test Length'] > 200) & (all_runs['Test Length'] < 20000)]\n", "# selected_runs = selected_runs.iloc[:40]\n", - "plt = selected_runs['Test Length'].plot(kind='hist', bins=25, figsize=[10,5])\n", - "plt.tick_params(labelsize=14)\n", - "plt.set_xlim([10,18000])\n", - "plt.set_xlabel('Test Length ($l_k$)', fontsize=15)\n", - "plt.set_ylabel('Number of Tests', fontsize=15)\n", - "plt.figure.savefig('test_lengths.png')\n", + "# selected_runs = all_runs.sample(frac=1, axis=1, random_state=55)\n", + "# tl_plot = selected_runs['Test Length'].plot(kind='hist', bins=25, figsize=[10,5])\n", + "# tl_plot.tick_params(labelsize=14)\n", + "# tl_plot.set_xlim([10,18000])\n", + "# tl_plot.set_xlabel('Test Length ($l_k$)', fontsize=15)\n", + "# tl_plot.set_ylabel('Number of Tests', fontsize=15)\n", + "# tl_plot.figure.savefig('paper_data/test_lengths.png')\n", "# #selected_runs\n", - "# print(all_runs.shape, selected_runs.shape)" + "# print(all_runs.shape, selected_runs.shape)\n", + "# selected_runs['Test Length'].mean()" ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "inputs = ('SpeedFts', 'Pitch', 'Roll', 'Yaw', 'current_altitude', )\n", "outputs= ('elev', 'ai', 'rdr', 'throttle', 'Flaps')\n", "\n", - "# max_length = selected_runs['Test Length'].max()\n", - "max_length = 18000 \n", - "\n", + "max_length = 18000\n", "\n", "tfdataset = dataset_ops.TensorflowDataset(dataset_manager)\n", - "dataset = tfdataset.get_dataset(selected_runs, batch_size=25, features=inputs+outputs, max_length=max_length)\n", - "dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", + "train_dataset, test_dataset, validation_dataset = dataset_ops.split_dataset(\n", + " tfdataset.get_dataset(selected_runs, features=inputs+outputs, max_length=max_length),\n", + " split_proportion=(6, 1, 3)\n", + ") # 60% 10% 30% += 100%\n", + "train_dataset, test_dataset, validation_dataset = (\n", + "dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n", + " .batch(25)\n", + " .shuffle(buffer_size=15)\n", + " for dataset in (train_dataset, test_dataset, validation_dataset)\n", + ")\n", "\n", + "assert dataset_manager.count_states() > 0\n", "\n", - "test_dataset = dataset.enumerate().filter(lambda x,y: x % 20 == 0).map(lambda x,y: y)\n", - "validation_dataset = dataset.enumerate().filter(lambda x,y: x % 20 == 1).map(lambda x,y: y)\n", - "train_dataset = dataset.enumerate().filter(lambda x,y: x % 20 > 1).map(lambda x,y: y)\n", - "\n", - "dataset.element_spec" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + "train_dataset.element_spec" + ] }, { "cell_type": "code", @@ -91,53 +94,6 @@ }, "outputs": [], "source": [ - "from model_helper import MaskStealingLayer\n", - "\n", - "def make_model(inputs, outputs, input_length):\n", - " n_in = len(inputs)\n", - " n_out = len(outputs)\n", - " n_features = n_in + n_out\n", - " \n", - " bias_initializer = tf.keras.initializers.Constant(np.log(0.01))\n", - " \n", - " signals = tf.keras.Input(shape=[input_length, n_features], name='signals')\n", - " mask = tf.keras.Input(shape=[input_length, 1], name='mask')\n", - " \n", - " x = signals\n", - " x = MaskStealingLayer(0)((x, mask))\n", - " \n", - " x = tf.keras.layers.Conv1D(filters=64, kernel_size=3, padding=\"same\", name='conv_3')(x)\n", - " # x = tf.keras.layers.MaxPool1D(pool_size=2)(x)\n", - " x = tf.keras.layers.Conv1D(filters=64, kernel_size=5, padding=\"same\", name='conv_5')(x)\n", - "# x = tf.keras.layers.MaxPool1D(pool_size=2)(x)\n", - " x = tf.keras.layers.Conv1D(filters=64, kernel_size=10, padding=\"same\", name='conv_10')(x)\n", - "# x = tf.keras.layers.MaxPool1D(pool_size=2)(x)\n", - " x = tf.keras.layers.Conv1D(filters=64, kernel_size=15, padding=\"same\", name='conv_15')(x)\n", - "\n", - " x = tf.keras.layers.Conv1D(filters=64, kernel_size=20, padding=\"same\", name='conv_20')(x)\n", - "# x = tf.keras.layers.MaxPool1D(pool_size=2)(x)\n", - " \n", - " x = tf.keras.layers.GRU(128, return_sequences=True)(x)\n", - "# x = tf.keras.layers.LeakyReLU()(x)\n", - " \n", - " x = tf.keras.layers.GRU(128, return_sequences=True)(x)\n", - "# x = tf.keras.layers.LeakyReLU()(x)\n", - " \n", - " x = tf.keras.layers.Dense(128)(x)\n", - " x = tf.keras.layers.LeakyReLU()(x)\n", - " \n", - " x = tf.keras.layers.Dense(dataset_manager.count_states(), bias_initializer=bias_initializer, activation='softmax')(x)\n", - "\n", - " # x = tf.keras.layers.UpSampling1D(2 ** 1)(x)\n", - " \n", - " model = tf.keras.Model(inputs=[signals, mask], outputs=x)\n", - " \n", - " return model\n", - "\n", - "# %%\n", - "\n", - "from metrics import F1, Precision, Recall, soft_dice_loss, remove_clutter_one_sample, ClassPrecision, ClassRecall\n", - "\n", "def create_prec_recall_f1(tolerance):\n", " prec = Precision(name=f'prec_{tolerance}', tolerance=tolerance)\n", " recl = Recall(name=f'recall_{tolerance}', tolerance=tolerance)\n", @@ -150,13 +106,31 @@ " F1(prec, recl),\n", " ]\n", "\n", + "\n", "evaluation_metrics = create_prec_recall_f1(25)\n", + "metrics_reporting = (create_prec_recall_f1(5)[:-1] +\n", + " create_prec_recall_f1(15)[:-1] +\n", + " create_prec_recall_f1(25)[:-1] +\n", + " [ClassPrecision(), ClassRecall()])\n", + "\n", "optimizer = tf.keras.optimizers.Adam(lr=3e-5)\n", "\n", - "model = make_model(inputs, outputs, max_length)\n", - "model.compile(loss=soft_dice_loss, optimizer=optimizer, metrics=evaluation_metrics)\n", + "# MP:\n", + "mp_model_builder = functools.partial(make_model, inputs, outputs, max_length, n_states=dataset_manager.count_states())\n", + "\n", + "full_model = mp_model_builder(convs=[(64, 3), (64, 5), (64, 10), (64, 15), (64, 20)], grus=[128, 128], name='mp_model')\n", + "full_model.summary()\n", + "full_model.compile(loss=soft_dice_loss, optimizer=optimizer, metrics=evaluation_metrics)\n", "\n", - "# model.summary()" + "cnn_baseline_model = mp_model_builder(convs=[(64, 3), (64, 5), (64, 10), (64, 15), (64, 20)], grus=[], name='convolutional_baseline')\n", + "cnn_baseline_model.summary()\n", + "cnn_baseline_model.compile(loss=soft_dice_loss, optimizer=optimizer, metrics=evaluation_metrics)\n", + "\n", + "rnn_baseline_model = mp_model_builder(convs=[(1, 1)], grus=[128, 128], name='recurrent_baseline')\n", + "rnn_baseline_model.summary()\n", + "rnn_baseline_model.compile(loss=soft_dice_loss, optimizer=optimizer, metrics=evaluation_metrics)\n", + "\n", + "evaluation_results = {}" ] }, { @@ -169,40 +143,52 @@ }, "outputs": [], "source": [ - "epochs = 100\n", + "epochs = 500\n", "# epochs = 5\n", "\n", - "training_start_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", - "log_dir=\"logs/fit/\" + training_start_time\n", - "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", - "\n", - "history = model.fit(train_dataset,\n", - " epochs=epochs,\n", - " validation_data=validation_dataset,\n", - " callbacks=[\n", - " tensorboard_callback,\n", - " tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5),\n", - " ])\n", - "model.save(f'models/categroical-{training_start_time}-{epochs}.h5')\n", - "tf.keras.utils.plot_model(model, show_shapes=True, to_file=f'models/categorical-{training_start_time}-{epochs}.png')\n", - "model.evaluate(test_dataset)\n", - "# history.history" + "for model_name, model in zip(('full', 'rnn', 'cnn',), (full_model, rnn_baseline_model, cnn_baseline_model)):\n", + " training_start_time = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", + " log_dir=\"logs/fit/\" + training_start_time\n", + " file_name = f'models/mp_cameraready-{model_name}-{training_start_time}-{epochs}.h5'\n", + " if Path(file_name).exists():\n", + " model.load_weights(file_name)\n", + " else:\n", + " tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n", + "\n", + " history = model.fit(train_dataset,\n", + " epochs=epochs,\n", + " validation_data=validation_dataset,\n", + " callbacks=[\n", + " tensorboard_callback,\n", + " tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5),\n", + " ])\n", + " model.save(file_name)\n", + " tf.keras.utils.plot_model(model, show_shapes=True, to_file=file_name.replace('.h5', '.png'))\n", + "\n", + " evaluation_results[model_name] = evaluate_model(model, validation_dataset)" ] }, { "cell_type": "code", "execution_count": null, - "outputs": [], - "source": [ - "print(model.evaluate(test_dataset))\n", - "print(model.evaluate(validation_dataset))" - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "rq3 = pd.DataFrame(evaluation_results).set_index(pd.Index(\n", + " ['prec_5', 'recall_5', 'prec_15', 'recall_15', 'prec_25', 'recall_25', 'class_precision', 'class_recall'])).T\n", + "for _tau in [5, 15, 25]:\n", + " p = rq3[f'prec_{_tau}']\n", + " r = rq3[f'recall_{_tau}']\n", + " f1 = 2 * p * r / (p + r)\n", + " rq3.insert(rq3.columns.to_list().index(f'recall_{_tau}') + 1, f\"F1_{_tau}\", f1)\n", + "\n", + "p, r = rq3['class_precision'], rq3['class_recall']\n", + "rq3['class_F1'] = 2 * p * r / (p + r)" + ] }, { "cell_type": "code", @@ -214,79 +200,70 @@ }, "outputs": [], "source": [ - "model.load_weights('models/categroical-20200325-225905-100.h5')\n", - "# model.load_weights('models/categroical-20200507-225905-100.h5')\n", - "import matplotlib.pyplot as plt" + "with pandas_format.PandasFloatFormatter('{:,.2f}%'):\n", + " print((rq3*100).T.to_latex())" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(15, 2, figsize=(15, 8), sharex=True)\n", - "N = 600\n", + "# N = 2500\n", + "N = 1000\n", + "\n", + "folder_name = Path('plots') / 'output_compare' / ('categorical_' + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\"))\n", + "if not folder_name.exists():\n", + " folder_name.mkdir(parents=True)\n", "\n", - "folder_name = f'Batch/categorical_{datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")}'\n", - "# os.mkdir(folder_name)\n", "print('Plotting in', folder_name)\n", "\n", - "metrics_reporting = evaluation_metrics[:-1] + [\n", - " Precision(name='prec_15', tolerance=15),\n", - " Recall(name='recall_15', tolerance=15),\n", - " Precision(name='prec_5', tolerance=5),\n", - " Recall(name='recall_5', tolerance=5),\n", - " # ClassPrecision(),\n", - " # ClassRecall(),\n", - "]\n", "results = []\n", - "for set_name, data_set in (('Training', train_dataset), ('Test', test_dataset)):\n", + "for set_name, data_set in (('Training', train_dataset), ('Test', test_dataset), ('Validation', validation_dataset)):\n", " for bi, data in data_set.unbatch().batch(30).enumerate():\n", - " ins, gt = data\n", + " ins, ground_truth = data\n", " prediction = model.predict_on_batch(ins)\n", " mask = tf.squeeze(ins['mask'], axis=-1)\n", - " gt = tf.squeeze(gt)\n", + " ground_truth = tf.squeeze(ground_truth)\n", "\n", " for metric in metrics_reporting: \n", " metric.reset_states()\n", - " metric.update_state(gt, prediction)\n", + " metric.update_state(ground_truth, prediction)\n", " \n", " prediction = tf.math.argmax(prediction, axis=-1)\n", " no_clutter = tf.map_fn(remove_clutter_one_sample, prediction)\n", - " gt = tf.math.argmax(gt, axis=-1)\n", + " ground_truth = tf.math.argmax(ground_truth, axis=-1)\n", " \n", " run_length = tf.math.minimum(tf.argmin(mask, axis=-1), N)\n", - " \n", + " max_run_length_in_batch = int(tf.math.reduce_max(run_length))\n", + "\n", " results.append(\n", " [set_name] + [float(metric.result()) for metric in metrics_reporting] \n", " )\n", - " \n", - " # for prednc, truth, idx, ax in zip(no_clutter, gt, run_length, axs.reshape(-1)):\n", - " # if idx == 0:\n", - " # idx = N\n", - " # truth, prednc = truth[:idx], prednc[:idx]\n", - " # \n", - " # concat = tf.stack((prednc, truth), axis=0)\n", - " # ax.imshow(concat, aspect='auto', interpolation='nearest')\n", - " # ax.set_yticklabels(['', '$\\\\hat{O}$', '$O$'])\n", - " # ax.set_xlim([1, N])\n", - " # \n", - " # plt.tight_layout()\n", - " # fig.savefig(f'{folder_name}/{set_name}_{bi}.png')\n", - " # for ax in axs.reshape(-1): ax.clear()\n", + "\n", + " for prednc, truth, idx, ax in zip(no_clutter, ground_truth, run_length, axs.reshape(-1)):\n", + " if idx == 0:\n", + " idx = N\n", + " truth, prednc = truth[:idx], prednc[:idx]\n", + "\n", + " concat = tf.stack((prednc, truth), axis=0)\n", + " # ax.imshow(concat, aspect='auto', interpolation='nearest', vmin=0, vmax=dataset_manager.count_states())\n", + " ax.imshow(concat, aspect='auto', interpolation='nearest')#, vmin=0, vmax=dataset_manager.count_states())\n", + " ax.set_yticklabels(['', '$\\\\hat{O}$', '$O$'])\n", + " ax.set_xlim([1, max_run_length_in_batch])\n", + "\n", + " plt.tight_layout()\n", + " fig.savefig(folder_name / f'{set_name}_{bi}.png')\n", + " for ax in axs.reshape(-1): ax.clear()\n", "plt.close()\n", "columns = ['Dataset'] + [metric.name for metric in metrics_reporting] \n", "results = pd.DataFrame(results, columns=columns)\n", - "pd.options.display.float_format = '{:,.2f}'.format\n", "\n", "# results['class_precision'] *= 100\n", "# results['class_recall'] *= 100\n", - "# results['F1'] = 2*results['class_precision']*results['class_recall'] / (results['class_precision']+results['class_recall'])\n", + "results['class_F1'] = 2*results['class_precision']*results['class_recall'] / (results['class_precision']+results['class_recall'])\n", "\n", "results" ] @@ -294,11 +271,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "df = results.groupby('Dataset').aggregate('mean')*100\n", @@ -314,33 +287,13 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "is_executing": false, - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "from sklearn.preprocessing import minmax_scale\n", - "# plt.plot(train_dataset[0]['signals'])\n", - "i = [*train_dataset.unbatch().take(1).as_numpy_iterator()][0][0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "pycharm": { - "is_executing": false, - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ - "inp = model.input\n", - "outputs = [layer.output for layer in model.layers[3:]]\n", - "functors = [tf.keras.backend.function([inp], [output]) for output in outputs]" + "paper_results = df\n", + "\n", + "with pandas_format.PandasFloatFormatter('{:,.2f}%'):\n", + " print(paper_results.loc[['Validation']].T.to_latex())" ] }, { @@ -348,40 +301,13 @@ "execution_count": null, "metadata": { "pycharm": { - "is_executing": false, "name": "#%%\n" } }, "outputs": [], "source": [ - "inputs=[*train_dataset.unbatch().batch(1).take(1).as_numpy_iterator()]\n", - "layer_outs = [func(inputs) for func in functors]\n", - "layer_outs\n", - "# conv1 = functors[0]\n", - "# conv1()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.plot(minmax_scale(layer_outs[9][0][0][150:900,-4]))\n", - "plt.plot(minmax_scale(layer_outs[9][0][0][150:900,-3]) + 1)\n", - "plt.plot(minmax_scale(layer_outs[9][0][0][150:900,-5]) + 2)\n", - "plt.plot(minmax_scale(layer_outs[9][0][0][150:900,-8]) + 3)\n", - "plt.axis('off')\n", - "plt.savefig('signal-out.png', bbox_inches='tight')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "[*enumerate(model.layers[3:])]" + "with pandas_format.PandasFloatFormatter('{:,.2f}%'):\n", + " print(paper_results.loc[['Validation', 'Test', 'Training'], ~paper_results.columns.str.contains('class_')].to_latex())" ] }, { @@ -389,40 +315,14 @@ "execution_count": null, "metadata": { "pycharm": { - "is_executing": false, "name": "#%%\n" } }, "outputs": [], "source": [ - "plt.plot(minmax_scale(i['signals'][150:900,-4]))\n", - "plt.plot(minmax_scale(i['signals'][150:900,-3]) + 1)\n", - "plt.plot(minmax_scale(i['signals'][150:900,-5]) + 2)\n", - "plt.plot(minmax_scale(i['signals'][150:900,-8]) + 3)\n", - "plt.axis('off')\n", - "plt.savefig('signal.png', bbox_inches='tight')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model.summary()" + "with pandas_format.PandasFloatFormatter('{:,.2f}%'):\n", + " display(paper_results.loc[['Validation']].T)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } } ], "metadata": { @@ -446,4 +346,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/dataset_ops.py b/dataset_ops.py index 8273b24..a891b44 100644 --- a/dataset_ops.py +++ b/dataset_ops.py @@ -2,7 +2,7 @@ import re from collections import defaultdict from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np import pandas as pd @@ -216,8 +216,7 @@ def _gen(): return _gen - def get_dataset(self, selected_runs: pd.DataFrame, *, features: List[str], max_length: int, - batch_size: int = -1) -> tf.data.Dataset: + def get_dataset(self, selected_runs: pd.DataFrame, *, features: List[str], max_length: int) -> tf.data.Dataset: ds = tf.data.Dataset.from_generator( self._create_padded_generator(selected_runs, features=features, max_length=max_length), output_types=({'signals': tf.float32, 'mask': tf.float32}, tf.float32), @@ -227,7 +226,45 @@ def get_dataset(self, selected_runs: pd.DataFrame, *, features: List[str], max_l }, tf.TensorShape([max_length, self.dataset_manager.count_states()])), ) - if batch_size == -1: - return ds - return ds.batch(batch_size) + return ds + + +def split_dataset(ds, split_proportion: Tuple[int, int, int]): + def get_second(_, x): + return x + + limits = [0] + for p in split_proportion: + limits.append(limits[-1] + p) + + def make_sieve(idx): + lower, upper, total = limits[idx], limits[idx + 1], limits[-1] + + def _sieve(index, _): + r = index % total + return (lower <= r) and (r < upper) + return tf.function(_sieve) + + return tuple( + ds.enumerate().filter(make_sieve(idx)).map(get_second) + for idx in range(len(split_proportion)) + ) + + +def load_and_split(dataset_manager, selected_runs, features, split_ratio, batch_size, max_length=None): + if not max_length: + max_length = selected_runs['Test Length'].max() + + tfdataset = TensorflowDataset(dataset_manager) + ds = tfdataset.get_dataset(selected_runs, features=features, max_length=max_length) + train_dataset, test_dataset, validation_dataset = split_dataset(ds, split_proportion=split_ratio) + train_dataset, test_dataset, validation_dataset = ( + dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + .batch(batch_size) + .shuffle(buffer_size=15) + for dataset in (train_dataset, test_dataset, validation_dataset) + ) + + return ds, train_dataset, test_dataset, validation_dataset + diff --git a/model_helper.py b/model_helper.py index 572b4ab..97c7b7a 100644 --- a/model_helper.py +++ b/model_helper.py @@ -31,18 +31,21 @@ def get_model_input_output_layers(inputs, outputs, input_length, n_states, convs x = signals x = MaskStealingLayer(0)((x, mask)) + last_layer_size = input_length for filters, conv_size in convs: x = tf.keras.layers.Conv1D(filters=filters, kernel_size=conv_size, padding="same", name=f'conv_{conv_size}')(x) + last_layer_size = conv_size for gru_size in grus: x = tf.keras.layers.GRU(gru_size, return_sequences=True)(x) # x = tf.keras.layers.LeakyReLU()(x) + last_layer_size = gru_size if skip_denses >= 3: raise ValueError('There are only 2 dense layers in the end!') if not skip_denses >= 2: - x = tf.keras.layers.Dense(grus[-1])(x) + x = tf.keras.layers.Dense(last_layer_size)(x) x = tf.keras.layers.LeakyReLU()(x) if not skip_denses >= 1: @@ -54,6 +57,7 @@ def get_model_input_output_layers(inputs, outputs, input_length, n_states, convs def make_model(*args, **kwargs): + name = kwargs.pop('name', None) input_layer, output_layer = get_model_input_output_layers(*args, **kwargs) - model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + model = tf.keras.Model(inputs=input_layer, outputs=output_layer, name=name) return model