# Copyright 2019 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import enum
import itertools
from typing import Any, Dict, List, Set, Iterable, Iterator, Optional, Tuple, Union, Callable
from lale.schema_simplifier import findRelevantFields, narrowToGivenRelevantFields, simplify
from .schema_utils import Schema, SchemaEnum
# logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
[docs]class DiscoveredEnums(object):
def __init__(self, enums:Optional[SchemaEnum]=None, children:Optional[Dict[str,"DiscoveredEnums"]]=None) -> None:
self.enums=enums
self.children=children
def __str__(self)->str:
def val_as_str(v):
if v is None:
return "null"
elif isinstance(v, str):
return f"'{v}'"
else:
return str(v)
en = ""
if self.enums:
ens = [val_as_str(v) for v in self.enums]
en = ", ".join(sorted(ens))
ch = ""
if self.children:
chs = [f"{str(k)}->{str(v)}" for k,v in self.children.items()]
ch = ",".join(chs)
if en and ch:
en = en + "; "
return "<" + en + ch + ">"
[docs]def schemaToDiscoveredEnums(schema:Schema)->Optional[DiscoveredEnums]:
""" Given a schema, returns a positive enumeration set.
This is very conservative, and even includes negated enum constants
(since the assumption is that they may, in some contexts, be valid)
"""
def combineDiscoveredEnums(combine:Callable[[Iterable[SchemaEnum]], Optional[SchemaEnum]], des:Iterable[Optional[DiscoveredEnums]])->Optional[DiscoveredEnums]:
enums:List[SchemaEnum] = []
children:Dict[str, List[DiscoveredEnums]] = {}
for de in des:
if de is None:
continue
if de.enums is not None:
enums.append(de.enums)
if de.children is not None:
for cn,cv in de.children.items():
if cv is None:
continue
if cn in children:
children[cn].append(cv)
else:
children[cn] = [cv]
combined_enums:Optional[SchemaEnum] = None
if enums:
combined_enums = combine(enums)
if not children:
if combined_enums is None:
return None
else:
return DiscoveredEnums(enums=combined_enums)
else:
combined_children:Dict[str,DiscoveredEnums] = {}
for ccn,ccv in children.items():
if not ccv:
continue
ccvc = combineDiscoveredEnums(combine, ccv)
if ccvc is not None:
combined_children[ccn] = ccvc
return DiscoveredEnums(enums=combined_enums, children=combined_children)
def joinDiscoveredEnums(des:Iterable[Optional[DiscoveredEnums]])->Optional[DiscoveredEnums]:
def op(args:Iterable[SchemaEnum]) -> Optional[SchemaEnum]:
return set.union(*args)
return combineDiscoveredEnums(op, des)
def meetDiscoveredEnums(des:Tuple[Optional[DiscoveredEnums], ...])->Optional[DiscoveredEnums]:
def op(args:Iterable[SchemaEnum]) -> Optional[SchemaEnum]:
return set.intersection(*args)
return combineDiscoveredEnums(op, des)
if schema is True or schema is False:
return None
if 'enum' in schema:
# TODO: we should validate the enum elements according to the schema, like schema2search_space does
return DiscoveredEnums(enums=set(schema['enum']))
if 'type' in schema:
typ = schema['type']
if typ == "object" and 'properties' in schema:
props = schema['properties']
pret:Dict[str, DiscoveredEnums] = {}
for p,s in props.items():
pos = schemaToDiscoveredEnums(s)
if pos is not None:
pret[p] = pos
if pret:
return DiscoveredEnums(children=pret)
else:
return None
else:
return None
if 'not' in schema:
neg = schemaToDiscoveredEnums(schema['not'])
return neg
if 'allOf' in schema:
posl = [schemaToDiscoveredEnums(s) for s in schema['allOf']]
return joinDiscoveredEnums(posl)
if 'anyOf' in schema:
posl = [schemaToDiscoveredEnums(s) for s in schema['anyOf']]
return joinDiscoveredEnums(posl)
if 'oneOf' in schema:
posl = [schemaToDiscoveredEnums(s) for s in schema['oneOf']]
return joinDiscoveredEnums(posl)
return None
[docs]def accumulateDiscoveredEnumsToPythonEnums(de:Optional[DiscoveredEnums], path:List[str], acc:Dict[str, enum.Enum])->None:
def withEnumValue(e:str)->Tuple[str,Any]:
if isinstance(e, str):
return (e.replace('-', '_'), e)
elif isinstance(e, (int, float, complex)):
return ("num" + str(e), e)
else:
logger.info(f"Unknown type ({type(e)}) of enumeration constant {e}, not handling very well")
return (str(e), e)
if de is None:
return
if de.enums is not None:
ppath, _ = withEnumValue("_".join(path))
epath = ".".join(path)
acc[ppath] = enum.Enum(epath, (withEnumValue(x) for x in de.enums if x is not None))
if de.children is not None:
for k in de.children:
accumulateDiscoveredEnumsToPythonEnums(de.children[k], [k] + path, acc)
[docs]def discoveredEnumsToPythonEnums(de:Optional[DiscoveredEnums])->Dict[str, enum.Enum]:
acc:Dict[str, enum.Enum] = {}
accumulateDiscoveredEnumsToPythonEnums(de, [], acc)
return acc
[docs]def schemaToPythonEnums(schema:Schema)->Dict[str, enum.Enum]:
de = schemaToDiscoveredEnums(schema)
enums = discoveredEnumsToPythonEnums(de)
return enums
[docs]def addDictAsFields(obj:Any, d:Dict[str, Any], force=False)->None:
if d is None:
return
for k, v in d.items():
if k == "":
logger.warning(f"There was a top level enumeration specified, so it is not being added to {obj._name}")
elif hasattr(obj, k) and not force:
logger.error(f"The object {obj._name} already has the field {k}. This conflicts with our attempt at adding that key as an enumeration field")
else:
setattr(obj, k, v)
[docs]def addSchemaEnumsAsFields(obj:Any, schema:Schema, force=False)->None:
enums = schemaToPythonEnums(schema)
addDictAsFields(obj, enums, force)