Skip to content
Snippets Groups Projects
Commit 3b924699 authored by BING GONG's avatar BING GONG
Browse files

fix bug to pass the frac_save_checkpoint to arguments in main

parent 3ee82854
No related branches found
No related tags found
No related merge requests found
Pipeline #76852 passed
...@@ -594,6 +594,8 @@ def main(): ...@@ -594,6 +594,8 @@ def main():
parser.add_argument("--model", type=str, help="Model class name") parser.add_argument("--model", type=str, help="Model class name")
parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters")
parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use")
parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint")
parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model")
parser.add_argument("--seed", default=1234, type=int) parser.add_argument("--seed", default=1234, type=int)
args = parser.parse_args() args = parser.parse_args()
...@@ -602,7 +604,8 @@ def main(): ...@@ -602,7 +604,8 @@ def main():
#create a training instance #create a training instance
train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict, train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict,
model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset, model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset,
gpu_mem_frac=args.gpu_mem_frac,seed=args.seed,args=args) gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start,
prob_save_model=args.prob_save_model)
print('----------------------------------- Options ------------------------------------') print('----------------------------------- Options ------------------------------------')
for k, v in args._get_kwargs(): for k, v in args._get_kwargs():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment