Source code for recipies.selector

import re
from .ingredients import Ingredients
from typing import Union
from polars import DataType
from .constants import Backend


[docs] class Selector: """Class responsible for selecting the variables affected by a recipe step Args: description: Text used to represent Selector when printed in summaries names: Column names to select. Defaults to None. roles: Column roles to select, see also Ingredients. Defaults to None. types: Column data types to select. Defaults to None. pattern: Regex pattern to search column names with. Defaults to None. """
[docs] def __init__( self, description: str, names: Union[str, list[str]] = None, roles: Union[str, list[str]] = None, types: Union[str, list[str]] = None, pattern: re.Pattern = None, ): self.description = description self.set_names(names) self.set_roles(roles) self.set_types(types) self.set_pattern(pattern)
[docs] def set_names(self, names: Union[str, list[str]]): """Set the column names to select with this Selector Args: names: column names to select """ self.names = enlist_str(names)
[docs] def set_roles(self, roles: Union[str, list[str]]): """Set the column roles to select with this Selector Args: roles: column roles to select, see also Ingredients """ self.roles = enlist_str(roles)
[docs] def set_types(self, roles: Union[str, list[str]]): """Set the column data types to select with this Selector Args: roles: column data types to select """ self.types = enlist_str(roles)
# self.types = enlist_dt(roles)
[docs] def set_pattern(self, pattern: re.Pattern): """Set the pattern to search with this Selector Args: pattern: Regex pattern to search column names with. """ self.pattern = pattern
def __call__(self, ingr: Ingredients) -> list[str]: """Select variables from Ingredients Args: ingr: object from which to select the variables Raises: TypeError: when something other than an Ingredient object is passed Returns: Selected variables. """ if not isinstance(ingr, Ingredients): raise TypeError(f"Expected Ingredients, got {ingr.__class__}") vars = list(ingr.columns) # Pandas # vars = ingr.columns.tolist() if self.roles is not None: # for v, r in ingr.roles.items(): # print(v, r) # print(intersection(r, self.roles)) sel_roles = [v for v, r in ingr.roles.items() if intersection(r, self.roles)] vars = intersection(vars, sel_roles) if self.types is not None: sel_types = list(ingr.select_dtypes(include=self.types)) # .columns.tolist() vars = intersection(vars, sel_types) if self.names is not None: vars = intersection(vars, self.names) if self.pattern is not None: vars = list(filter(self.pattern.search, vars)) return vars def __repr__(self): return self.description
[docs] def enlist_dt(x: Union[DataType, list[DataType], None]) -> Union[list[DataType], None]: """Wrap a pl datatype in a list if it isn't a list yet Args: x: object to wrap. Raises: TypeError: If neither a datatype nor a list of datatypes is passed Returns: _description_ """ if isinstance(x, DataType): return [x] elif isinstance(x, list): if not all(isinstance(i, DataType) for i in x): raise TypeError("Only lists of datatypes are allowed.") return x elif x is None: return x else: raise TypeError(f"Expected a pl datatype, got {x.__class__}")
[docs] def enlist_str(x: Union[str, list[str], None]) -> Union[list[str], None]: """Wrap a str in a list if it isn't a list yet Args: x: object to wrap. Raises: TypeError: If neither a str nor a list of strings is passed Returns: _description_ """ if isinstance(x, str): return [x] elif isinstance(x, list): if not all(isinstance(i, str) for i in x): raise TypeError("Only lists of str are allowed.") return x elif x is None: return x else: raise TypeError(f"Expected str or list of str, got {x.__class__}")
[docs] def intersection(x: list, y: list) -> list: """Intersection of two lists Note: maintains the order of the first list does not deduplicate items (i.e., does not return a set) Args: x: first list y: second list Returns: Elements in `x` that are also in `y`. """ if isinstance(x, str): x = [x] if isinstance(y, str): y = [y] return [i for i in x if i in y]
[docs] def all_of(names: Union[str, list[str]]) -> Selector: """Define selector for any columns with one of the given names Args: names: names to select Returns: Object representing the selection rule. """ return Selector(description=str(names), names=names)
[docs] def regex_names(regex: str) -> Selector: """Define selector for any columns where the name matches the regex pattern Args: pattern: string to be transformed to regex pattern to search for Returns: Object representing the selection rule. """ pattern = re.compile(regex) return Selector(description=f"regex: {regex}", pattern=pattern)
[docs] def starts_with(prefix: str) -> Selector: """Define selector for any columns where the name starts with the prefix Args: prefix: prefix to search for Returns: Object representing the selection rule. """ return regex_names(f"^{prefix}")
[docs] def ends_with(suffix: str) -> Selector: """Define selector for any columns where the name ends with the suffix Args: prsuffixefix: suffix to search for Returns: Object representing the selection rule. """ return regex_names(f"{suffix}$")
[docs] def contains(substring: str) -> Selector: """Define selector for any columns where the name contains the substring Args: substring: substring to search for Returns: Object representing the selection rule. """ return regex_names(f"{substring}")
[docs] def has_role(roles: Union[str, list[str]]) -> Selector: """Define selector for any columns with one of the given roles Args: roles: roles to select Returns: Object representing the selection rule. """ return Selector(description=f"roles: {roles}", roles=roles)
[docs] def has_type(types: Union[str, list[str]]) -> Selector: """Define selector for any columns with one of the given types Args: types: data types to select Note: Data types are selected based on string representation as returned by `df[[varname]].dtype.name`. Returns: Object representing the selection rule. """ return Selector(description=f"types: {types}", types=types)
[docs] def all_predictors() -> Selector: """Define selector for all predictor columns Returns: Object representing the selection rule. """ sel = has_role(["predictor"]) sel.description = "all predictors" return sel
[docs] def all_numeric_predictors(backend=Backend.POLARS) -> Selector: """Define selector for all numerical predictor columns Returns: Object representing the selection rule. """ sel = all_predictors() # if backend == Backend.POLARS: sel.set_types(["Int8", "Int16", "Int32", "Int64", "Float32", "Float64", "int16", "int32", "int64", "float16", "float32", "float64"]) # else: # sel.set_types([]) sel.description = "all numeric predictors" return sel
[docs] def all_outcomes() -> Selector: """Define selector for all outcome columns Returns: Object representing the selection rule. """ sel = has_role(["outcome"]) sel.description = "all outcomes" return sel
[docs] def all_groups() -> Selector: """Define selector for all grouping variables Returns: Object representing the selection rule. """ return Selector(description="all grouping variables", roles=["group"])
[docs] def select_groups(ingr: Ingredients) -> list[str]: """Select any grouping columns Defines and directly applies Selector(roles=["group"]) Returns: grouping columns """ groups = all_groups()(ingr) return groups
[docs] def all_sequences() -> Selector: """Define selector for all grouping variables Returns: Object representing the selection rule. """ return Selector(description="all sequence variables", roles=["sequence"])
[docs] def select_sequence(ingr: Ingredients) -> list[str]: """Select any sequence columns Defines and directly applies Selector(roles=["sequence"]) Returns: Grouping columns. """ return all_sequences()(ingr)