"""
A class for simple arithmetic with uncertainty.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Yujie Wu
import math
from past.utils import old_div
NAN = float('nan')
def _is_equal(a, b):
    if isinstance(a, float) or isinstance(b, float):
        # Check whether both values are NAN, which does not compare equal to
        # itself
        if a != a and b != b:
            return True
        # Does floating number comparison when either argument is a `float`.
        # Hard-code the tolerance (relative error) to 1E-6,
        # likely nobody would ever care to customize it.
        c = max(abs(a), abs(b))
        return abs(a - b) <= (1E-6 * c)
    # Otherwise, we require that the == operator is defined for `a` and `b`.
    return a == b
[docs]class Measurement(object):
    r"""
    Basic method for uncertainty propagation:
    Say we have a few measurements: x1, x2, x3, ..., and they each has a uncertainty: d1, d2, d3, ..., respectively.
    Now we have a function: f( x1, x2, x3, ... ), we want to get the value of this function with given measurements x1, x2, x3,
    ... and also the uncertainty of the result of f.
    A way to do this is the following:
        1. We need to get the contribution to the uncertainty of f result due to each measurement: fd1, fd2, fd3, ...
           This can be given by the following equations::
               fd1 = df / dx1 * d1
               fd2 = df / dx2 * d2
               fd3 = df / dx3 * d3
               ...
           where `df / dx1` is a partial derivative of f with respect to x1.
        2. With fd1, fd2, fd3, ..., we can get the uncertainty of f using this
           equation: fd = math.sqrt( fd1 * fd1 + fd2 * fd2 + fd3 * fd3 + ... )
    """
[docs]    def __init__(self, value, uncertainty=NAN):
        """
        """
        self.val = float(value)
        self.unc = float(uncertainty) 
    def __repr__(self):
        """
        """
        return "Measurement( %g, %g )" % (
            self.val,
            self.unc,
        )
    def __eq__(self, rhs):
        if not isinstance(rhs, Measurement):
            return False
        else:
            return _is_equal(self.val, rhs.val) and _is_equal(self.unc, rhs.unc)
    def __str__(self):
        """
        """
        # FIXME: Extra digits after the first significant figure in uncertainty should be dropped, for example,
        #        instead of writing "5.345+-0.212", we should write "5.3+-0.2".
        # FIXME: For scientific notation, make sure the value and the uncertainty use the same power of 10, for example,
        #        instead of writing "5.34E-7+-3E-9", we should write "(5.34+-0.03)E-7".
        return "%g+-%g" % (
            self.val,
            self.unc,
        )
    def __float__(self):
        """
        """
        return self.val
    def __int__(self):
        """
        """
        return int(self.val)
    def __tuple__(self):
        """
        """
        return (
            self.val,
            self.unc,
        )
    def __add__(self, rhs):
        """
        """
        if (isinstance(rhs, tuple)):
            rhs = Measurement(rhs[0], rhs[1])
        if (isinstance(rhs, Measurement)):
            val = self.val + rhs.val
            unc = math.sqrt(self.unc * self.unc + rhs.unc * rhs.unc)
            return Measurement(val, unc)
        else:
            return Measurement(self.val + rhs, self.unc)
    def __radd__(self, lhs):
        """
        """
        return self.__add__(lhs)
    def __sub__(self, rhs):
        """
        """
        return self.__add__(-rhs)
    def __rsub__(self, lhs):
        """
        """
        return (-self) + lhs
    def __neg__(self):
        """
        """
        return Measurement(-self.val, self.unc)
    def __pos__(self):
        """
        """
        return Measurement(self.val, self.unc)
    def __abs__(self):
        """
        """
        return Measurement(abs(self.val), self.unc)
    def __mul__(self, rhs):
        """
        """
        if (isinstance(rhs, tuple)):
            rhs = Measurement(rhs[0], rhs[1])
        if (isinstance(rhs, Measurement)):
            val = self.val * rhs.val
            unc = math.sqrt(self.val * self.val * rhs.unc * rhs.unc +
                            self.unc * self.unc * rhs.val * rhs.val)
            return Measurement(val, unc)
        else:
            return Measurement(self.val * rhs, math.fabs(self.unc * rhs))
    def __rmul__(self, lhs):
        """
        """
        return self * lhs
    def __div__(self, rhs):
        """
        """
        if (isinstance(rhs, tuple)):
            rhs = Measurement(rhs[0], rhs[1])
        if (isinstance(rhs, Measurement)):
            val = old_div(self.val, (rhs.val * rhs.val))
            unc = old_div(1.0, rhs.val)
            unc = math.sqrt(val * val * rhs.unc * rhs.unc +
                            unc * unc * self.unc * self.unc)
            return Measurement(old_div(self.val, rhs.val), unc)
        else:
            return Measurement(old_div(self.val, rhs),
                               math.fabs(old_div(self.unc, rhs)))
    def __rdiv__(self, lhs):
        """
        """
        return Measurement(1.0, 0.0) / self * lhs
    def __truediv__(self, rhs):
        return self.__div__(rhs)
[docs]    def to_string(self, num_digits):
        """
        Convert `Measurement` to string with given `num_digits` decimal places
        """
        fmt = '{{:.{0}f}}+-{{:.{0}f}}'.format(num_digits)
        return fmt.format(self.val, self.unc) 
[docs]    @classmethod
    def from_string(cls, s):
        """
        Convert string (e.g: "5.0+-0.3") to `Measurement`
        """
        i = s.find("+-")
        value = float(s[:i])
        uncertainty = float(s[i + 2:])
        return cls(value, uncertainty)  
[docs]def string2measurement(s):
    """
    Convert string (e.g: "5.0+-0.3") to `Measurement`
    This method is deprecated in favor of Measurement.from_string
    """
    return Measurement.from_string(s) 
def __test():
    """
    """
    print("Testing `Measurement' class:")
    a = Measurement(30.0, 0.5)
    b = Measurement(20.0, 0.8)
    print("a =", a)
    print("b =", b)
    print("float( a ) =", float(a))
    print("int  ( a ) =", int(a))
    print("-a =", -a)
    print("+b =", +b)
    print("abs( -a ) =", abs(-a))
    print("a + 1 =", a + 1)
    print("a - 1 =", a - 1)
    print("a + b =", a + b)
    print("a - b =", a - b)
    print("a + (20.0, 0.8) =", a + (20.0, 0.8))
    print("a - (20.0, 0.8) =", a + (20.0, 0.8))
    print("1 + a =", 1 + a)
    print("1 - a =", 1 - a)
    print("(20.0, 0.8) + a =", (20.0, 0.8) + a)
    print("(20.0, 0.8) - a =", (20.0, 0.8) - a)
    print("a * 2 =", a * 2)
    print("a / 2 =", old_div(a, 2))
    print("a * b =", a * b)
    print("a / b =", old_div(a, b))
    print("a * (20.0, 0.8) =", a * (20.0, 0.8))
    print("a / (20.0, 0.8) =", old_div(a, (20.0, 0.8)))
    print("2 * a =", 2 * a)
    print("2 / a =", old_div(2, a))
    print("(20.0, 0.8) * a =", (20.0, 0.8) * a)
    print("(30.0, 0.5) / b =", old_div((30.0, 0.5), b))
    print("(a + b) * (a - b) =", (a + b) * (a - b))
    print("(a - b) / (a + b) =", old_div((a - b), (a + b)))
    a = string2measurement("1.2+-0.3")
    print("1.2+-0.3 =", a)
# Correct results:
#
# a = Measurement( 30.0, 0.5 )
# b = Measurement( 20.0, 0.8 )
#
# a = 30+-0.5
# b = 20+-0.8
# float( a ) = 30.0
# int  ( a ) = 30
# -a = -30+-0.5
# +b = 20+-0.8
# abs( -a ) = 30+-0.5
# a + 1 = 31+-0.5
# a - 1 = 29+-0.5
# a + b = 50+-0.943398
# a - b = 10+-0.943398
# a + (20.0, 0.8) = 50+-0.943398
# a - (20.0, 0.8) = 50+-0.943398
# 1 + a = 31+-0.5
# 1 - a = -29+-0.5
# (20.0, 0.8) + a = 50+-0.943398
# (20.0, 0.8) - a = -10+-0.943398
# a * 2 = 60+-1
# a / 2 = 15+-0.25
# a * b = 600+-26
# a / b = 1.5+-0.065
# a * (20.0, 0.8) = 600+-26
# a / (20.0, 0.8) = 1.5+-0.065
# 2 * a = 60+-1
# 2 / a = 0.0666667+-0.00111111
# (20.0, 0.8) * a = 600+-26
# (30.0, 0.5) / b = 1.5+-0.065
# (a + b) * (a - b) = 500+-48.1041
# (a - b) / (a + b) = 0.2+-0.0192416
if (__name__ == "__main__"):
    __test()