diff --git a/lir/config/__init__.py b/lir/config/__init__.py index f6bc35e9..e69de29b 100644 --- a/lir/config/__init__.py +++ b/lir/config/__init__.py @@ -1,5 +0,0 @@ -# This import is used to be able to write a short-hand version in `registry.yaml`, i.e. -# `lir.config.numpy_csv_writer`. Ignored by linting (F401; unused import). -from .transform import ( # noqa: F401 - NumpyWrappingConfigParser as numpy_wrapper, # noqa: N813 -) diff --git a/lir/config/aggregation.py b/lir/config/aggregation.py index cc4005e5..72e518bc 100644 --- a/lir/config/aggregation.py +++ b/lir/config/aggregation.py @@ -1,5 +1,6 @@ from functools import partial from pathlib import Path +from typing import Any from lir import registry from lir.aggregation import Aggregation, SubsetAggregation @@ -117,3 +118,8 @@ def subset_aggregation(config: ContextAwareDict, output_dir: Path) -> SubsetAggr aggregation_methods = [parse_aggregation(aggregation_config, subset_output_dir)] return SubsetAggregation(aggregation_methods, category_field) + + +class AggregationPlotConfigParser(ConfigParser): + def parse(self, config: ContextAwareDict, output_dir: Path) -> Any: + pass diff --git a/lir/config/base.py b/lir/config/base.py index ddfe74b5..236f47fb 100644 --- a/lir/config/base.py +++ b/lir/config/base.py @@ -5,9 +5,6 @@ from pathlib import Path from typing import Any, TypeVar -from lir import registry -from lir.transform.pairing import PairingMethod - class YamlParseError(ValueError): """ @@ -386,48 +383,6 @@ def reference(self) -> str: return ConfigParserFunction -def parse_pairing_config( - module_config: ContextAwareDict | str, - output_dir: Path, - context: list[str], -) -> PairingMethod: - """ - Parse and delegate pairing to the corresponding function for the defined pairing method. - - The argument `module_config` defines the pairing method. If its value is a `str`, the registry is queried and the - corresponding pairing method is returned. If its value is a `dict`, the pairing method is defined - by the value `module_config["method"]`, and the registry is queried for the config parser of - the corresponding pairing method. The remaining values in `module_config` are passed as arguments to the - configuration parser of the pairing method. - - If the registry cannot resolve the pairing method, an exception is raised. - - Parameters - ---------- - module_config : ContextAwareDict | str - Pairing method configuration. - output_dir : Path - Output directory for parser calls. - context : list[str] - Context used when ``module_config`` is a string. - - Returns - ------- - PairingMethod - Parsed pairing method. - """ - if isinstance(module_config, str): - class_name = module_config - args = ContextAwareDict(context) - else: - class_name = pop_field(module_config, 'method') - args = module_config - - return registry.get(class_name, search_path=['pairing'], default_config_parser=GenericConfigParser).parse( - args, output_dir - ) - - AnyType = TypeVar('AnyType') diff --git a/lir/config/lrsystem_architectures.py b/lir/config/lrsystem_architectures.py index 78fcccb2..bd3535c9 100644 --- a/lir/config/lrsystem_architectures.py +++ b/lir/config/lrsystem_architectures.py @@ -8,7 +8,6 @@ YamlParseError, check_is_empty, config_parser, - parse_pairing_config, pop_field, ) from lir.config.substitution import ( @@ -16,7 +15,7 @@ HyperparameterOption, substitute_parameters, ) -from lir.config.transform import parse_module +from lir.config.transform import parse_module, parse_pairing_config from lir.data.models import InstanceData, LLRData from lir.lrsystems.binary_lrsystem import BinaryLRSystem from lir.lrsystems.lrsystems import LRSystem diff --git a/lir/config/transform.py b/lir/config/transform.py index 797067b0..09a8c287 100644 --- a/lir/config/transform.py +++ b/lir/config/transform.py @@ -5,20 +5,20 @@ from lir.config.base import ( ConfigParser, ContextAwareDict, + GenericConfigParser, YamlParseError, check_not_none, config_parser, pop_field, ) from lir.transform import ( - BinaryClassifierTransformer, CsvWriter, FunctionTransformer, Identity, - NumpyTransformer, Transformer, as_transformer, ) +from lir.transform.pairing import PairingMethod class GenericTransformerConfigParser(ConfigParser): @@ -73,19 +73,6 @@ def parse( f'failed to instantiate module {self.component_class.__name__}: {e}', ) - if isinstance(instance, Transformer): - # The component already supports all necessary methods, - # through the `Transformer` interface. - return instance - if hasattr(instance, 'transform'): - # The component implements a `transform()` method, which means it - # is a transformer and can be used in the scikit-learn pipeline. - return instance - if hasattr(instance, 'predict_proba'): - # The component has a `predict_proba` method, which should be used as - # `transform()` step in the pipeline, which the wrapper class provides. - return BinaryClassifierTransformer(instance) - elif callable(self.component_class): # When none of the above conditions apply, the component class might be a function # or a callable class, which should be used as a `transform()` step in the pipeline, @@ -95,54 +82,6 @@ def parse( raise YamlParseError(config.context, f'unrecognized module type: `{self.component_class}`') -class NumpyWrappingConfigParser(ConfigParser): - """ - Wrap a Transformer to add a header to FeatureData. - - Parameters - ---------- - module_parser : ConfigParser - Parser used to create the wrapped transformer. - """ - - def __init__(self, module_parser: ConfigParser): - super().__init__() - self.module_parser = module_parser - - def parse(self, config: ContextAwareDict, output_dir: Path) -> Transformer: - """ - Parse the provided header configuration. - - Parameters - ---------- - config : ContextAwareDict - Configuration possibly containing ``header`` and module fields. - output_dir : Path - Output directory passed to the wrapped parser. - - Returns - ------- - Transformer - Wrapped transformer that preserves numpy headers. - """ - header = config.pop('header') if 'header' in config else None - return NumpyTransformer( - self.module_parser.parse(config, output_dir), - header=header, - ) - - def reference(self) -> str: - """ - Return the full name of the ``module_parser`` class argument. - - Returns - ------- - str - Reference string for the wrapped parser. - """ - return self.module_parser.reference() - - def parse_module( module_config: ContextAwareDict | str | None, output_dir: Path, @@ -220,3 +159,45 @@ def csv_writer(config: ContextAwareDict, output_dir: Path) -> CsvWriter: if 'path' not in config: config |= {'path': output_dir / f'{config.context[-1]}.csv'} return CsvWriter(**config) + + +def parse_pairing_config( + module_config: ContextAwareDict | str, + output_dir: Path, + context: list[str], +) -> PairingMethod: + """ + Parse and delegate pairing to the corresponding function for the defined pairing method. + + The argument `module_config` defines the pairing method. If its value is a `str`, the registry is queried and the + corresponding pairing method is returned. If its value is a `dict`, the pairing method is defined + by the value `module_config["method"]`, and the registry is queried for the config parser of + the corresponding pairing method. The remaining values in `module_config` are passed as arguments to the + configuration parser of the pairing method. + + If the registry cannot resolve the pairing method, an exception is raised. + + Parameters + ---------- + module_config : ContextAwareDict | str + Pairing method configuration. + output_dir : Path + Output directory for parser calls. + context : list[str] + Context used when ``module_config`` is a string. + + Returns + ------- + PairingMethod + Parsed pairing method. + """ + if isinstance(module_config, str): + class_name = module_config + args = ContextAwareDict(context) + else: + class_name = pop_field(module_config, 'method') + args = module_config + + return registry.get(class_name, search_path=['pairing'], default_config_parser=GenericConfigParser).parse( + args, output_dir + ) diff --git a/lir/registry.py b/lir/registry.py index 05a81b55..77315143 100644 --- a/lir/registry.py +++ b/lir/registry.py @@ -7,7 +7,7 @@ import confidence -from lir.config.base import ConfigParser, GenericConfigParser +from lir.config.base import ConfigParser, GenericConfigParser, _expand from . import resources as package_resources @@ -81,12 +81,15 @@ class ConfigParserLoader(ABC, Iterable): @staticmethod def _get_config_parser( - result_type: Any, default_config_parser: Callable[[Any], ConfigParser] | None + result_type: Any, + default_config_parser: Callable[[Any], ConfigParser] | None, + args: dict[str, Any] | None = None, ) -> ConfigParser: + args = args or {} if inspect.isclass(result_type) and issubclass(result_type, ConfigParser): - return result_type() + return result_type(**args) elif default_config_parser is not None: - return default_config_parser(result_type) + return default_config_parser(result_type, **args) else: raise InvalidRegistryEntryError( f'unable to instantiate {result_type}: ' @@ -320,19 +323,8 @@ def _parse( except Exception as e: raise ValueError(f'registry key `{key}` resolved to `{spec.get("class")}` but failed to materialize: {e}') - parser = ConfigParserLoader._get_config_parser(cls, default_config_parser) - - if 'wrapper' in spec: - try: - wrapper = _get_attribute_by_name(spec.get('wrapper')) # type: ignore[arg-type] - except Exception as e: - raise InvalidRegistryEntryError( - f'unable to instantiate class {spec["class"]}: ' - f'error while instantiating wrapper class: {spec["wrapper"]}: {e}' - ) - parser = wrapper(parser) - - return parser + parser_init_args = spec.get('args', {}) + return ConfigParserLoader._get_config_parser(cls, default_config_parser, args=parser_init_args) def _find(self, key: str, search_path: list[str] | None) -> Any: """ diff --git a/lir/resources/registry.yaml b/lir/resources/registry.yaml index 132db698..7875bdd5 100644 --- a/lir/resources/registry.yaml +++ b/lir/resources/registry.yaml @@ -27,15 +27,9 @@ modules: bootstrap: lir.algorithms.bootstraps.bootstrap # transformations - standard_scaler: - class: sklearn.preprocessing.StandardScaler - wrapper: lir.config.numpy_wrapper - probabilities_to_odds: - class: lir.util.probability_to_odds - wrapper: lir.config.numpy_wrapper - probabilities_to_logodds: - class: lir.util.probability_to_logodds - wrapper: lir.config.numpy_wrapper + standard_scaler: sklearn.preprocessing.StandardScaler + probabilities_to_odds: lir.util.probability_to_odds + probabilities_to_logodds: lir.util.probability_to_logodds # comparisons element_wise_difference: lir.transform.distance.ElementWiseDifference @@ -50,34 +44,20 @@ modules: identity: lir.transform.Identity # classifiers - logistic_regression: - class: sklearn.linear_model.LogisticRegression - wrapper: lir.config.numpy_wrapper - svm: - class: sklearn.svm.SVC - wrapper: lir.config.numpy_wrapper + logistic_regression: sklearn.linear_model.LogisticRegression + svm: sklearn.svm.SVC # calibrators kde: lir.algorithms.kde.KDECalibrator - isotonic_calibrator: - class: lir.algorithms.isotonic_regression.IsotonicCalibrator - wrapper: lir.config.numpy_wrapper + isotonic_calibrator: lir.algorithms.isotonic_regression.IsotonicCalibrator logistic_calibrator: lir.algorithms.logistic_regression.LogitCalibrator mcmc: lir.config.algorithms.mcmc # bounders - static_bounder: - class: lir.bounding.StaticBounder - wrapper: lir.config.numpy_wrapper - elub_bounder: - class: lir.algorithms.bayeserror.ELUBBounder - wrapper: lir.config.numpy_wrapper - iv_bounder: - class: lir.algorithms.invariance_bounds.IVBounder - wrapper: lir.config.numpy_wrapper - n_source_bounder: - class: lir.bounding.NSourceBounder - wrapper: lir.config.numpy_wrapper + static_bounder: lir.bounding.StaticBounder + elub_bounder: lir.algorithms.bayeserror.ELUBBounder + iv_bounder: lir.algorithms.invariance_bounds.IVBounder + n_source_bounder: lir.bounding.NSourceBounder metric: cllr: lir.metrics.cllr cllr_min: lir.metrics.cllr_min @@ -87,7 +67,11 @@ metric: devpav: lir.algorithms.devpav.devpav output: metrics: lir.aggregation.metrics_csv - pav: lir.aggregation.plot_pav + pav: + class: lir.config.aggregation.parse_aggregate_plot + args: + method: pav + ece: lir.aggregation.plot_ece lr_histogram: lir.aggregation.plot_lr_histogram llr_interval: lir.aggregation.plot_llr_interval diff --git a/lir/transform/__init__.py b/lir/transform/__init__.py index 96067776..afc857b4 100644 --- a/lir/transform/__init__.py +++ b/lir/transform/__init__.py @@ -386,61 +386,6 @@ def apply(self, instances: InstanceData) -> InstanceData: return self.wrapped_transformer.apply(instances) -class NumpyTransformer(TransformerWrapper): - """ - Implementation of a transformer wrapper. - - Parameters - ---------- - transformer : Transformer - Transformer instance wrapped by this adapter. - header : list[str] | None - Value passed via ``header``. - """ - - def __init__(self, transformer: Transformer, header: list[str] | None): - super().__init__(transformer) - self.header = header - - def apply(self, instances: InstanceData) -> InstanceData: - """ - Extend the instances with the desired header data, call base `apply`. - - Parameters - ---------- - instances : InstanceData - Input instances to be processed by this method. - - Returns - ------- - InstanceData - Instance data object produced by this operation. - """ - instances = super().apply(instances) - if self.header: - instances = instances.replace(header=self.header) - return instances - - def fit_apply(self, instances: InstanceData) -> InstanceData: - """ - Extend the instances with the desired header data, call base `fit_apply`. - - Parameters - ---------- - instances : InstanceData - Input instances to be processed by this method. - - Returns - ------- - InstanceData - Instance data object produced by this operation. - """ - instances = super().fit_apply(instances) - if self.header: - instances = instances.replace(header=self.header) - return instances - - class CsvWriter(Transformer): """ Implementation of a transformation step in a scikit-learn Pipeline that writes to CSV.