Skip to content

Commit 392cf4b

Browse files
pacemaker:
- add pacemaker -t option for interactive dialogue - update -c/--clean option update inpute_template.yaml preparedata.py: change default to cache_ref_df=False setup.py: update url
1 parent ac4f57f commit 392cf4b

3 files changed

Lines changed: 143 additions & 215 deletions

File tree

bin/pacemaker.py

Lines changed: 115 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
#!/usr/bin/env python
22

33
import argparse
4+
import getpass # for getpass.getuser()
45
import glob
6+
import logging
57
import os
8+
import shutil
69

7-
import pandas as pd
810
import pkg_resources
9-
import ruamel.yaml as yaml
11+
import re
12+
import readline
13+
import socket
1014
import sys
1115

12-
import logging
16+
hostname = socket.gethostname()
17+
username = getpass.getuser()
18+
19+
import pandas as pd
20+
import ruamel.yaml as yaml
21+
22+
LOG_FMT = '%(asctime)s %(levelname).1s - %(message)s'.format(hostname)
23+
logging.basicConfig(level=logging.INFO, format=LOG_FMT, datefmt="%Y/%m/%d %H:%M:%S")
24+
log = logging.getLogger()
1325

1426
from shutil import copyfile
1527
from pyace.generalfit import GeneralACEFit
@@ -18,17 +30,16 @@
1830
from pyace.atomicenvironment import calculate_minimal_nn_atomic_env, calculate_minimal_nn_tp_atoms
1931
from pyace.validate import plot_analyse_error_distributions
2032

21-
files_to_remove = ["fitting_data_info.csv", "log.txt", "nohup.out",
33+
files_to_remove = ["fitting_data_info.csv", "fitting_data_info.pckl.gzip", "log.txt", "nohup.out",
2234
"target_potential.yaml", "current_extended_potential.yaml", "output_potential.yaml",
2335
"ladder_metrics.txt", "cycle_metrics.txt", "metrics.txt",
2436
"test_ladder_metrics.txt", "test_cycle_metrics.txt", "test_metrics.txt",
25-
"train_pred.pckl.gzip", "test_pred.pckl.gzip"
37+
"train_pred.pckl.gzip", "test_pred.pckl.gzip",
38+
"test_ef-distributions.png", "train_ef-distributions.png", "report"
2639
]
2740

2841
DEFAULT_SEED = 42
2942

30-
log = logging.getLogger()
31-
3243

3344
def main(args):
3445
parser = argparse.ArgumentParser(prog="pacemaker", description="Fitting utility for atomic cluster expansion "
@@ -86,7 +97,7 @@ def main(args):
8697
default=False)
8798

8899
parser.add_argument("-t", "--template",
89-
help="Create a template 'input.yaml' file",
100+
help="Generate a template 'input.yaml' file by dialog",
90101
dest="template", action="store_true",
91102
default=False)
92103

@@ -117,23 +128,22 @@ def main(args):
117128
sys.exit(0)
118129

119130
if args_parse.clean:
120-
print("Cleaning working directory...")
131+
print("Cleaning working directory. Removing files/folders:")
121132

122133
interim_potentails = glob.glob("interim_potential*.yaml")
123134
ensemble_potentails = glob.glob("ensemble_potential*.yaml")
124135
for filename in sorted(files_to_remove + interim_potentails + ensemble_potentails):
125136
if os.path.isfile(filename):
126137
os.remove(filename)
127-
print(" - " + filename)
128-
print("done")
138+
print(" - ", filename)
139+
elif os.path.isdir(filename):
140+
shutil.rmtree(filename)
141+
print(" - ", filename, "(folder)")
142+
print("Done")
129143
sys.exit(0)
130144

131145
if args_parse.template:
132-
print("Creating template 'input.yaml'...", end="")
133-
template_input_yaml_filename = pkg_resources.resource_filename('pyace.data', 'input_template.yaml')
134-
copyfile(template_input_yaml_filename, "input.yaml")
135-
print("done")
136-
sys.exit(0)
146+
generate_template_input()
137147

138148
if args_parse.dry_run:
139149
log.info("====== DRY RUN ======")
@@ -144,12 +154,14 @@ def main(args):
144154
log_file_name = args_parse.log
145155
log.info("Redirecting log into file {}".format(log_file_name))
146156
fileh = logging.FileHandler(log_file_name, 'a')
147-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
157+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
158+
formatter = logging.Formatter(LOG_FMT)
148159
fileh.setFormatter(formatter)
149-
# log = logging.getLogger()
150160
log.addHandler(fileh)
151161

152162
log.info("Start pacemaker")
163+
log.info("Hostname: {}".format(hostname))
164+
log.info("Username: {}".format(username))
153165
log.info("pacemaker/pyace version: {}".format(__version__))
154166
log.info("ace_evaluator version: {}".format(get_ace_evaluator_version()))
155167
log.info("Loading {}... ".format(input_yaml_filename))
@@ -199,10 +211,9 @@ def main(args):
199211
# raise ValueError("'backend' section is not given")
200212

201213
if "backend" in args_parse:
202-
backend_config["evaluator"]=args_parse.backend
214+
backend_config["evaluator"] = args_parse.backend
203215
log.info("Backend settings is overwritten from arguments: ", backend_config)
204216

205-
206217
if 'evaluator' in backend_config:
207218
evaluator_name = backend_config['evaluator']
208219
else:
@@ -289,16 +300,97 @@ def main(args):
289300
pred_data = predict_and_save(general_fit, target_bbasisconfig, general_fit.fitting_data,
290301
fname="train_pred.pckl.gzip")
291302
log.info("Ploting validation graphs")
292-
plot_analyse_error_distributions(pred_data, fig_prefix="train_",fig_path="report",
293-
imagetype=backend_config.get("imagetype","png"))
303+
plot_analyse_error_distributions(pred_data, fig_prefix="train_", fig_path="report",
304+
imagetype=backend_config.get("imagetype", "png"))
294305

295306
if general_fit.test_data is not None:
296307
log.info("For test data")
297308
pred_data = predict_and_save(general_fit, target_bbasisconfig, general_fit.test_data,
298309
fname="test_pred.pckl.gzip")
299310
log.info("Ploting validation graphs")
300311
plot_analyse_error_distributions(pred_data, fig_prefix="test_", fig_path="report",
301-
imagetype=backend_config.get("imagetype","png"))
312+
imagetype=backend_config.get("imagetype", "png"))
313+
314+
315+
def generate_template_input():
316+
print("Generating 'input.yaml'")
317+
readline.parse_and_bind("tab: complete")
318+
319+
# 1. Training set size
320+
train_filename = input("Enter training dataset filename (ex.: data.pckl.gzip, [TAB] - autocompletion): ")
321+
testset_size_inp = float(input("Enter test set fraction or size (ex.: 0.05 or [ENTER] - no test set): ") or 0)
322+
323+
# 2. Elements
324+
elements_str = input("""Please enter list of elements (ex.: "Cu", "AlNi", [ENTER] - determine from dataset): """)
325+
if elements_str:
326+
patt = re.compile("([A-Z][a-z]?)")
327+
elements = patt.findall(elements_str)
328+
elements = sorted(elements)
329+
else:
330+
# determine from training set
331+
print("Trying to load {}".format(train_filename))
332+
df = pd.read_pickle(train_filename, compression="gzip")
333+
if 'ase_atoms' in df.columns:
334+
print("Determining available elements...")
335+
elements_set = set()
336+
df["ase_atoms"].map(lambda at: elements_set.update(at.get_chemical_symbols()));
337+
elements = sorted(elements_set)
338+
print("Found elements: ", elements)
339+
else:
340+
print("ERROR! No `ase_atoms` column found")
341+
sys.exit(1)
342+
343+
print("Number of elements: ", len(elements))
344+
print("Elements: ", elements)
345+
346+
# number of functions per element
347+
number_of_functions_per_element = int(input(
348+
"""Enter number of functions per element ([ENTER] - default 700): """) or 700)
349+
print("Number of functions per element: ", number_of_functions_per_element)
350+
351+
cutoff = float(input("Enter cutoff (Angstrom, default:7.0): ") or 7.0)
352+
print("Cutoff: ", cutoff)
353+
354+
# weighting scheme
355+
default_energy_based_weighting = """{ type: EnergyBasedWeightingPolicy, DElow: 1.0, DEup: 10.0, DFup: 50.0, DE: 1.0, DF: 1.0, wlow: 0.75, energy: convex_hull, reftype: all,seed: 42}"""
356+
weighting = None
357+
while True:
358+
weighting_inp = input(
359+
"Enter weighting scheme type - `uniform` or `energy` ([ENTER] - `uniform`): ") or 'uniform'
360+
if weighting_inp in ['uniform', 'energy']:
361+
break
362+
if weighting_inp == "energy":
363+
weighting = default_energy_based_weighting
364+
print("Use EnergyBasedWeightingPolicy: ", weighting)
365+
else:
366+
weighting = None
367+
print("Use UniformWeightingPolicy")
368+
369+
template_input_yaml_filename = pkg_resources.resource_filename('pyace.data', 'input_template.yaml')
370+
copyfile(template_input_yaml_filename, "input.yaml")
371+
with open("input.yaml", "r") as f:
372+
input_yaml_text = f.read()
373+
374+
input_yaml_text = input_yaml_text.replace("{{ELEMENTS}}", str(elements))
375+
input_yaml_text = input_yaml_text.replace("{{CUTOFF}}", str(cutoff))
376+
input_yaml_text = input_yaml_text.replace("{{DATAFILENAME}}", train_filename)
377+
input_yaml_text = input_yaml_text.replace("{{number_of_functions_per_element}}",
378+
"number_of_functions_per_element: {}".format(
379+
number_of_functions_per_element))
380+
if weighting:
381+
input_yaml_text = input_yaml_text.replace("{{WEIGHTING}}", "weighting: " + weighting)
382+
else:
383+
input_yaml_text = input_yaml_text.replace("{{WEIGHTING}}", "")
384+
385+
if testset_size_inp > 0:
386+
input_yaml_text = input_yaml_text.replace("{{test_size}}", "test_size: {}".format(testset_size_inp))
387+
else:
388+
input_yaml_text = input_yaml_text.replace("{{test_size}}", "")
389+
390+
with open("input.yaml", "w") as f:
391+
print(input_yaml_text, file=f)
392+
print("Input file is written into `input.yaml`")
393+
sys.exit(0)
302394

303395

304396
def predict_and_save(general_fit, target_bbasisconfig, structures_dataframe, fname):

0 commit comments

Comments
 (0)