import copy
from collections import abc
import more_itertools
[docs]def split_list(input_list, num_chunks):
"""
Split a list into N equal chunks.
Note: function is similar to numpy.split_array.
:param input_list: The list to split
:type input_list: list
:param num_chunks: The desired number of chunks
:type num_chunks: int
"""
if not input_list:
return [input_list]
num_items = len(input_list)
if num_items < num_chunks:
empty_entries = [[]] * (num_chunks - num_items)
return list(more_itertools.sliced(input_list, 1)) + empty_entries
chunk_size = (num_items // num_chunks) + (num_items % num_chunks)
return list(more_itertools.sliced(input_list, chunk_size))
[docs]class IdSet(abc.MutableSet, set):
"""
An id set is a set that uses the id of an object as the key instead of the
hash. This means that two objects that compare equal but are different
instances will be stored separately since id(obj1) != id(obj2).
NOTE: Using this set with python's builtin immutable datatypes is /strongly/
discouraged (e.g. str and int are not guaranteed to have different ids
for separate instances)
"""
[docs] def __init__(self, initial_set=None):
self._id_to_obj_map = {}
if initial_set is not None:
self.update(initial_set)
[docs] def __contains__(self, obj):
obj_id = id(obj)
return set.__contains__(self, obj_id)
def __iter__(self):
return iter(self._id_to_obj_map.values())
[docs] def __len__(self):
return set.__len__(self)
def __copy__(self):
return IdSet(self)
def __deepcopy__(self):
raise TypeError("Deepcopy is incompatible with IdSet")
[docs] def copy(self):
# NB default copy method does not call __copy__
return copy.copy(self)
[docs] @classmethod
def fromIterable(cls, iterable):
new_set = cls()
for item in iterable:
new_set.add(item)
return new_set
[docs] def isdisjoint(self, other):
self._checkOtherSets(other)
raise NotImplementedError()
return set.isdisjoint(self, other)
[docs] def issubset(self, other):
self._checkOtherSets(other)
return set.issubset(self, other)
[docs] def issuperset(self, other):
self._checkOtherSets(other)
return set.issuperset(self, other)
[docs] def union(self, *other_sets):
self._checkOtherSets(*other_sets)
raise NotImplementedError()
return set.union(self, *other_sets)
[docs] def intersection(self, *other_sets):
self._checkOtherSets(*other_sets)
id_intersection = set.intersection(self, *other_sets)
return self.fromIterable(
self._id_to_obj_map[obj_id] for obj_id in id_intersection)
[docs] def difference(self, *other_sets):
self._checkOtherSets(*other_sets)
id_difference = set.difference(self, *other_sets)
return self.fromIterable(
self._id_to_obj_map[obj_id] for obj_id in id_difference)
[docs] def symmetric_difference(self, *other_sets):
self._checkOtherSets(*other_sets)
raise NotImplementedError()
[docs] def update(self, *other_sets):
self._checkOtherSets(*other_sets)
for o_set in other_sets:
for itm in o_set:
self.add(itm)
[docs] def intersection_update(self, *other_sets):
self._checkOtherSets(*other_sets)
raise NotImplementedError()
[docs] def difference_update(self, *other_sets):
self._checkOtherSets(*other_sets)
raise NotImplementedError()
[docs] def symmetric_difference_update(self, *other_sets):
self._checkOtherSets(*other_sets)
raise NotImplementedError()
[docs] def add(self, obj):
obj_id = id(obj)
self._id_to_obj_map[obj_id] = obj
set.add(self, obj_id)
[docs] def discard(self, obj):
obj_id = id(obj)
if obj_id in self._id_to_obj_map:
del self._id_to_obj_map[obj_id]
set.discard(self, obj_id)
def _checkOtherSets(self, *other_sets):
any_set_builtin = any(
isinstance(o_set, set) and not isinstance(o_set, IdSet)
for o_set in other_sets)
if any_set_builtin:
raise ValueError('Set operations only supported with other IdSets')
[docs]class IdItemsView(abc.ItemsView):
[docs] def __init__(self, id_dict, id_map):
self.id_dict = id_dict
self.id_map = id_map
self.id_map_items = None
[docs] def __contains__(self, item):
k, v = item
return k in self.id_dict and self.id_dict[k] == v
def __iter__(self):
self.id_map_items = iter(self.id_map.items())
return self
def __next__(self):
obj_id, obj = next(self.id_map_items)
return obj, self.id_dict[obj]
[docs] def __len__(self):
return len(self.id_dict)
[docs]class IdDict(abc.MutableMapping, dict):
"""
An id dict is a dictionary that uses the id of an object as the key
instead of the hash. This means that two objects that compare equal but are
different instances will be stored separately since id(obj1) != id(obj2).
NOTE: Using this dict with python's builtin immutable datatypes is /strongly/
discouraged (e.g. str and int are not guaranteed to have different ids
for separate instances)
"""
[docs] def __init__(self, initial_dict=None):
self._id_to_obj_map = {}
if initial_dict is not None:
if not isinstance(initial_dict, IdDict):
err_msg = 'IdDict can only be initialized with another IdDict'
raise ValueError(err_msg)
self.update(initial_dict)
def __getitem__(self, obj):
obj_id = id(obj)
try:
return dict.__getitem__(self, obj_id)
except KeyError:
raise KeyError(str(obj)) from None
def __setitem__(self, obj, value):
obj_id = id(obj)
self._id_to_obj_map[obj_id] = obj
dict.__setitem__(self, obj_id, value)
def __delitem__(self, obj):
obj_id = id(obj)
del self._id_to_obj_map[obj_id]
dict.__delitem__(self, obj_id)
[docs] def setdefault(self, key, default):
cur_val = self.get(key)
if cur_val is None:
self[key] = default
return self[key]
[docs] def __contains__(self, obj):
obj_id = id(obj)
return dict.__contains__(self, obj_id)
[docs] def __len__(self):
return dict.__len__(self)
[docs] def items(self):
return IdItemsView(self, self._id_to_obj_map)
[docs] def keys(self):
return self._id_to_obj_map.values()
def __eq__(self, other):
if isinstance(other, IdDict):
sentinel = object()
if (len(self) == len(other) and
all(other.get(key, sentinel) == self[key] for key in self)):
return True
else:
return False
else:
return NotImplemented
def __iter__(self):
return iter(self._id_to_obj_map.values())
def __repr__(self):
item_reprs = []
for k, v in self.items():
item_reprs.append(f'{repr(k)}: {repr(v)}')
return 'IdDict({' + ', '.join(item_reprs) + '})'
[docs] def has_key(self, obj):
return self.__contains__(obj)
[docs] def update(self, other_dict):
if not isinstance(other_dict, IdDict):
raise ValueError('Update is only supported with other IdDicts')
self.updateFromIterable(other_dict.items())
[docs] def updateFromIterable(self, iterable):
for k, v in iterable:
self[k] = v
[docs] @classmethod
def fromIterable(cls, iterable):
id_dict = cls()
id_dict.updateFromIterable(iterable)
return id_dict
[docs] def clear(self):
self._id_to_obj_map.clear()
dict.clear(self)
[docs] def copy(self):
return IdDict(self)
[docs]class DefaultIdDict(IdDict):
"""
A dict that is both an id dict and a defaultdict.
"""
[docs] def __init__(self, default_factory):
super().__init__()
self._default_factory = default_factory
def __getitem__(self, obj):
obj_id = id(obj)
try:
return dict.__getitem__(self, obj_id)
except KeyError:
default_value = self._default_factory()
self[obj] = default_value
return default_value
[docs] @classmethod
def fromIterable(cls, iterable):
raise NotImplementedError()
[docs] def setdefault(self, key, default):
raise NotImplementedError()
[docs]class DefaultFactoryDictMixin:
"""
A mixin to use with `dict`'s that allows the dict to use a factory function
similar to `defaultdict`. The key distinction here is that the factory
function will be passed the key itself instead of called without any
arguments.
.. NOTE::
Despite the name, this mixin works with classes as well. When passed
a class, the constructor will be called and passed the keys.
.. WARNING::
This mixin will not work with factory functions that expect only one
tuple as an argument. This is due to the way `__getitem__` packages up
all keys in a single call into one tuple.
"""
[docs] def __init__(self, factory_func, *args, **kwargs):
"""
:param factory_fun: A callable to create a value from.
:type factory_fun: callable
"""
self._factory_func = factory_func
super().__init__(*args, **kwargs)
def __missing__(self, key):
if isinstance(key, tuple):
value = self._factory_func(*key)
else:
value = self._factory_func(key)
self[key] = value
return value
[docs]class DefaultFactoryDict(DefaultFactoryDictMixin, dict):
"""
A basic `dict` using the `DefaultFactoryDictMixin`. This is separated
from the mixin to allow other `dict` subclasses to easily subclass
`DefaultFactoryDictMixin`.
Example usage::
stringified_objs = DefaultFactoryDict(str)
assert 1 not in stringified_objs
print(stringified_objs[1]) # '1'
"""