""" Cкрипт обучения простой ML-модели. Используется для демонстрации работы фреймворка Hydra. Решается задача анализа зависимости состояния здоровья дерева (health) от его диаметра ствола (tree_dbh) и типа дерева (spc_common). Запуск из консоли: python src/trees-training-multirun.py (с параметрами по умолчанию из src/hydra/config.yaml) python src/trees-training-multirun.py -m params.epochs=20 (с переопределением параметра) python src/trees-training-multirun.py -m params.batch_size=32,64 (мультизапуск с перебором параметров) """ import os import logging import gzip import pandas as pd from sklearn.preprocessing import LabelEncoder, StandardScaler from sklearn.model_selection import train_test_split import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import hydra from hydra.core.config_store import ConfigStore from src.hydra.config import TreesConfig from inc.microfuncs import Microfuncs from inc.templates import MdTemplates logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) cs = ConfigStore.instance() cs.store(name="trees_config", node=TreesConfig) @hydra.main(config_path=".", config_name="config", version_base=None) def main(cfg: TreesConfig) -> None: # Определение кол-ва ядер и настройка для TensorFlow try: num_cores = int(cfg.params.num_cores) if num_cores > 0: tf.config.threading.set_intra_op_parallelism_threads(num_cores) tf.config.threading.set_inter_op_parallelism_threads(num_cores) logger.info(f'Set the number of cores: {num_cores}') except ValueError: logger.info("The value of the cfg.params.num_cores parameter is of invalid type") data_file = os.path.join(os.path.join(cfg.paths.data, cfg.files.train_data)) # Собираем имена файлов для этого запуска: pathto/datafile-artifacts-XXXXXX.png (и такой же .md) short_random = Microfuncs.generate_random_string(6) plot_file = os.path.join(cfg.paths.artifacts, Microfuncs.collect_filename(cfg.files.train_data, f"-artifacts-{short_random}", ".png") ) md_file = os.path.join(cfg.paths.artifacts, Microfuncs.collect_filename(cfg.files.train_data, f"-artifacts-{short_random}", ".md") ) # Загружаем датасет with gzip.open(data_file, 'rb') as gz_file: data = pd.read_csv(gz_file) logger.info(f'Dataset has been loaded. Shape of data: {data.shape}') # Предобработка данных data = data[['tree_dbh', 'spc_common', 'health']] data = data.dropna() # Кодирование категориальных признаков label_encoder = LabelEncoder() data['spc_common'] = label_encoder.fit_transform(data['spc_common']) data['health'] = label_encoder.fit_transform(data['health']) # Разделение данных на train и test X = data[['tree_dbh', 'spc_common']] y = data['health'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Нормализация числовых признаков scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) X_test_scaled = scaler.transform(X_test) # Создание модели model = keras.Sequential([ layers.Dense(64, activation='relu', input_shape=(2,)), layers.Dense(32, activation='relu'), layers.Dense(3, activation='softmax') ]) # Компиляция модели model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Обучение модели history = model.fit(X_train_scaled, y_train, epochs=cfg.params.epochs, batch_size=cfg.params.batch_size, validation_split=cfg.params.validation_split ) # Оценка модели test_loss, test_acc = model.evaluate(X_test_scaled, y_test) logger.info(f'Training is complete. Accuracy: {test_acc:.3f}') # Сохранение графика в файл Microfuncs.save_plot_to_file(history, plot_file) logger.info(f'The training graph has been saved to a file: {plot_file}') # Собираем содержимое md-файла из шаблона с подстановкой значений content = Microfuncs.replace_all(MdTemplates.workflow_artifact, { "": f"{test_acc:.3f}", "": "data/" + os.path.split(data_file)[1], "": "artifacts/" + os.path.split(plot_file)[1] }, ) # Сохраняем md-файл, снегерированный из шаблона with open(md_file, 'w', encoding='utf-8') as f: f.writelines(content) logger.info(f'Artifact has been saved to a file: {md_file}') if __name__ == "__main__": main()