Coverage for src/methodsnm/visualize.py: 11%
161 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-27 13:22 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-27 13:22 +0000
1from methodsnm.fe_1d import *
2from methodsnm.fe_2d import *
3from methodsnm.meshfct import *
4import numpy as np
5import matplotlib.pyplot as plt
6#import pylab as plt
7from math import ceil, sqrt
9def DrawSegmentFE(fe, sampling=100, xkcd_style=False, derivative=False):
10 if not isinstance(fe, FE_1D):
11 raise ValueError("fe must be an instance of FE_1D")
12 xvals = np.linspace(0, 1, sampling).reshape((sampling, 1))
13 if derivative:
14 yvals = np.array([fe.evaluate_deriv(xi) for xi in xvals])
15 else:
16 yvals = fe.evaluate(xvals)
17 #yvals = np.array([fe.evaluate(xi) for xi in xvals])
18 plt.style.use("fivethirtyeight")
19 if xkcd_style:
20 plt.xkcd()
21 plt.plot(xvals, yvals)
22 plt.legend(["$\\phi_{"+str(i)+"}$" for i in range(fe.ndof)])
23 plt.show()
25# identify best subplot pattern:
26def identify_best_subplot_pattern(L):
27 """
28 Given the number of subplots L, this function identifies the best subplot pattern
29 to use for a plot with L subplots. The function returns the number of rows and columns
30 for the subplot grid.
32 Args:
33 - L: int, the number of subplots
35 Returns:
36 - N: int, the number of rows in the subplot grid
37 - M: int, the number of columns in the subplot grid
38 """
39 patterns = [(1,1),(2,1),(3,1),(3,2),(3,3),(5,2)]
40 s_of_p = [None for p in patterns]
41 D_of_p = [None for p in patterns]
42 for pi, p in enumerate(patterns):
43 ps = p[0] * p[1]
44 D = int(ceil(sqrt(L/ps)))
45 s_of_p[pi] = ps * D**2
46 D_of_p[pi] = D
47 mini = s_of_p.index(min(s_of_p))
48 pM,pN = patterns[mini]
49 D = D_of_p[mini]
50 N,M = pN*D, pM*D
51 return N,M
53import matplotlib.tri as mtri
54def DrawTriangleFE(fe, sampling=10, contour=False, figsize=(10,6)):
55 x0vals = np.array([0,1,0,0])
56 y0vals = np.array([0,0,1,0])
58 xvals = np.array([j/sampling for i in range(sampling+1) for j in range(sampling+1-i)])
59 yvals = np.array([i/sampling for i in range(sampling+1) for j in range(sampling+1-i)])
61 # fe vals on grid
62 fevals = np.array([fe.evaluate(np.array([xx,yy])) for (xx,yy) in zip(xvals,yvals)])
64 n_grid_pts = (sampling+1)*(sampling+2)//2
66 # map from (i,j) to n and vice versa:
67 cnt = 0
68 n2ij = []
69 ij2n = [[None for j in range(sampling+1)] for i in range(sampling+1)]
70 for i in range(sampling+1):
71 for j in range(sampling+1-i):
72 n2ij.append((i,j))
73 ij2n[i][j] = cnt
74 cnt += 1
76 # triangles for plotting:
77 trigs = [[ij2n[i][j], ij2n[i][j+1], ij2n[i+1][j]] for j in range(sampling) for i in range(sampling-j)]
78 trigs += [[ij2n[i][j+1], ij2n[i+1][j+1], ij2n[i+1][j]] for j in range(sampling-1) for i in range(sampling-j-1)]
80 N,M = identify_best_subplot_pattern(fe.ndof)
81 plt.figure(figsize=figsize)
83 for dof in range(fe.ndof):
84 m = dof % M; n = dof // M
85 if contour:
86 ax = plt.subplot2grid((N,M),(n, m))
87 else:
88 ax = plt.subplot2grid((N,M),(n, m), projection='3d')
89 if not contour:
90 for i in range(3):
91 ax.plot(x0vals[i:i+2], y0vals[i:i+2], linewidth=2.0, color="black", antialiased=True)
92 ax.plot_trisurf(xvals, yvals, trigs, fevals[:,dof], cmap=plt.cm.Spectral, linewidth=0.0, antialiased=True)
93 else:
94 triang = mtri.Triangulation(xvals, yvals, trigs)
95 tcf = ax.tricontourf(triang, fevals[:,dof])
96 plt.colorbar(tcf)
97 for i in range(3):
98 ax.plot(x0vals[i:i+2], y0vals[i:i+2], linewidth=2.0, color="black", antialiased=True)
99 plt.tight_layout(pad=1.0)
100 plt.show()
102def DrawMesh2D(mesh):
103 x_v = mesh.points
104 trigs = mesh.elements()
105 plt.triplot(x_v[:,0],x_v[:,1],trigs, 'ko-')
106 plt.show()
109def DrawMesh1D(mesh):
110 x_v = mesh.points
111 plt.plot(x_v,np.zeros(len(x_v)),'|',label='points')
112 plt.xlabel("x")
113 plt.legend()
114 plt.show()
116def DrawFunction1D(f, sampling = 10, mesh = None, show_mesh = False):
117 if not isinstance(f, list):
118 DrawFunction1D([f], sampling, mesh=mesh, show_mesh=show_mesh)
119 return
120 if not all([isinstance(fi, MeshFunction) for fi in f]):
121 raise ValueError("f must be a list of MeshFunction instances")
122 if mesh is None:
123 mesh = f[0].mesh
124 xy = []
125 for elnr,vs in enumerate(mesh.elements()):
126 trafo = mesh.trafo(elnr)
127 xl, xr = mesh.points[vs]
128 xy += [[xl*(1-ip_x) + xr*ip_x] + [fi.evaluate(np.array([ip_x]),trafo) for fi in f] for ip_x in np.linspace(0,1,sampling)]
129 xy = np.array(xy)
130 plt.plot(xy[:,0],xy[:,1::],'-')
131 plt.xlabel("x")
132 #plt.legend()
133 if show_mesh:
134 plt.plot(mesh.points,np.zeros(len(mesh.points)),'|',label='points')
135 plt.show()
137def RefTriangleIPs(sampling, shrink_eps = 0):
139 n_grid_pts = (sampling+1)*(sampling+2)//2
141 # map from (i,j) to n and vice versa:
142 cnt = 0
143 n2ij = []
144 ij2n = [[None for j in range(sampling+1)] for i in range(sampling+1)]
145 for i in range(sampling+1):
146 for j in range(sampling+1-i):
147 n2ij.append((i,j))
148 ij2n[i][j] = cnt
149 cnt += 1
151 # triangles for plotting:
152 trigs = [[ij2n[i][j], ij2n[i][j+1], ij2n[i+1][j]] for j in range(sampling) for i in range(sampling-j)]
153 trigs += [[ij2n[i][j+1], ij2n[i+1][j+1], ij2n[i+1][j]] for j in range(sampling-1) for i in range(sampling-j-1)]
154 trigs = np.array(trigs)
156 # integration points:
157 if shrink_eps > 0:
158 ips = np.array([[(1-shrink_eps)*i/sampling+shrink_eps/3,(1-shrink_eps)*j/sampling+shrink_eps/3] for i in range(sampling+1) for j in range(sampling+1-i)])
159 else:
160 ips = np.array([[i/sampling,j/sampling] for i in range(sampling+1) for j in range(sampling+1-i)])
161 return ips, trigs
163from matplotlib import colors as pltcolor
164def DrawFunction2D(f, sampling = 10, mesh = None, show_mesh = False,
165 vmin=None, vmax=None, shrink_eps = 0, contour=False, figsize=(10,6)):
166 if not isinstance(f, list):
167 DrawFunction2D([f], sampling, mesh=mesh, show_mesh=show_mesh,
168 vmin=vmin, vmax=vmax, shrink_eps=shrink_eps,
169 contour=contour, figsize=figsize)
170 return
171 if not all([isinstance(fi, MeshFunction) for fi in f]):
172 raise ValueError("f must be a list of MeshFunction instances")
173 if mesh is None:
174 mesh = f[0].mesh
176 ref_ips, ref_trigs = RefTriangleIPs(sampling, shrink_eps=shrink_eps)
178 ne = len(mesh.elements())
179 allips = np.empty((ref_ips.shape[0]*ne,2))
180 n_ips_block = len(ref_ips)
181 allfevals = np.empty(ref_ips.shape[0]*ne)
182 alltrigs = np.empty((ref_trigs.shape[0]*ne,3),dtype=int)
184 for elnr, verts in enumerate(mesh.elements()):
185 trafo = mesh.trafo(elnr)
186 ips = trafo(ref_ips)
187 allips[elnr*n_ips_block:(elnr+1)*n_ips_block,:] = ips
188 fevals = f[0].evaluate(ref_ips,trafo)
189 allfevals[elnr*n_ips_block:(elnr+1)*n_ips_block] = fevals
190 alltrigs[elnr*ref_trigs.shape[0]:(elnr+1)*ref_trigs.shape[0],:] = ref_trigs + elnr*n_ips_block
192 if vmin is None:
193 vmin = np.min(allfevals)
194 if vmax is None:
195 vmax = np.max(allfevals)
197 plt.figure(figsize=figsize)
198 if not contour:
199 ax = plt.subplot(111, projection='3d')
200 ax.plot_trisurf(allips[:,0], allips[:,1], alltrigs, allfevals, cmap=plt.cm.jet, linewidth=0.0, antialiased=True, vmin=vmin, vmax=vmax)
201 else:
202 triang = mtri.Triangulation(allips[:,0], allips[:,1], alltrigs)
203 tcf = plt.tricontourf(triang, allfevals, cmap=plt.cm.jet, vmin=vmin, vmax=vmax)
204 plt.colorbar(tcf)
205 plt.tight_layout(pad=1.0)
206 plt.show()
209def DrawShapes(fes, sampling = 10):
210 uhs = [FEFunction(fes) for i in range(fes.ndof)]
211 for i in range(fes.ndof):
212 uhs[i].vector[i] = 1
213 DrawFunction1D(uhs, sampling=sampling)
215if __name__ == "__main__":
216 p1 = P1_Segment_FE()
217 DrawSegmentFE(p1, sampling=10)
218 p1 = P1_Triangle_FE()
219 DrawTriangleFE(p1, sampling=10)