"""
Need a diff in a jiffy? Use diffy!
"""
import collections
from collections import abc
from functools import singledispatch
from schrodinger.utils.scollections import IdDict
from schrodinger.utils.scollections import IdSet
"""
Generic Functions
"""
[docs]@singledispatch
def get_diff(new_state, old_state):
    """
    Given two states of an object, calculate what changed between them.
    """
    err_msg = f'get_diff not implemented for type {type(old_state).__name__}'
    raise NotImplementedError(err_msg) 
[docs]@singledispatch
def get_removed(new_state, old_state):
    err_msg = f'get_removed not implemented for type {type(old_state).__name__}'
    raise NotImplementedError(err_msg) 
[docs]@singledispatch
def get_added(new_state, old_state):
    err_msg = f'get_added not implemented for type {type(old_state).__name__}'
    raise NotImplementedError(err_msg) 
[docs]@singledispatch
def get_updated(new_state, old_state):
    err_msg = f'get_updated not implemented for type {type(old_state).__name__}'
    raise NotImplementedError(err_msg) 
[docs]@singledispatch
def get_moved(new_state, old_state):
    err_msg = f'get_updated not implemented for type {type(old_state).__name__}'
    raise NotImplementedError(err_msg) 
"""
Implementations of generic functions.
"""
ListDiff = collections.namedtuple('ListDiff', 'added removed moved')
[docs]@get_diff.register(list)
def get_diff_list(new_state, old_state):
    """
    Calculate what was added, removed, and moved between two states of a list.
    Note that items are compared by identity not equality (ie `is` rather than
    `==`).
    :return: A namedtuple describing what was added, removed, and moved between
        two lists. See `get_added`, `get_removed`, and `get_moved` more details.
    :rtype: ListDiff(set, set, set)
    """
    added = get_added(new_state, old_state)
    removed = get_removed(new_state, old_state)
    moved = get_moved(new_state, old_state)
    return ListDiff(added, removed, moved) 
SetDiff = collections.namedtuple('SetDiff', 'added removed')
[docs]@get_diff.register(set)
def get_diff_set(new_state, old_state):
    """
    Calculate what was added and removed between two states of a set.
    :return: A namedtuple describing what was added and removed.
    :rtype: SetDiff(set, set)
    """
    added = get_added(new_state, old_state)
    removed = get_removed(new_state, old_state)
    return SetDiff(added, removed) 
[docs]@get_removed.register(set)
def get_removed_set(new_state, old_state):
    """
    Calculate what was removed between two states of a set.
    :rtype: set
    """
    old_state, new_state = set(old_state), set(new_state)
    return old_state - new_state 
[docs]@get_removed.register(list)
def get_removed_list(new_state, old_state):
    """
    :return: A set of tuples, each describing an item that was removed and
        and its index in `old_state`
    :rtype: set((object, int))
    """
    raw_removed = {
        _HashableTuple((o, i))
        for i, (o, n) in enumerate(zip(old_state, new_state))
        if o != n
    }
    for idx, item in enumerate(old_state[len(new_state):], len(new_state)):
        raw_removed.add(_HashableTuple((item, idx)))
    moved = get_moved(new_state, old_state)
    true_removed = raw_removed.difference(
        [_HashableTuple((item, old_state.index(item))) for item, idx in moved])
    return true_removed 
[docs]@get_added.register(list)
def get_added_list(new_state, old_state):
    """
    :return: A set of tuples, each describing an item that was added and
        and its index in `new_state`.
    :rtype: set((object, int))
    """
    raw_added = {
        _HashableTuple((n, idx))
        for idx, (o, n) in enumerate(zip(old_state, new_state))
        if o != n
    }
    for idx, item in enumerate(new_state[len(old_state):], len(old_state)):
        raw_added.add(_HashableTuple((item, idx)))
    moved = get_moved(new_state, old_state)
    true_added = raw_added.difference(
        [_HashableTuple((item, idx)) for item, idx in moved])
    return true_added 
[docs]@get_added.register(set)
def get_added_set(new_state, old_state):
    """
    Calculate what was removed between two states of a set.
    :rtype: set
    """
    return new_state - old_state 
DictDiff = collections.namedtuple('DictDiff', 'added removed updated')
[docs]@get_diff.register(dict)
def get_diff_dict(new_state, old_state):
    """
    Return dictionary items that have been added, removed, and updated.
    :return: A namedtuple describing what was added, removed, and moved between
            two dicts. See `get_added`, `get_removed`, and `get_updated` more
            details.
    :rtype: DictDiff(dict, dict, dict)
    """
    added = get_added(new_state, old_state)
    removed = get_removed(new_state, old_state)
    updated = get_updated(new_state, old_state)
    return DictDiff(added, removed, updated) 
[docs]@get_added.register(dict)
def get_added_dict(new_state, old_state):
    """
    :return: A dictionary with items in `new_state` but not in `old_state`.
    """
    return {k: new_state[k] for k in new_state if k not in old_state} 
[docs]@get_removed.register(dict)
def get_removed_dict(new_state, old_state):
    """
    :return: A dictionary with items in `old_state` but not in `new_state`.
    """
    return {k: v for k, v in old_state.items() if k not in new_state} 
[docs]@get_updated.register(dict)
def get_updated_dict(new_state, old_state):
    """
    :return: A dictionary with values that have changed from `old_state` to
            `new_state`. The values in the returned dictionary will be those
            from `new_state`.
    """
    return {
        k: new_state[k]
        for k, v in old_state.items()
        if k in new_state and new_state[k] != v
    } 
[docs]@get_moved.register(list)
def get_moved_list(new_state, old_state):
    """
    :return: A set of tuples, each describing an item that was moved and
        and its index in `new_state`
    :rtype: set((object, int))
    """
    item_to_new_idx = IdDict.fromIterable(
        (item, idx) for idx, item in enumerate(new_state))
    for idx, item in enumerate(old_state):
        if item_to_new_idx.get(item) == idx:
            item_to_new_idx.pop(item)
    shared_items = IdSet(old_state)
    shared_items = shared_items.intersection(IdSet(item_to_new_idx.keys()))
    return {
        _HashableTuple((item, item_to_new_idx[item])) for item in shared_items
    } 
class _HashableTuple(tuple):
    def __hash__(self):
        """
        This will return an identical hash value to a regular tuple if this
        tuple is already made up of hashable items. If any items are not hashable,
        we use the id of the value instead.
        """
        hashes = []
        for idx, item in enumerate(self):
            try:
                hashes.append(hash(item))
            except TypeError:
                # If the item is unhashable, settle for its id.
                hashes.append(id(item))
        return hash(tuple(hashes))
    def __eq__(self, other):
        if len(self) != len(other):
            return False
        else:
            for s_item, o_item in zip(self, other):
                if not isinstance(s_item, abc.Hashable):
                    if s_item is not o_item:
                        return False
                else:
                    if s_item != o_item:
                        return False
            return True