Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions lir/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
# This import is used to be able to write a short-hand version in `registry.yaml`, i.e.
# `lir.config.numpy_csv_writer`. Ignored by linting (F401; unused import).
from .transform import ( # noqa: F401
NumpyWrappingConfigParser as numpy_wrapper, # noqa: N813
)
6 changes: 6 additions & 0 deletions lir/config/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from pathlib import Path
from typing import Any

from lir import registry
from lir.aggregation import Aggregation, SubsetAggregation
Expand Down Expand Up @@ -117,3 +118,8 @@ def subset_aggregation(config: ContextAwareDict, output_dir: Path) -> SubsetAggr
aggregation_methods = [parse_aggregation(aggregation_config, subset_output_dir)]

return SubsetAggregation(aggregation_methods, category_field)


class AggregationPlotConfigParser(ConfigParser):
def parse(self, config: ContextAwareDict, output_dir: Path) -> Any:
pass
45 changes: 0 additions & 45 deletions lir/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from pathlib import Path
from typing import Any, TypeVar

from lir import registry
from lir.transform.pairing import PairingMethod


class YamlParseError(ValueError):
"""
Expand Down Expand Up @@ -386,48 +383,6 @@ def reference(self) -> str:
return ConfigParserFunction


def parse_pairing_config(
module_config: ContextAwareDict | str,
output_dir: Path,
context: list[str],
) -> PairingMethod:
"""
Parse and delegate pairing to the corresponding function for the defined pairing method.

The argument `module_config` defines the pairing method. If its value is a `str`, the registry is queried and the
corresponding pairing method is returned. If its value is a `dict`, the pairing method is defined
by the value `module_config["method"]`, and the registry is queried for the config parser of
the corresponding pairing method. The remaining values in `module_config` are passed as arguments to the
configuration parser of the pairing method.

If the registry cannot resolve the pairing method, an exception is raised.

Parameters
----------
module_config : ContextAwareDict | str
Pairing method configuration.
output_dir : Path
Output directory for parser calls.
context : list[str]
Context used when ``module_config`` is a string.

Returns
-------
PairingMethod
Parsed pairing method.
"""
if isinstance(module_config, str):
class_name = module_config
args = ContextAwareDict(context)
else:
class_name = pop_field(module_config, 'method')
args = module_config

return registry.get(class_name, search_path=['pairing'], default_config_parser=GenericConfigParser).parse(
args, output_dir
)


AnyType = TypeVar('AnyType')


Expand Down
3 changes: 1 addition & 2 deletions lir/config/lrsystem_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
YamlParseError,
check_is_empty,
config_parser,
parse_pairing_config,
pop_field,
)
from lir.config.substitution import (
ContextAwareDict,
HyperparameterOption,
substitute_parameters,
)
from lir.config.transform import parse_module
from lir.config.transform import parse_module, parse_pairing_config
from lir.data.models import InstanceData, LLRData
from lir.lrsystems.binary_lrsystem import BinaryLRSystem
from lir.lrsystems.lrsystems import LRSystem
Expand Down
107 changes: 44 additions & 63 deletions lir/config/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
from lir.config.base import (
ConfigParser,
ContextAwareDict,
GenericConfigParser,
YamlParseError,
check_not_none,
config_parser,
pop_field,
)
from lir.transform import (
BinaryClassifierTransformer,
CsvWriter,
FunctionTransformer,
Identity,
NumpyTransformer,
Transformer,
as_transformer,
)
from lir.transform.pairing import PairingMethod


class GenericTransformerConfigParser(ConfigParser):
Expand Down Expand Up @@ -73,19 +73,6 @@ def parse(
f'failed to instantiate module {self.component_class.__name__}: {e}',
)

if isinstance(instance, Transformer):
# The component already supports all necessary methods,
# through the `Transformer` interface.
return instance
if hasattr(instance, 'transform'):
# The component implements a `transform()` method, which means it
# is a transformer and can be used in the scikit-learn pipeline.
return instance
if hasattr(instance, 'predict_proba'):
# The component has a `predict_proba` method, which should be used as
# `transform()` step in the pipeline, which the wrapper class provides.
return BinaryClassifierTransformer(instance)

elif callable(self.component_class):
# When none of the above conditions apply, the component class might be a function
# or a callable class, which should be used as a `transform()` step in the pipeline,
Expand All @@ -95,54 +82,6 @@ def parse(
raise YamlParseError(config.context, f'unrecognized module type: `{self.component_class}`')


class NumpyWrappingConfigParser(ConfigParser):
"""
Wrap a Transformer to add a header to FeatureData.

Parameters
----------
module_parser : ConfigParser
Parser used to create the wrapped transformer.
"""

def __init__(self, module_parser: ConfigParser):
super().__init__()
self.module_parser = module_parser

def parse(self, config: ContextAwareDict, output_dir: Path) -> Transformer:
"""
Parse the provided header configuration.

Parameters
----------
config : ContextAwareDict
Configuration possibly containing ``header`` and module fields.
output_dir : Path
Output directory passed to the wrapped parser.

Returns
-------
Transformer
Wrapped transformer that preserves numpy headers.
"""
header = config.pop('header') if 'header' in config else None
return NumpyTransformer(
self.module_parser.parse(config, output_dir),
header=header,
)

def reference(self) -> str:
"""
Return the full name of the ``module_parser`` class argument.

Returns
-------
str
Reference string for the wrapped parser.
"""
return self.module_parser.reference()


def parse_module(
module_config: ContextAwareDict | str | None,
output_dir: Path,
Expand Down Expand Up @@ -220,3 +159,45 @@ def csv_writer(config: ContextAwareDict, output_dir: Path) -> CsvWriter:
if 'path' not in config:
config |= {'path': output_dir / f'{config.context[-1]}.csv'}
return CsvWriter(**config)


def parse_pairing_config(
module_config: ContextAwareDict | str,
output_dir: Path,
context: list[str],
) -> PairingMethod:
"""
Parse and delegate pairing to the corresponding function for the defined pairing method.

The argument `module_config` defines the pairing method. If its value is a `str`, the registry is queried and the
corresponding pairing method is returned. If its value is a `dict`, the pairing method is defined
by the value `module_config["method"]`, and the registry is queried for the config parser of
the corresponding pairing method. The remaining values in `module_config` are passed as arguments to the
configuration parser of the pairing method.

If the registry cannot resolve the pairing method, an exception is raised.

Parameters
----------
module_config : ContextAwareDict | str
Pairing method configuration.
output_dir : Path
Output directory for parser calls.
context : list[str]
Context used when ``module_config`` is a string.

Returns
-------
PairingMethod
Parsed pairing method.
"""
if isinstance(module_config, str):
class_name = module_config
args = ContextAwareDict(context)
else:
class_name = pop_field(module_config, 'method')
args = module_config

return registry.get(class_name, search_path=['pairing'], default_config_parser=GenericConfigParser).parse(
args, output_dir
)
26 changes: 9 additions & 17 deletions lir/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import confidence

from lir.config.base import ConfigParser, GenericConfigParser
from lir.config.base import ConfigParser, GenericConfigParser, _expand

from . import resources as package_resources

Expand Down Expand Up @@ -81,12 +81,15 @@ class ConfigParserLoader(ABC, Iterable):

@staticmethod
def _get_config_parser(
result_type: Any, default_config_parser: Callable[[Any], ConfigParser] | None
result_type: Any,
default_config_parser: Callable[[Any], ConfigParser] | None,
args: dict[str, Any] | None = None,
) -> ConfigParser:
args = args or {}
if inspect.isclass(result_type) and issubclass(result_type, ConfigParser):
return result_type()
return result_type(**args)
elif default_config_parser is not None:
return default_config_parser(result_type)
return default_config_parser(result_type, **args)
else:
raise InvalidRegistryEntryError(
f'unable to instantiate {result_type}: '
Expand Down Expand Up @@ -320,19 +323,8 @@ def _parse(
except Exception as e:
raise ValueError(f'registry key `{key}` resolved to `{spec.get("class")}` but failed to materialize: {e}')

parser = ConfigParserLoader._get_config_parser(cls, default_config_parser)

if 'wrapper' in spec:
try:
wrapper = _get_attribute_by_name(spec.get('wrapper')) # type: ignore[arg-type]
except Exception as e:
raise InvalidRegistryEntryError(
f'unable to instantiate class {spec["class"]}: '
f'error while instantiating wrapper class: {spec["wrapper"]}: {e}'
)
parser = wrapper(parser)

return parser
parser_init_args = spec.get('args', {})
return ConfigParserLoader._get_config_parser(cls, default_config_parser, args=parser_init_args)

def _find(self, key: str, search_path: list[str] | None) -> Any:
"""
Expand Down
46 changes: 15 additions & 31 deletions lir/resources/registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,9 @@ modules:
bootstrap: lir.algorithms.bootstraps.bootstrap

# transformations
standard_scaler:
class: sklearn.preprocessing.StandardScaler
wrapper: lir.config.numpy_wrapper
probabilities_to_odds:
class: lir.util.probability_to_odds
wrapper: lir.config.numpy_wrapper
probabilities_to_logodds:
class: lir.util.probability_to_logodds
wrapper: lir.config.numpy_wrapper
standard_scaler: sklearn.preprocessing.StandardScaler
probabilities_to_odds: lir.util.probability_to_odds
probabilities_to_logodds: lir.util.probability_to_logodds

# comparisons
element_wise_difference: lir.transform.distance.ElementWiseDifference
Expand All @@ -50,34 +44,20 @@ modules:
identity: lir.transform.Identity

# classifiers
logistic_regression:
class: sklearn.linear_model.LogisticRegression
wrapper: lir.config.numpy_wrapper
svm:
class: sklearn.svm.SVC
wrapper: lir.config.numpy_wrapper
logistic_regression: sklearn.linear_model.LogisticRegression
svm: sklearn.svm.SVC

# calibrators
kde: lir.algorithms.kde.KDECalibrator
isotonic_calibrator:
class: lir.algorithms.isotonic_regression.IsotonicCalibrator
wrapper: lir.config.numpy_wrapper
isotonic_calibrator: lir.algorithms.isotonic_regression.IsotonicCalibrator
logistic_calibrator: lir.algorithms.logistic_regression.LogitCalibrator
mcmc: lir.config.algorithms.mcmc

# bounders
static_bounder:
class: lir.bounding.StaticBounder
wrapper: lir.config.numpy_wrapper
elub_bounder:
class: lir.algorithms.bayeserror.ELUBBounder
wrapper: lir.config.numpy_wrapper
iv_bounder:
class: lir.algorithms.invariance_bounds.IVBounder
wrapper: lir.config.numpy_wrapper
n_source_bounder:
class: lir.bounding.NSourceBounder
wrapper: lir.config.numpy_wrapper
static_bounder: lir.bounding.StaticBounder
elub_bounder: lir.algorithms.bayeserror.ELUBBounder
iv_bounder: lir.algorithms.invariance_bounds.IVBounder
n_source_bounder: lir.bounding.NSourceBounder
metric:
cllr: lir.metrics.cllr
cllr_min: lir.metrics.cllr_min
Expand All @@ -87,7 +67,11 @@ metric:
devpav: lir.algorithms.devpav.devpav
output:
metrics: lir.aggregation.metrics_csv
pav: lir.aggregation.plot_pav
pav:
class: lir.config.aggregation.parse_aggregate_plot
args:
method: pav

ece: lir.aggregation.plot_ece
lr_histogram: lir.aggregation.plot_lr_histogram
llr_interval: lir.aggregation.plot_llr_interval
Expand Down
Loading
Loading