Coverage for src/methodsnm/diff.py: 96%

48 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-27 13:22 +0000

1from typing import Tuple, Type, Union 

2import numpy as np 

3 

4class AD: 

5 """ 

6 This class represents a datatype that can automatically generate a derivative for a 

7 composition of supplied founding elements 

8 """ 

9 def __init__(self, val: float, deriv: np.ndarray) -> None: 

10 """ 

11 val: array_like, single value or numpy array, that holds one ore moe values for the element 

12 deriv: array_like, holds the derivative of the element 

13 """ 

14 self.val = val 

15 self.deriv = deriv 

16 

17 def __add__(self, other: Union[float, 'AD']) -> 'AD': 

18 """ 

19 returns self + other 

20 """ 

21 if isinstance(other, float) or isinstance(other, int): 

22 return AD(self.val + other, self.deriv) 

23 return AD(self.val + other.val, self.deriv + other.deriv) 

24 

25 def __radd__(self, other: float) -> 'AD': 

26 """ 

27 returns other + self 

28 """ 

29 return AD(other + self.val, self.deriv) 

30 

31 def __sub__(self, other: Union[float, 'AD']) -> 'AD': 

32 """ 

33 returns self - other 

34 """ 

35 if isinstance(other, AD): 

36 return AD(self.val - other.val, self.deriv - other.deriv) 

37 return AD(self.val - other, self.deriv) 

38 #raise NotImplementedError(f"not implemented for other: {type(other.dtype)}") 

39 

40 def __rsub__(self, other: float) -> 'AD': 

41 """ 

42 returns other - self 

43 """ 

44 return AD(other - self.val, -self.deriv) 

45 

46 def __mul__(self, other: Union[float, 'AD']) -> 'AD': 

47 """ 

48 returns self * other 

49 """ 

50 if isinstance(other, float) or isinstance(other, int): 

51 return AD(self.val * other, other*self.deriv) 

52 return AD(self.val * other.val, self.val * other.deriv + self.deriv*other.val) 

53 

54 def __rmul__(self, other: float) -> 'AD': 

55 """ 

56 returns other * self 

57 """ 

58 return AD(other * self.val, other * self.deriv) 

59 

60 def __pow__(self, p: int) -> 'AD': 

61 """ 

62 returns self^p 

63 """ 

64 if p == -1: 

65 if (self.deriv <= 0).any(): 

66 raise ValueError("Invalid value of derivative. Must be positive.") 

67 else: 

68 return AD(np.power(self.val, p), np.log(self.deriv)) 

69 if p == 0: 

70 return AD(1., np.zeros_like(self.deriv)) 

71 if p == 1: 

72 return self 

73 return AD(np.power(self.val, p), p * (self.val**(p-1)) * self.deriv) 

74 

75 def __truediv__(self, other: Union[float, 'AD']) -> 'AD': 

76 """ 

77 returns self / other 

78 """ 

79 if isinstance(other, float) or isinstance(other, int): 

80 return AD(self.val / other, self.deriv/other) 

81 if other.val == 0: 

82 raise ZeroDivisionError("You can't divide by zero!") 

83 else: 

84 return AD(self.val / other.val, (other.val * self.deriv - self.val * other.deriv) / other.val**2) 

85 

86 def __rtruediv__(self, other: float) -> 'AD': 

87 """ 

88 returns other / self 

89 """ 

90 if self.val == 0: 

91 raise ZeroDivisionError("You can't divide by zero!") 

92 else: 

93 return AD(other / self.val, - other * self.deriv / self.val**2) 

94 

95 

96 def __str__(self): # pragma: no cover 

97 return f"AD({self.val}, {self.deriv})" 

98 def __repr__(self): # pragma: no cover 

99 return f"AD({self.val}, {self.deriv})" 

100 def __eq__(self, other: Union['AD', float]): 

101 if isinstance(other, float) or isinstance(other, int): 

102 return np.isclose(self.val, other) 

103 return np.isclose(self.val, other.val) and np.isclose(self.deriv, other.deriv).all() 

104 

105 

106