import numpy as np

def thermaldiode(x):
    """
    Converts temperature sensitive silicon diode voltages to temperature using
    fitted functions and returns a tuple of temperature and temperature error
    arrays.

    @type x: np.array | int | float | list
    @rtype: tuple(np.array, np.array)

    Allowed input_unit values: "V", "mV"
    Allowed output_unit values: "K", "C", "F"

    For emf readings corresponding to temperatures between 2 K and 100 K, error
    is 0.25 K.
    For emf readings corresponding to temperatures between 100 K and 300 K, error
    is 0.3 K.
    **Note: errors have been increased by 0.05 K to account for discrepancies in
      in fit of calibration curve**
    **All input voltages MUST BE positive**

    Written by: Katrina Hooper
    Date: 03-02-2017
    Required packages: numpy
    """

### Checks if type of x is iterable, if not, it creates a list containing x ###
    if ((type(x) == float) or (type(x) == int)):
        x = [x]

    temp = np.zeros(len(x))
    temp_err = np.zeros(len(x))
    
""" For each voltage in x, it uses appropriate fitting function to find the
 temperature at the voltage and gives appropriate error in temperature """
    for i in range(len(x)):
        
        if x[i] < 0.82308:
            temp[i] = coss(x[i], *par1)
            temp_err[i] = 0.35
            
        elif 0.82308 <= x[i] < 1.0542:
            temp[i] = cubic(x[i], *par2)
            temp_err[i] = 0.35
            if temp[i] < 100:
                temp_err[i] = 0.3

        elif 1.0542 <= x[i] <= 1.1473:
            temp[i] = coss(x[i], *par3)
            temp_err[i] = 0.3

        elif 1.1473 < x[i] <= 1.8539:
            temp[i] = cubic(x[i], *par4)
            temp_err[i] = 0.3

        elif x[i] > 1.8539:
            temp[i] = coss(x[i], *par5)
            temp_err[i] = 0.3

    return (temp, temp_err)

### Fitting functions ###
def coss(x, a, b, c, d, e):
    return a * x + b * np.cos(c*x +d) + e
def cubic(x, a, b, c, d, e, f, g):
    return a * (x-b)**3 + c * (x-d)**2 + e + f*x + g

### Parameters for fitting functions found using scipy.optimize.curve_fit ###
par1 = np.array([-364.44364646,    0.62360251,  -17.49716892,  -24.31045687,
                 439.47840062])
par2 = np.array([-638.38887851,    0.7968693 ,  139.77185052,    1.99087984,
                 2.71879039,  -48.26168127,  -13.51226062])
par3 = np.array([-254.2274026 ,    4.43251716,  -44.01866014,   20.22202958,
                 315.79583499])
par4 = np.array([-34.57457284,   1.4199206 ,  12.46905518,  -0.18175058,
                 32.3424489 , -53.68686765,  32.3432621 ])
par5 = np.array([-16.29864235,  -0.32427019,  -8.26553791,  11.74401258,
                 43.93899782])
