Skip to content

The-Pillars-of-Creation/NTO-second-tour

Repository files navigation

Vitya model

Наша нейросеть Витя умеет различать видео и классифицировать их на 9 категорий

label id
water 0
car 1
cloud 2
food 3
flower 4
dance 5
animal 6
sunset 7
fire 8

Мы использовали ResNet50V2 для покадравой классификации видео и давали тэг по наиболее частому тэгу кадров.

class VityaModel:
    def __init__(self) -> None:
        augmentaion_layer = Sequential([
            RandomFlip("horizontal", seed=RANDOM_SEED),
            RandomRotation(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            RandomZoom(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            RandomHeight(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            RandomWidth(AUGMENTATION_FACTOR, seed=RANDOM_SEED),
            Rescaling(1 / 255.)
        ])

        augmentaion_layer
        base_model = tf.keras.applications.ResNet50V2(include_top=False, weights=None)

        base_model.trainable = False

        input_layer = tf.keras.layers.Input(shape=(224, 224, 3), name="input_layer")
        x = augmentaion_layer(input_layer)
        x = base_model(x, training=False)
        x = tf.keras.layers.GlobalAveragePooling2D(name="global_average_pooling2d")(x)
        output_layer = Dense(len(TAGS), activation=softmax, name="output_layer")(x)

        model = tf.keras.Model(input_layer, output_layer)
        model.compile(
            loss=tf.keras.losses.CategoricalCrossentropy(),
            optimizer=tf.keras.optimizers.Adam(),
            metrics=["accuracy"]
        )
        
        model_checkpoint_path = "./checkpoints/vitya_weights"
        model.load_weights(model_checkpoint_path)
        print("Compiled VityaModel")

        self.model = model

Обучение модели train.ipynb

history_1 = model_1.fit(
    train_data,
    epochs=EPOCHS,
    steps_per_epoch=len(train_data),
    validation_data=test_data,
    validation_steps=int(0.15 * len(test_data)),
    callbacks=[
        tf.keras.callbacks.CSVLogger("history/history.csv"),
        checkpoint_callback,
        create_tensorboard_callback(dir_name="tensorboard", experiment_name="vitya")
    ]
)
Файл Назначение
solution.py итоговое решение для отправки в систему НТО
solution.ipynb тестировка модели на тестовых фотографиях
train.ipynb обучение модели
first_try.ipynb первая попытка обучения модели
dataload.py первоначальная загрузка данных для pytorch, не используется

About

Our solution to NTO second tour

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •