From ed41a945b40d00ba3a5b18c990b33ad1942e19b8 Mon Sep 17 00:00:00 2001 From: Michael <m.langguth@fz-juelich.de> Date: Tue, 27 Oct 2020 12:05:37 +0100 Subject: [PATCH] First initial release of config_train.py, which is still incomplete. .gitignore has also been adapted. --- .gitignore | 3 +- .../HPC_scripts/config_train.py | 67 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 video_prediction_tools/HPC_scripts/config_train.py diff --git a/.gitignore b/.gitignore index 6dc2f339..b3079606 100644 --- a/.gitignore +++ b/.gitignore @@ -124,6 +124,7 @@ virt_env*/ # Ignore (Batch) runscripts video_prediction_tools/HPC_scripts/** !video_prediction_tools/HPC_scripts/*_template.sh +!video_predcition_tools/HPC_scripts/config_train.py video_prediction_tools/Zam347_scripts/** !video_prediction_tools/Zam347_scripts/*_template.sh - +!video_predcition_tools/Zam347_scripts/config_train.py diff --git a/video_prediction_tools/HPC_scripts/config_train.py b/video_prediction_tools/HPC_scripts/config_train.py new file mode 100644 index 00000000..66a47f58 --- /dev/null +++ b/video_prediction_tools/HPC_scripts/config_train.py @@ -0,0 +1,67 @@ +""" +Basic task of the Python-script: + +Creates user-defined runscripts for training, set ups a user-defined target directory and allows for full control +on the setting of hyperparameters. +""" + +__email__ = "b.gong@fz-juelich.de" +__authors__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" +__date__ = "2020-10-27" + +# import modules +import sys, os, glob +import numpy as np +import datetime as dt +import json as js +import metadata +sys.path.append(path.abspath('../video_prediction/')) +from models import get_model_class + +known_architectures = ["savp","convLSTM","vae","mcnet"] +if not (model in known_architectures): + + +# start script + +def main(): + # get required information from the user by keyboard interaction + + # path to preprocessed data + exp_dir = input("Enter the path to the preprocessed data (directory where tf-records files are located):\n") + exp_dir = os.path.join(exp_dir,"train") + # sanity check (does preprocessed data exist?) + if not (os.path.isdir(exp_dir)): + raise NotADirectoryError("Passed path to preprocessed data '"+exp_dir+"' does not exist!") + file_list = glob.glob(os.path.join(exp_dir,"sequence*.tfrecords")) + if len(file_list) == 0: + raise FileNotFoundError("Passed path to preprocessed data '"+exp_dir+"' exists,"+\ + "but no tfrecord-files can be found therein") + + # path to virtual environment to be used + venv_name = input("Enter the name of the virtual environment which should be used:\n") + + # sanity check (does virtual environment exist?) + if not (os.path.isfile("../",venv_name,"bin","activate")): + raise FileNotFoundError("Could not find a virtual environment named "+venv_name) + + # model + + + # experimental ID + exp_id = input("Enter your desired experimental id:\n") + + # also get current timestamp and user-name + timestamp = dt.datetime.now().strftime("%Y%m%dT%H%M%S") + user_name = os.environ["USER"] + + + + + + + + + + + -- GitLab