|
28 | 28 |
|
29 | 29 | log = logging.getLogger(__name__) |
30 | 30 |
|
| 31 | +REF_PROP_NAME = '1-body-000001:static' |
| 32 | +REF_GENERIC_PROTOTYPE_NAME = '1-body-000001' |
| 33 | + |
31 | 34 | # ## QUERY DATA |
32 | 35 | LATTICE_COLUMNS = ["_lat_ax", "_lat_ay", "_lat_az", |
33 | 36 | "_lat_bx", "_lat_by", "_lat_bz", |
@@ -129,17 +132,7 @@ def query_data(config: Dict, seed=None, query_limit=None, db_conn_string=None): |
129 | 132 | if REF_ENERGY_KW not in config: |
130 | 133 | try: |
131 | 134 | # TODO: generalize query of reference property |
132 | | - REF_PROP_NAME = '1-body-000001:static' |
133 | | - REF_GENERIC_PROTOTYPE_NAME = '1-body-000001' |
134 | | - ref_prop = storage.query(StaticProperty).join(StructureEntry, GenericEntry).filter( |
135 | | - Property.CALCULATOR == reference_calculator, |
136 | | - Property.NAME == REF_PROP_NAME, |
137 | | - StructureEntry.COMPOSITION.like(config["element"] + "-%"), |
138 | | - StructureEntry.NUMBER_OF_ATOMS == 1, |
139 | | - GenericEntry.PROTOTYPE_NAME == REF_GENERIC_PROTOTYPE_NAME |
140 | | - ).one() |
141 | | - # free atom reference energy |
142 | | - ref_energy = ref_prop.energy / ref_prop.n_atom |
| 135 | + ref_energy = query_reference_energy(config["element"], reference_calculator, storage) |
143 | 136 | except NoResultFound as e: |
144 | 137 | log.error(("No reference energy for {} was found in database. " + |
145 | 138 | "Either add property named `{}` with generic named `{}` to database or use `{}` " + |
@@ -214,6 +207,20 @@ def query_data(config: Dict, seed=None, query_limit=None, db_conn_string=None): |
214 | 207 | return df_total, ref_energy |
215 | 208 |
|
216 | 209 |
|
| 210 | +def query_reference_energy(element, reference_calculator, storage): |
| 211 | + from structdborm import StructureEntry, StaticProperty, GenericEntry, Property |
| 212 | + ref_prop = storage.query(StaticProperty).join(StructureEntry, GenericEntry).filter( |
| 213 | + Property.CALCULATOR == reference_calculator, |
| 214 | + Property.NAME == REF_PROP_NAME, |
| 215 | + StructureEntry.COMPOSITION.like(element + "-%"), |
| 216 | + StructureEntry.NUMBER_OF_ATOMS == 1, |
| 217 | + GenericEntry.PROTOTYPE_NAME == REF_GENERIC_PROTOTYPE_NAME |
| 218 | + ).one() |
| 219 | + # free atom reference energy |
| 220 | + ref_energy = ref_prop.energy / ref_prop.n_atom |
| 221 | + return ref_energy |
| 222 | + |
| 223 | + |
217 | 224 | class StructuresDatasetWeightingPolicy: |
218 | 225 | def generate_weights(self, df): |
219 | 226 | raise NotImplementedError |
@@ -639,7 +646,7 @@ def get_fit_dataframe(self, force_query=None, weights_policy=None, ignore_weight |
639 | 646 |
|
640 | 647 | class EnergyBasedWeightingPolicy(StructuresDatasetWeightingPolicy): |
641 | 648 |
|
642 | | - def __init__(self, nfit=20000, |
| 649 | + def __init__(self, nfit=None, |
643 | 650 | cutoff=None, |
644 | 651 | DElow=1.0, |
645 | 652 | DEup=10.0, |
@@ -705,6 +712,10 @@ def __str__(self): |
705 | 712 | reftype=self.reftype, seed=self.seed) |
706 | 713 |
|
707 | 714 | def generate_weights(self, df): |
| 715 | + if self.nfit is None: |
| 716 | + self.nfit = len(df) |
| 717 | + log.info("Set nfit to the dataset size {}".format(self.nfit)) |
| 718 | + |
708 | 719 | if self.reftype == "bulk": |
709 | 720 | log.info("Reducing to bulk data") |
710 | 721 | df = df[df.pbc] |
@@ -1019,7 +1030,8 @@ def generate_weights(self, df): |
1019 | 1030 | if col_to_drop in df.columns: |
1020 | 1031 | df.drop(columns=col_to_drop, inplace=True) |
1021 | 1032 |
|
1022 | | - mdf = pd.merge(df, self.weights_df[[WEIGHTS_ENERGY_COLUMN, WEIGHTS_FORCES_COLUMN]], left_index=True, right_index=True) |
| 1033 | + mdf = pd.merge(df, self.weights_df[[WEIGHTS_ENERGY_COLUMN, WEIGHTS_FORCES_COLUMN]], left_index=True, |
| 1034 | + right_index=True) |
1023 | 1035 | if not (mdf[FORCES_COLUMN].map(len) == mdf[WEIGHTS_FORCES_COLUMN].map(len)).all(): |
1024 | 1036 | error_msg = ("Shape of the `{}` column doesn't correspond to the shape of " |
1025 | 1037 | "`forces` column in original dataframe").format(WEIGHTS_FORCES_COLUMN) |
|
0 commit comments