Coverage for src/methodsnm/diff.py: 96%
48 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-27 13:22 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-27 13:22 +0000
1from typing import Tuple, Type, Union
2import numpy as np
4class AD:
5 """
6 This class represents a datatype that can automatically generate a derivative for a
7 composition of supplied founding elements
8 """
9 def __init__(self, val: float, deriv: np.ndarray) -> None:
10 """
11 val: array_like, single value or numpy array, that holds one ore moe values for the element
12 deriv: array_like, holds the derivative of the element
13 """
14 self.val = val
15 self.deriv = deriv
17 def __add__(self, other: Union[float, 'AD']) -> 'AD':
18 """
19 returns self + other
20 """
21 if isinstance(other, float) or isinstance(other, int):
22 return AD(self.val + other, self.deriv)
23 return AD(self.val + other.val, self.deriv + other.deriv)
25 def __radd__(self, other: float) -> 'AD':
26 """
27 returns other + self
28 """
29 return AD(other + self.val, self.deriv)
31 def __sub__(self, other: Union[float, 'AD']) -> 'AD':
32 """
33 returns self - other
34 """
35 if isinstance(other, AD):
36 return AD(self.val - other.val, self.deriv - other.deriv)
37 return AD(self.val - other, self.deriv)
38 #raise NotImplementedError(f"not implemented for other: {type(other.dtype)}")
40 def __rsub__(self, other: float) -> 'AD':
41 """
42 returns other - self
43 """
44 return AD(other - self.val, -self.deriv)
46 def __mul__(self, other: Union[float, 'AD']) -> 'AD':
47 """
48 returns self * other
49 """
50 if isinstance(other, float) or isinstance(other, int):
51 return AD(self.val * other, other*self.deriv)
52 return AD(self.val * other.val, self.val * other.deriv + self.deriv*other.val)
54 def __rmul__(self, other: float) -> 'AD':
55 """
56 returns other * self
57 """
58 return AD(other * self.val, other * self.deriv)
60 def __pow__(self, p: int) -> 'AD':
61 """
62 returns self^p
63 """
64 if p == -1:
65 if (self.deriv <= 0).any():
66 raise ValueError("Invalid value of derivative. Must be positive.")
67 else:
68 return AD(np.power(self.val, p), np.log(self.deriv))
69 if p == 0:
70 return AD(1., np.zeros_like(self.deriv))
71 if p == 1:
72 return self
73 return AD(np.power(self.val, p), p * (self.val**(p-1)) * self.deriv)
75 def __truediv__(self, other: Union[float, 'AD']) -> 'AD':
76 """
77 returns self / other
78 """
79 if isinstance(other, float) or isinstance(other, int):
80 return AD(self.val / other, self.deriv/other)
81 if other.val == 0:
82 raise ZeroDivisionError("You can't divide by zero!")
83 else:
84 return AD(self.val / other.val, (other.val * self.deriv - self.val * other.deriv) / other.val**2)
86 def __rtruediv__(self, other: float) -> 'AD':
87 """
88 returns other / self
89 """
90 if self.val == 0:
91 raise ZeroDivisionError("You can't divide by zero!")
92 else:
93 return AD(other / self.val, - other * self.deriv / self.val**2)
96 def __str__(self): # pragma: no cover
97 return f"AD({self.val}, {self.deriv})"
98 def __repr__(self): # pragma: no cover
99 return f"AD({self.val}, {self.deriv})"
100 def __eq__(self, other: Union['AD', float]):
101 if isinstance(other, float) or isinstance(other, int):
102 return np.isclose(self.val, other)
103 return np.isclose(self.val, other.val) and np.isclose(self.deriv, other.deriv).all()