import re
from .ingredients import Ingredients
from typing import Union
from polars import DataType
from .constants import Backend
[docs]
class Selector:
"""Class responsible for selecting the variables affected by a recipe step
Args:
description: Text used to represent Selector when printed in summaries
names: Column names to select. Defaults to None.
roles: Column roles to select, see also Ingredients. Defaults to None.
types: Column data types to select. Defaults to None.
pattern: Regex pattern to search column names with. Defaults to None.
"""
[docs]
def __init__(
self,
description: str,
names: Union[str, list[str]] = None,
roles: Union[str, list[str]] = None,
types: Union[str, list[str]] = None,
pattern: re.Pattern = None,
):
self.description = description
self.set_names(names)
self.set_roles(roles)
self.set_types(types)
self.set_pattern(pattern)
[docs]
def set_names(self, names: Union[str, list[str]]):
"""Set the column names to select with this Selector
Args:
names: column names to select
"""
self.names = enlist_str(names)
[docs]
def set_roles(self, roles: Union[str, list[str]]):
"""Set the column roles to select with this Selector
Args:
roles: column roles to select, see also Ingredients
"""
self.roles = enlist_str(roles)
[docs]
def set_types(self, roles: Union[str, list[str]]):
"""Set the column data types to select with this Selector
Args:
roles: column data types to select
"""
self.types = enlist_str(roles)
# self.types = enlist_dt(roles)
[docs]
def set_pattern(self, pattern: re.Pattern):
"""Set the pattern to search with this Selector
Args:
pattern: Regex pattern to search column names with.
"""
self.pattern = pattern
def __call__(self, ingr: Ingredients) -> list[str]:
"""Select variables from Ingredients
Args:
ingr: object from which to select the variables
Raises:
TypeError: when something other than an Ingredient object is passed
Returns:
Selected variables.
"""
if not isinstance(ingr, Ingredients):
raise TypeError(f"Expected Ingredients, got {ingr.__class__}")
vars = list(ingr.columns)
# Pandas
# vars = ingr.columns.tolist()
if self.roles is not None:
# for v, r in ingr.roles.items():
# print(v, r)
# print(intersection(r, self.roles))
sel_roles = [v for v, r in ingr.roles.items() if intersection(r, self.roles)]
vars = intersection(vars, sel_roles)
if self.types is not None:
sel_types = list(ingr.select_dtypes(include=self.types)) # .columns.tolist()
vars = intersection(vars, sel_types)
if self.names is not None:
vars = intersection(vars, self.names)
if self.pattern is not None:
vars = list(filter(self.pattern.search, vars))
return vars
def __repr__(self):
return self.description
[docs]
def enlist_dt(x: Union[DataType, list[DataType], None]) -> Union[list[DataType], None]:
"""Wrap a pl datatype in a list if it isn't a list yet
Args:
x: object to wrap.
Raises:
TypeError: If neither a datatype nor a list of datatypes is passed
Returns:
_description_
"""
if isinstance(x, DataType):
return [x]
elif isinstance(x, list):
if not all(isinstance(i, DataType) for i in x):
raise TypeError("Only lists of datatypes are allowed.")
return x
elif x is None:
return x
else:
raise TypeError(f"Expected a pl datatype, got {x.__class__}")
[docs]
def enlist_str(x: Union[str, list[str], None]) -> Union[list[str], None]:
"""Wrap a str in a list if it isn't a list yet
Args:
x: object to wrap.
Raises:
TypeError: If neither a str nor a list of strings is passed
Returns:
_description_
"""
if isinstance(x, str):
return [x]
elif isinstance(x, list):
if not all(isinstance(i, str) for i in x):
raise TypeError("Only lists of str are allowed.")
return x
elif x is None:
return x
else:
raise TypeError(f"Expected str or list of str, got {x.__class__}")
[docs]
def intersection(x: list, y: list) -> list:
"""Intersection of two lists
Note:
maintains the order of the first list
does not deduplicate items (i.e., does not return a set)
Args:
x: first list
y: second list
Returns:
Elements in `x` that are also in `y`.
"""
if isinstance(x, str):
x = [x]
if isinstance(y, str):
y = [y]
return [i for i in x if i in y]
[docs]
def all_of(names: Union[str, list[str]]) -> Selector:
"""Define selector for any columns with one of the given names
Args:
names: names to select
Returns:
Object representing the selection rule.
"""
return Selector(description=str(names), names=names)
[docs]
def regex_names(regex: str) -> Selector:
"""Define selector for any columns where the name matches the regex pattern
Args:
pattern: string to be transformed to regex pattern to search for
Returns:
Object representing the selection rule.
"""
pattern = re.compile(regex)
return Selector(description=f"regex: {regex}", pattern=pattern)
[docs]
def starts_with(prefix: str) -> Selector:
"""Define selector for any columns where the name starts with the prefix
Args:
prefix: prefix to search for
Returns:
Object representing the selection rule.
"""
return regex_names(f"^{prefix}")
[docs]
def ends_with(suffix: str) -> Selector:
"""Define selector for any columns where the name ends with the suffix
Args:
prsuffixefix: suffix to search for
Returns:
Object representing the selection rule.
"""
return regex_names(f"{suffix}$")
[docs]
def contains(substring: str) -> Selector:
"""Define selector for any columns where the name contains the substring
Args:
substring: substring to search for
Returns:
Object representing the selection rule.
"""
return regex_names(f"{substring}")
[docs]
def has_role(roles: Union[str, list[str]]) -> Selector:
"""Define selector for any columns with one of the given roles
Args:
roles: roles to select
Returns:
Object representing the selection rule.
"""
return Selector(description=f"roles: {roles}", roles=roles)
[docs]
def has_type(types: Union[str, list[str]]) -> Selector:
"""Define selector for any columns with one of the given types
Args:
types: data types to select
Note:
Data types are selected based on string representation as returned by `df[[varname]].dtype.name`.
Returns:
Object representing the selection rule.
"""
return Selector(description=f"types: {types}", types=types)
[docs]
def all_predictors() -> Selector:
"""Define selector for all predictor columns
Returns:
Object representing the selection rule.
"""
sel = has_role(["predictor"])
sel.description = "all predictors"
return sel
[docs]
def all_numeric_predictors(backend=Backend.POLARS) -> Selector:
"""Define selector for all numerical predictor columns
Returns:
Object representing the selection rule.
"""
sel = all_predictors()
# if backend == Backend.POLARS:
sel.set_types(["Int8", "Int16", "Int32", "Int64", "Float32", "Float64",
"int16", "int32", "int64", "float16", "float32", "float64"])
# else:
# sel.set_types([])
sel.description = "all numeric predictors"
return sel
[docs]
def all_outcomes() -> Selector:
"""Define selector for all outcome columns
Returns:
Object representing the selection rule.
"""
sel = has_role(["outcome"])
sel.description = "all outcomes"
return sel
[docs]
def all_groups() -> Selector:
"""Define selector for all grouping variables
Returns:
Object representing the selection rule.
"""
return Selector(description="all grouping variables", roles=["group"])
[docs]
def select_groups(ingr: Ingredients) -> list[str]:
"""Select any grouping columns
Defines and directly applies Selector(roles=["group"])
Returns:
grouping columns
"""
groups = all_groups()(ingr)
return groups
[docs]
def all_sequences() -> Selector:
"""Define selector for all grouping variables
Returns:
Object representing the selection rule.
"""
return Selector(description="all sequence variables", roles=["sequence"])
[docs]
def select_sequence(ingr: Ingredients) -> list[str]:
"""Select any sequence columns
Defines and directly applies Selector(roles=["sequence"])
Returns:
Grouping columns.
"""
return all_sequences()(ingr)