import copy
from collections import abc
import more_itertools
[docs]def split_list(input_list, num_chunks):
    """
    Split a list into N equal chunks.
    Note: function is similar to numpy.split_array.
    :param input_list: The list to split
    :type input_list: list
    :param num_chunks: The desired number of chunks
    :type num_chunks: int
    """
    if not input_list:
        return [input_list]
    num_items = len(input_list)
    if num_items < num_chunks:
        empty_entries = [[]] * (num_chunks - num_items)
        return list(more_itertools.sliced(input_list, 1)) + empty_entries
    chunk_size = (num_items // num_chunks) + (num_items % num_chunks)
    return list(more_itertools.sliced(input_list, chunk_size)) 
[docs]class IdSet(abc.MutableSet, set):
    """
    An id set is a set that uses the id of an object as the key instead of the
    hash. This means that two objects that compare equal but are different
    instances will be stored separately since id(obj1) != id(obj2).
    NOTE: Using this set with python's builtin immutable datatypes is /strongly/
    discouraged (e.g. str and int are not guaranteed to have different ids
    for separate instances)
    """
[docs]    def __init__(self, initial_set=None):
        self._id_to_obj_map = {}
        if initial_set is not None:
            self.update(initial_set) 
[docs]    def __contains__(self, obj):
        obj_id = id(obj)
        return set.__contains__(self, obj_id) 
    def __iter__(self):
        return iter(self._id_to_obj_map.values())
[docs]    def __len__(self):
        return set.__len__(self) 
    def __copy__(self):
        return IdSet(self)
    def __deepcopy__(self):
        raise TypeError("Deepcopy is incompatible with IdSet")
[docs]    def copy(self):
        # NB default copy method does not call __copy__
        return copy.copy(self) 
[docs]    @classmethod
    def fromIterable(cls, iterable):
        new_set = cls()
        for item in iterable:
            new_set.add(item)
        return new_set 
[docs]    def isdisjoint(self, other):
        self._checkOtherSets(other)
        raise NotImplementedError()
        return set.isdisjoint(self, other) 
[docs]    def issubset(self, other):
        self._checkOtherSets(other)
        return set.issubset(self, other) 
[docs]    def issuperset(self, other):
        self._checkOtherSets(other)
        return set.issuperset(self, other) 
[docs]    def union(self, *other_sets):
        self._checkOtherSets(*other_sets)
        raise NotImplementedError()
        return set.union(self, *other_sets) 
[docs]    def intersection(self, *other_sets):
        self._checkOtherSets(*other_sets)
        id_intersection = set.intersection(self, *other_sets)
        return self.fromIterable(
            self._id_to_obj_map[obj_id] for obj_id in id_intersection) 
[docs]    def difference(self, *other_sets):
        self._checkOtherSets(*other_sets)
        id_difference = set.difference(self, *other_sets)
        return self.fromIterable(
            self._id_to_obj_map[obj_id] for obj_id in id_difference) 
[docs]    def symmetric_difference(self, *other_sets):
        self._checkOtherSets(*other_sets)
        raise NotImplementedError() 
[docs]    def update(self, *other_sets):
        self._checkOtherSets(*other_sets)
        for o_set in other_sets:
            for itm in o_set:
                self.add(itm) 
[docs]    def intersection_update(self, *other_sets):
        self._checkOtherSets(*other_sets)
        raise NotImplementedError() 
[docs]    def difference_update(self, *other_sets):
        self._checkOtherSets(*other_sets)
        raise NotImplementedError() 
[docs]    def symmetric_difference_update(self, *other_sets):
        self._checkOtherSets(*other_sets)
        raise NotImplementedError() 
[docs]    def add(self, obj):
        obj_id = id(obj)
        self._id_to_obj_map[obj_id] = obj
        set.add(self, obj_id) 
[docs]    def discard(self, obj):
        obj_id = id(obj)
        if obj_id in self._id_to_obj_map:
            del self._id_to_obj_map[obj_id]
        set.discard(self, obj_id) 
    def _checkOtherSets(self, *other_sets):
        any_set_builtin = any(
            isinstance(o_set, set) and not isinstance(o_set, IdSet)
            for o_set in other_sets)
        if any_set_builtin:
            raise ValueError('Set operations only supported with other IdSets') 
[docs]class IdItemsView(abc.ItemsView):
[docs]    def __init__(self, id_dict, id_map):
        self.id_dict = id_dict
        self.id_map = id_map
        self.id_map_items = None 
[docs]    def __contains__(self, item):
        k, v = item
        return k in self.id_dict and self.id_dict[k] == v 
    def __iter__(self):
        self.id_map_items = iter(self.id_map.items())
        return self
    def __next__(self):
        obj_id, obj = next(self.id_map_items)
        return obj, self.id_dict[obj]
[docs]    def __len__(self):
        return len(self.id_dict)  
[docs]class IdDict(abc.MutableMapping, dict):
    """
    An id dict is a dictionary that uses the id of an object as the key
    instead of the hash. This means that two objects that compare equal but are
    different instances will be stored separately since id(obj1) != id(obj2).
    NOTE: Using this dict with python's builtin immutable datatypes is /strongly/
    discouraged (e.g. str and int are not guaranteed to have different ids
    for separate instances)
    """
[docs]    def __init__(self, initial_dict=None):
        self._id_to_obj_map = {}
        if initial_dict is not None:
            if not isinstance(initial_dict, IdDict):
                err_msg = 'IdDict can only be initialized with another IdDict'
                raise ValueError(err_msg)
            self.update(initial_dict) 
    def __getitem__(self, obj):
        obj_id = id(obj)
        try:
            return dict.__getitem__(self, obj_id)
        except KeyError:
            raise KeyError(str(obj)) from None
    def __setitem__(self, obj, value):
        obj_id = id(obj)
        self._id_to_obj_map[obj_id] = obj
        dict.__setitem__(self, obj_id, value)
    def __delitem__(self, obj):
        obj_id = id(obj)
        del self._id_to_obj_map[obj_id]
        dict.__delitem__(self, obj_id)
[docs]    def setdefault(self, key, default):
        cur_val = self.get(key)
        if cur_val is None:
            self[key] = default
        return self[key] 
[docs]    def __contains__(self, obj):
        obj_id = id(obj)
        return dict.__contains__(self, obj_id) 
[docs]    def __len__(self):
        return dict.__len__(self) 
[docs]    def items(self):
        return IdItemsView(self, self._id_to_obj_map) 
[docs]    def keys(self):
        return self._id_to_obj_map.values() 
    def __eq__(self, other):
        if isinstance(other, IdDict):
            sentinel = object()
            if (len(self) == len(other) and
                    all(other.get(key, sentinel) == self[key] for key in self)):
                return True
            else:
                return False
        else:
            return NotImplemented
    def __iter__(self):
        return iter(self._id_to_obj_map.values())
    def __repr__(self):
        item_reprs = []
        for k, v in self.items():
            item_reprs.append(f'{repr(k)}: {repr(v)}')
        return 'IdDict({' + ', '.join(item_reprs) + '})'
[docs]    def has_key(self, obj):
        return self.__contains__(obj) 
[docs]    def update(self, other_dict):
        if not isinstance(other_dict, IdDict):
            raise ValueError('Update is only supported with other IdDicts')
        self.updateFromIterable(other_dict.items()) 
[docs]    def updateFromIterable(self, iterable):
        for k, v in iterable:
            self[k] = v 
[docs]    @classmethod
    def fromIterable(cls, iterable):
        id_dict = cls()
        id_dict.updateFromIterable(iterable)
        return id_dict 
[docs]    def clear(self):
        self._id_to_obj_map.clear()
        dict.clear(self) 
[docs]    def copy(self):
        return IdDict(self)  
[docs]class DefaultIdDict(IdDict):
    """
    A dict that is both an id dict and a defaultdict.
    """
[docs]    def __init__(self, default_factory):
        super().__init__()
        self._default_factory = default_factory 
    def __getitem__(self, obj):
        obj_id = id(obj)
        try:
            return dict.__getitem__(self, obj_id)
        except KeyError:
            default_value = self._default_factory()
            self[obj] = default_value
            return default_value
[docs]    @classmethod
    def fromIterable(cls, iterable):
        raise NotImplementedError() 
[docs]    def setdefault(self, key, default):
        raise NotImplementedError()  
[docs]class DefaultFactoryDictMixin:
    """
    A mixin to use with `dict`'s that allows the dict to use a factory function
    similar to `defaultdict`. The key distinction here is that the factory
    function will be passed the key itself instead of called without any
    arguments.
    .. NOTE::
        Despite the name, this mixin works with classes as well. When passed
        a class, the constructor will be called and passed the keys.
    .. WARNING::
        This mixin will not work with factory functions that expect only one
        tuple as an argument. This is due to the way `__getitem__` packages up
        all keys in a single call into one tuple.
    """
[docs]    def __init__(self, factory_func, *args, **kwargs):
        """
        :param factory_fun: A callable to create a value from.
        :type  factory_fun: callable
        """
        self._factory_func = factory_func
        super().__init__(*args, **kwargs) 
    def __missing__(self, key):
        if isinstance(key, tuple):
            value = self._factory_func(*key)
        else:
            value = self._factory_func(key)
        self[key] = value
        return value 
[docs]class DefaultFactoryDict(DefaultFactoryDictMixin, dict):
    """
    A basic `dict` using the `DefaultFactoryDictMixin`. This is separated
    from the mixin to allow other `dict` subclasses to easily subclass
    `DefaultFactoryDictMixin`.
    Example usage::
        stringified_objs = DefaultFactoryDict(str)
        assert 1 not in stringified_objs
        print(stringified_objs[1]) # '1'
    """