# Copyright 2019 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Optional, List, Set, TypeVar
import random
import math
import warnings
import lale.operators as Ops
from lale.pretty_print import hyperparams_to_string
from lale.search.PGO import remove_defaults_dict
from lale.util.Visitor import Visitor
# This method (and the to_lale() method on the returned value)
# are the only ones intended to be exported
[docs]def make_sklearn_compat(op:'Ops.Operator')->'SKlearnCompatWrapper':
"""Top level function for providing compatibiltiy with sklearn operations
This returns a wrapper around the provided sklearn operator graph which can be passed
to sklearn methods such as clone and GridSearchCV
The wrapper may modify the wrapped lale operator/pipeline as part of providing
compatibility with these methods.
After the sklearn operation is complete,
SKlearnCompatWrapper.to_lale() can be called to recover the
wrapped lale operator for future use
"""
return SKlearnCompatWrapper.make_wrapper(op)
[docs]def sklearn_compat_clone(impl:Any)->Any:
if impl is None:
return None
from sklearn.base import clone
cp = clone(impl, safe=False)
return cp
[docs]def clone_lale(op:Ops.Operator)->Ops.Operator:
return op._lale_clone(sklearn_compat_clone)
[docs]class WithoutGetParams(object):
""" This wrapper forwards everything except "get_attr" to what it is wrapping
"""
def __init__(self, base):
self._base = base
assert self._base != self
def __getattr__(self, name):
# This is needed because in python copy skips calling the __init__ method
if name == "_base":
raise AttributeError
if name == 'get_params':
raise AttributeError
else:
return getattr(self._base, name)
[docs] @classmethod
def clone_wgp(cls, obj:'WithoutGetParams')->'WithoutGetParams':
while isinstance(obj, WithoutGetParams):
obj = obj._base
assert isinstance(obj, Ops.Operator)
return WithoutGetParams(clone_lale(obj))
[docs]def partition_sklearn_params(d:Dict[str, Any])->Dict[str, Dict[str, Any]]:
ret:Dict[str, Dict[str, Any]] = {}
for k, v in d.items():
ks = k.split("__", 1)
assert len(ks) == 2
bucket:Dict[str, Any] = {}
group:str = ks[0]
param:str = ks[1]
if group in ret:
bucket = ret[group]
else:
ret[group] = bucket
assert param not in bucket
bucket[param] = v
return ret
[docs]def set_operator_params(op:'Ops.Operator', **impl_params)->Ops.TrainableOperator:
"""May return a new operator, in which case the old one should be overwritten
"""
if isinstance(op, Ops.PlannedIndividualOp):
return op.set_params(**impl_params)
elif isinstance(op, Ops.Pipeline):
steps = op.steps()
partitioned_params:Dict[str,Dict[str, Any]] = partition_sklearn_params(impl_params)
found_names:Set[str] = set()
step_map:Dict[Ops.Operator, Ops.TrainableOperator] = {}
for s in steps:
name = s.name()
found_names.add(name)
params:Dict[str, Any] = {}
if name in partitioned_params:
params = partitioned_params[name]
new_s = set_operator_params(s, **params)
if s != new_s:
step_map[s] = new_s
# make sure that no parameters were passed in for operations
# that are not actually part of this pipeline
assert set(partitioned_params.keys()).issubset(found_names)
if step_map:
op.subst_steps(step_map)
if not isinstance(op, Ops.TrainablePipeline):
# As a result of choices made, we may now be a TrainableIndividualOp
ret = Ops.get_pipeline_of_applicable_type(op.steps(), op.edges(), ordered=True)
if not isinstance(ret, Ops.TrainableOperator):
assert False
return ret
else:
return op
else:
assert isinstance(op, Ops.TrainableOperator)
return op
elif isinstance(op, Ops.OperatorChoice):
discriminant_name:str = "_lale_discriminant"
assert discriminant_name in impl_params
choice_name = impl_params[discriminant_name]
choices:List[Ops.Operator] = [step for step in op.steps() if step.name() == choice_name]
assert len(choices)==1, f"found {len(choices)} operators with the same name: {choice_name}"
choice:Ops.Operator = choices[0]
chosen_params = dict(impl_params)
del chosen_params[discriminant_name]
new_step = set_operator_params(choice, **chosen_params)
# we remove the OperatorChoice, replacing it with the branch that was taken
return new_step
else:
assert False, f"Not yet supported operation of type: {op.__class__.__name__}"
[docs]class SKlearnCompatWrapper(object):
_base:WithoutGetParams
# This is used to trick clone into leaving us alone
_old_params_for_clone:Optional[Dict[str, Any]]
[docs] @classmethod
def make_wrapper(cls, base:'Ops.Operator'):
b:Any = base
if isinstance(base, SKlearnCompatWrapper):
return base
elif not isinstance(base, WithoutGetParams):
b = WithoutGetParams(base)
return cls(__lale_wrapper_init_base=b)
def __init__(self, **kwargs):
if '__lale_wrapper_init_base' in kwargs:
# if we are being called by make_wrapper
# then we don't need to make a copy
self._base = kwargs['__lale_wrapper_init_base']
self._old_params_for_clone = None
else:
# otherwise, we are part of a get_params/init clone
# and we need to make a copy
self.init_params_internal(**kwargs)
assert self._base != self
[docs] def init_params_internal(self, **kwargs):
op = kwargs['__lale_wrapper_base']
self._base = WithoutGetParams.clone_wgp(op)
self._old_params_for_clone = kwargs
[docs] def get_params_internal(self, out:Dict[str,Any]):
out['__lale_wrapper_base'] = self._base
[docs] def set_params_internal(self, **impl_params):
self._base = impl_params['__lale_wrapper_base']
assert self._base != self
[docs] def fixup_params_internal(self, **params):
return params
[docs] def to_lale(self)->Ops.Operator:
cur:Any = self
assert cur != None
assert cur._base != None
cur = cur._base
while isinstance(cur, WithoutGetParams):
cur = cur._base
assert isinstance(cur, Ops.Operator)
return cur
# sklearn calls __repr__ instead of __str__
def __repr__(self):
op = self.to_lale()
if isinstance(op, Ops.TrainableIndividualOp):
name = op.name()
hyps = ""
hps = op.hyperparams()
if hps is not None:
hyps = hyperparams_to_string(hps)
return name + "(" + hyps + ")"
else:
return super().__repr__()
def __getattr__(self, name):
# This is needed because in python copy skips calling the __init__ method
if name == "_base":
raise AttributeError
return getattr(self._base, name)
[docs] def get_params(self, deep:bool = True)->Dict[str,Any]:
out:Dict[str,Any] = {}
if not deep:
if self._old_params_for_clone is not None:
# lie to clone to make it happy
params = self._old_params_for_clone
self._old_params_for_clone = None
return params
else:
self.get_params_internal(out)
else:
pass #TODO
return out
[docs] def fit(self, X, y=None, **fit_params):
if hasattr(self._base, 'fit'):
filtered_params = remove_defaults_dict(fit_params)
return self._base.fit(X, y=y, **filtered_params)
else:
pass
[docs] def set_params(self, **impl_params):
if '__lale_wrapper_base' in impl_params:
self.set_params_internal(**impl_params)
else:
prev = self
cur = self._base
assert prev != cur
assert cur != None
while isinstance(cur, WithoutGetParams):
assert cur != cur._base
prev = cur
cur = cur._base
if not isinstance(cur, Ops.Operator):
assert False
assert isinstance(cur, Ops.Operator)
fixed_params = self.fixup_params_internal(**impl_params)
new_s = set_operator_params(cur, **fixed_params)
if not isinstance(new_s, Ops.TrainableOperator):
assert False
if new_s != cur:
prev._base = new_s
return self
[docs] def hyperparam_defaults(self)->Dict[str,Any]:
return DefaultsVisitor.run(self.to_lale())
[docs]class DefaultsVisitor(Visitor):
[docs] @classmethod
def run(cls, op:Ops.Operator)->Dict[str,Any]:
visitor = cls()
accepting_op:Any = op
return accepting_op.accept(visitor)
def __init__(self):
super(DefaultsVisitor, self).__init__()
[docs] def visitIndividualOp(self, op:Ops.IndividualOp)->Dict[str,Any]:
return op.hyperparam_defaults()
visitPlannedIndividualOp = visitIndividualOp
visitTrainableIndividualOp = visitIndividualOp
visitTrainedIndividualOp = visitIndividualOp
[docs] def visitPipeline(self, op:Ops.PlannedPipeline)->Dict[str,Any]:
defaults_list:Iterable[Dict[str,Any]] = (
nest_HPparams(s.name(), s.accept(self)) for s in op.steps())
defaults:Dict[str,Any] = {}
for d in defaults_list:
defaults.update(d)
return defaults
visitPlannedPipeline = visitPipeline
visitTrainablePipeline = visitPipeline
visitTrainedPipeline = visitPipeline
[docs] def visitOperatorChoice(self, op:Ops.OperatorChoice)->Dict[str,Any]:
defaults_list:Iterable[Dict[str,Any]] = (
s.accept(self) for s in op.steps())
defaults : Dict[str,Any] = {}
for d in defaults_list:
defaults.update(d)
return defaults
# Auxiliary functions
V = TypeVar('V')
[docs]def nest_HPparam(name:str, key:str):
return name + "__" + key
[docs]def nest_HPparams(name:str, grid:Dict[str,V])->Dict[str,V]:
return {(nest_HPparam(name, k)):v for k, v in grid.items()}
[docs]def nest_all_HPparams(name:str, grids:List[Dict[str,V]])->List[Dict[str,V]]:
""" Given the name of an operator in a pipeline, this transforms every key(parameter name) in the grids
to use the operator name as a prefix (separated by __). This is the convention in scikit-learn pipelines.
"""
return [nest_HPparams(name, grid) for grid in grids]
[docs]def unnest_HPparams(k:str)->List[str]:
return k.split("__")