# Jack Eckert

# dependencies
import math
from typing import Any

# base class for functions

# arg is the argument of the function, which by default is the identity function represented
# by the string 'x'. Any string will work and acts as a marker to allow the program to identify
# the identity function.

# coef is the coefficient of the function, should always be an integer or float and by default
# is one. 
class BaseFunction():
    def __init__(self, coef=None, arg=None):
        if coef == None:
            self.__coefficient = 1
        else:
            self.__coefficient = coef
        
        if arg == None:
            self.__argument = 'x'
        else:
            self.__argument = arg

    def getArg(self):
        return self.__argument
    
    def setArg(self, new):
        self.__argument = new

    def getCoef(self):
        return self.__coefficient
    
    def setCoef(self, new):
        self.__coefficient = new

    # used to consolidate an outside constant (k) into the function's coefficient
    def integrateCoef(self, k):
        self.__coefficient *= k

    def __neg__(self):
        return type(self)(-self.getCoef(), self.getArg())

    def __eq__(self, other):
        if not(isinstance(other, type(self))):
            return False
        return (self.getArg() == other.getArg() and self.getCoef() == other.getCoef())

    def __str__(self):
        if self.getCoef() == 1:
            coefficientString = ''
        elif self.getCoef() == -1:
            coefficientString = '-'
        else:
            coefficientString = round(self.getCoef(), 3)
        return coefficientString
    
    def __repr__(self):
        return str(self)

    def derive(self, shell):
        if isinstance(self.getArg(), str):
            return shell
        else:
            return [shell.getArg().derive(), shell]

# power function, i.e. a function raised to a constant power.
class Power(BaseFunction):
    def __init__(self, coef=None, arg=None, exp=1):
        super().__init__(coef, arg)
        self.__exponent = exp

    def getExp(self):
        return self.__exponent
    
    def setExp(self):
        return self.__exponent
    
    def __neg__(self):
        return type(self)(-self.getCoef(), self.getArg(), self.getExp())
    
    def __eq__(self, other):
        return (super().__eq__(other) and self.getExp() == other.getExp())
    
    def __ne__(self, other):
        return not(self.__eq__(other))
    
    def __str__(self):
        coefficientString = super().__str__()

        # parentheses formatting
        if not(isinstance(self.getArg(), str)):
            argumentString = f"({self.getArg()})"
        else:
            argumentString = self.getArg()

        # exponent formatting
        if self.getExp() == 1:
            return f"{coefficientString}{argumentString}"

        return f"{coefficientString}{argumentString}^{self.getExp()}"

    def derive(self):
        # base framework for the derivative, n * f(x)^(n - 1)
        if self.getExp() == 1:
            shell = self.getCoef()
        else:
            shell = Power(self.getCoef() * self.getExp(), self.getArg(), self.getExp() - 1)

        # derivation
        return super().derive(shell)
        
# exponential function, i.e. a constant raised to a function power.
class Exponential(BaseFunction):
    def __init__(self, coef=None, arg=None, base=math.e):
        super().__init__(coef, arg)
        if base != math.e:
            if isinstance(self.getArg(), str):
                self.setArg(Power(math.log(base)))
            else:
                self.integrateCoef(math.log(base))
        self.__base = base

    def getBase(self):
        return self.__base
    
    def setBase(self, new):
        self.__base = new

    def __neg__(self):
        return type(self)(-self.getCoef(), self.getArg(), self.getBase())

    def __str__(self):
        coefficientString = super().__str__()
        return f"{coefficientString}exp({self.getArg()})"
    
    def __eq__(self, other):
        return (super().__eq__(other) and self.getBase() == other.getBase())
    
    def derive(self):
        shell = self
        return super().derive(shell)

# sine function  
class Sine(BaseFunction):
    def __init__(self, coef=None, arg=None):
        super().__init__(coef, arg)

    def __str__(self):
        coefficientString = super().__str__()
        return f"{coefficientString}sin({self.getArg()})"
    
    def derive(self):
        shell = Cosine(self.getCoef(), self.getArg())
        return super().derive(shell)

# cosine function 
class Cosine(BaseFunction):
    def __init__(self, coef=None, arg=None):
        super().__init__(coef, arg)

    def __str__(self):
        coefficientString = super().__str__()
        return f"{coefficientString}cos({self.getArg()})"
    
    def derive(self):
        shell = -Sine(self.getCoef(), self.getArg())
        return super().derive(shell)

class Equation():
    def __init__(self, eqList):
        self.__eqList = eqList
        for i, term in enumerate(self.__eqList):
            if not(isinstance(term, list)):
                self.__eqList[i] = [term]

    def __iter__(self):
        self.index = -1
        return self
    
    def __next__(self):
        if self.index < len(self.__eqList) - 1:
            self.index += 1
            return self.__eqList[self.index]
        else:
            raise StopIteration

    def getEqList(self):
        return self.__eqList
    
    def setEqList(self, new):
        self.__eqList = new

    def __getitem__(self, i):
        return self.__eqList[i]
    
    def __setitem__(self, i, new):
        self.__eqList[i] = new

    def addTerm(self, new):
        if not(isinstance(new, list)):
            new = [new]
        self.__eqList.append(new)

    def removeAt(self, i):
        del self.__eqList[i]

    def __str__(self):
        strList = []
        for term in self:
            termString = ''
            for seg in term:
                termString += f"({seg})"
            strList.append(termString)
        return " + ".join(strList)
    
    def __mul__(self, k):
        copy = Equation(self.__eqList)
        for term in copy:
            term.insert(0, k)
        return copy
    
    def __rmul__(self, k):
        copy = Equation(self.__eqList)
        for term in copy:
            term.insert(0, k)
        return copy
    
    def __add__(self, k):
        copy = Equation(self.__eqList)
        copy.addTerm(k)
        return copy

    def __radd__(self, k):
        copy = Equation(self.__eqList)
        copy.addTerm(k)
        return copy
    
    def __sub__(self, k):
        return self + -k
    
    def __rsub__(self, k):
        return -self + k
    
    #TODO
    def simplify(self):
        pass

# Testing
if __name__ == "__main__":
    eq1 = Equation([[Power(exp=2), Cosine()], Sine()])
    print(eq1 - Cosine())