11#!/usr/bin/env python
22
33import argparse
4+ import getpass # for getpass.getuser()
45import glob
6+ import logging
57import os
8+ import shutil
69
7- import pandas as pd
810import pkg_resources
9- import ruamel .yaml as yaml
11+ import re
12+ import readline
13+ import socket
1014import sys
1115
12- import logging
16+ hostname = socket .gethostname ()
17+ username = getpass .getuser ()
18+
19+ import pandas as pd
20+ import ruamel .yaml as yaml
21+
22+ LOG_FMT = '%(asctime)s %(levelname).1s - %(message)s' .format (hostname )
23+ logging .basicConfig (level = logging .INFO , format = LOG_FMT , datefmt = "%Y/%m/%d %H:%M:%S" )
24+ log = logging .getLogger ()
1325
1426from shutil import copyfile
1527from pyace .generalfit import GeneralACEFit
1830from pyace .atomicenvironment import calculate_minimal_nn_atomic_env , calculate_minimal_nn_tp_atoms
1931from pyace .validate import plot_analyse_error_distributions
2032
21- files_to_remove = ["fitting_data_info.csv" , "log.txt" , "nohup.out" ,
33+ files_to_remove = ["fitting_data_info.csv" , "fitting_data_info.pckl.gzip" , " log.txt" , "nohup.out" ,
2234 "target_potential.yaml" , "current_extended_potential.yaml" , "output_potential.yaml" ,
2335 "ladder_metrics.txt" , "cycle_metrics.txt" , "metrics.txt" ,
2436 "test_ladder_metrics.txt" , "test_cycle_metrics.txt" , "test_metrics.txt" ,
25- "train_pred.pckl.gzip" , "test_pred.pckl.gzip"
37+ "train_pred.pckl.gzip" , "test_pred.pckl.gzip" ,
38+ "test_ef-distributions.png" , "train_ef-distributions.png" , "report"
2639 ]
2740
2841DEFAULT_SEED = 42
2942
30- log = logging .getLogger ()
31-
3243
3344def main (args ):
3445 parser = argparse .ArgumentParser (prog = "pacemaker" , description = "Fitting utility for atomic cluster expansion "
@@ -86,7 +97,7 @@ def main(args):
8697 default = False )
8798
8899 parser .add_argument ("-t" , "--template" ,
89- help = "Create a template 'input.yaml' file" ,
100+ help = "Generate a template 'input.yaml' file by dialog " ,
90101 dest = "template" , action = "store_true" ,
91102 default = False )
92103
@@ -117,23 +128,22 @@ def main(args):
117128 sys .exit (0 )
118129
119130 if args_parse .clean :
120- print ("Cleaning working directory... " )
131+ print ("Cleaning working directory. Removing files/folders: " )
121132
122133 interim_potentails = glob .glob ("interim_potential*.yaml" )
123134 ensemble_potentails = glob .glob ("ensemble_potential*.yaml" )
124135 for filename in sorted (files_to_remove + interim_potentails + ensemble_potentails ):
125136 if os .path .isfile (filename ):
126137 os .remove (filename )
127- print (" - " + filename )
128- print ("done" )
138+ print (" - " , filename )
139+ elif os .path .isdir (filename ):
140+ shutil .rmtree (filename )
141+ print (" - " , filename , "(folder)" )
142+ print ("Done" )
129143 sys .exit (0 )
130144
131145 if args_parse .template :
132- print ("Creating template 'input.yaml'..." , end = "" )
133- template_input_yaml_filename = pkg_resources .resource_filename ('pyace.data' , 'input_template.yaml' )
134- copyfile (template_input_yaml_filename , "input.yaml" )
135- print ("done" )
136- sys .exit (0 )
146+ generate_template_input ()
137147
138148 if args_parse .dry_run :
139149 log .info ("====== DRY RUN ======" )
@@ -144,12 +154,14 @@ def main(args):
144154 log_file_name = args_parse .log
145155 log .info ("Redirecting log into file {}" .format (log_file_name ))
146156 fileh = logging .FileHandler (log_file_name , 'a' )
147- formatter = logging .Formatter ('%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
157+ # formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
158+ formatter = logging .Formatter (LOG_FMT )
148159 fileh .setFormatter (formatter )
149- # log = logging.getLogger()
150160 log .addHandler (fileh )
151161
152162 log .info ("Start pacemaker" )
163+ log .info ("Hostname: {}" .format (hostname ))
164+ log .info ("Username: {}" .format (username ))
153165 log .info ("pacemaker/pyace version: {}" .format (__version__ ))
154166 log .info ("ace_evaluator version: {}" .format (get_ace_evaluator_version ()))
155167 log .info ("Loading {}... " .format (input_yaml_filename ))
@@ -199,10 +211,9 @@ def main(args):
199211 # raise ValueError("'backend' section is not given")
200212
201213 if "backend" in args_parse :
202- backend_config ["evaluator" ]= args_parse .backend
214+ backend_config ["evaluator" ] = args_parse .backend
203215 log .info ("Backend settings is overwritten from arguments: " , backend_config )
204216
205-
206217 if 'evaluator' in backend_config :
207218 evaluator_name = backend_config ['evaluator' ]
208219 else :
@@ -289,16 +300,97 @@ def main(args):
289300 pred_data = predict_and_save (general_fit , target_bbasisconfig , general_fit .fitting_data ,
290301 fname = "train_pred.pckl.gzip" )
291302 log .info ("Ploting validation graphs" )
292- plot_analyse_error_distributions (pred_data , fig_prefix = "train_" ,fig_path = "report" ,
293- imagetype = backend_config .get ("imagetype" ,"png" ))
303+ plot_analyse_error_distributions (pred_data , fig_prefix = "train_" , fig_path = "report" ,
304+ imagetype = backend_config .get ("imagetype" , "png" ))
294305
295306 if general_fit .test_data is not None :
296307 log .info ("For test data" )
297308 pred_data = predict_and_save (general_fit , target_bbasisconfig , general_fit .test_data ,
298309 fname = "test_pred.pckl.gzip" )
299310 log .info ("Ploting validation graphs" )
300311 plot_analyse_error_distributions (pred_data , fig_prefix = "test_" , fig_path = "report" ,
301- imagetype = backend_config .get ("imagetype" ,"png" ))
312+ imagetype = backend_config .get ("imagetype" , "png" ))
313+
314+
315+ def generate_template_input ():
316+ print ("Generating 'input.yaml'" )
317+ readline .parse_and_bind ("tab: complete" )
318+
319+ # 1. Training set size
320+ train_filename = input ("Enter training dataset filename (ex.: data.pckl.gzip, [TAB] - autocompletion): " )
321+ testset_size_inp = float (input ("Enter test set fraction or size (ex.: 0.05 or [ENTER] - no test set): " ) or 0 )
322+
323+ # 2. Elements
324+ elements_str = input ("""Please enter list of elements (ex.: "Cu", "AlNi", [ENTER] - determine from dataset): """ )
325+ if elements_str :
326+ patt = re .compile ("([A-Z][a-z]?)" )
327+ elements = patt .findall (elements_str )
328+ elements = sorted (elements )
329+ else :
330+ # determine from training set
331+ print ("Trying to load {}" .format (train_filename ))
332+ df = pd .read_pickle (train_filename , compression = "gzip" )
333+ if 'ase_atoms' in df .columns :
334+ print ("Determining available elements..." )
335+ elements_set = set ()
336+ df ["ase_atoms" ].map (lambda at : elements_set .update (at .get_chemical_symbols ()));
337+ elements = sorted (elements_set )
338+ print ("Found elements: " , elements )
339+ else :
340+ print ("ERROR! No `ase_atoms` column found" )
341+ sys .exit (1 )
342+
343+ print ("Number of elements: " , len (elements ))
344+ print ("Elements: " , elements )
345+
346+ # number of functions per element
347+ number_of_functions_per_element = int (input (
348+ """Enter number of functions per element ([ENTER] - default 700): """ ) or 700 )
349+ print ("Number of functions per element: " , number_of_functions_per_element )
350+
351+ cutoff = float (input ("Enter cutoff (Angstrom, default:7.0): " ) or 7.0 )
352+ print ("Cutoff: " , cutoff )
353+
354+ # weighting scheme
355+ default_energy_based_weighting = """{ type: EnergyBasedWeightingPolicy, DElow: 1.0, DEup: 10.0, DFup: 50.0, DE: 1.0, DF: 1.0, wlow: 0.75, energy: convex_hull, reftype: all,seed: 42}"""
356+ weighting = None
357+ while True :
358+ weighting_inp = input (
359+ "Enter weighting scheme type - `uniform` or `energy` ([ENTER] - `uniform`): " ) or 'uniform'
360+ if weighting_inp in ['uniform' , 'energy' ]:
361+ break
362+ if weighting_inp == "energy" :
363+ weighting = default_energy_based_weighting
364+ print ("Use EnergyBasedWeightingPolicy: " , weighting )
365+ else :
366+ weighting = None
367+ print ("Use UniformWeightingPolicy" )
368+
369+ template_input_yaml_filename = pkg_resources .resource_filename ('pyace.data' , 'input_template.yaml' )
370+ copyfile (template_input_yaml_filename , "input.yaml" )
371+ with open ("input.yaml" , "r" ) as f :
372+ input_yaml_text = f .read ()
373+
374+ input_yaml_text = input_yaml_text .replace ("{{ELEMENTS}}" , str (elements ))
375+ input_yaml_text = input_yaml_text .replace ("{{CUTOFF}}" , str (cutoff ))
376+ input_yaml_text = input_yaml_text .replace ("{{DATAFILENAME}}" , train_filename )
377+ input_yaml_text = input_yaml_text .replace ("{{number_of_functions_per_element}}" ,
378+ "number_of_functions_per_element: {}" .format (
379+ number_of_functions_per_element ))
380+ if weighting :
381+ input_yaml_text = input_yaml_text .replace ("{{WEIGHTING}}" , "weighting: " + weighting )
382+ else :
383+ input_yaml_text = input_yaml_text .replace ("{{WEIGHTING}}" , "" )
384+
385+ if testset_size_inp > 0 :
386+ input_yaml_text = input_yaml_text .replace ("{{test_size}}" , "test_size: {}" .format (testset_size_inp ))
387+ else :
388+ input_yaml_text = input_yaml_text .replace ("{{test_size}}" , "" )
389+
390+ with open ("input.yaml" , "w" ) as f :
391+ print (input_yaml_text , file = f )
392+ print ("Input file is written into `input.yaml`" )
393+ sys .exit (0 )
302394
303395
304396def predict_and_save (general_fit , target_bbasisconfig , structures_dataframe , fname ):
0 commit comments