diff --git a/demos/inverse/example.py b/demos/inverse/example.py index b709de6..b799614 100644 --- a/demos/inverse/example.py +++ b/demos/inverse/example.py @@ -23,13 +23,14 @@ def get_tensor_map(self): def psi(F, rho): E = self.E * rho nu = 0.3 - mu = E/(2.*(1. + nu)) - kappa = E/(3.*(1. - 2.*nu)) + mu = E / (2.0 * (1.0 + nu)) + kappa = E / (3.0 * (1.0 - 2.0 * nu)) J = np.linalg.det(F) - Jinv = J**(-2./3.) + Jinv = J ** (-2.0 / 3.0) I1 = np.trace(F.T @ F) - energy = (mu/2.)*(Jinv*I1 - 3.) + (kappa/2.) * (J - 1.)**2. + energy = (mu / 2.0) * (Jinv * I1 - 3.0) + (kappa / 2.0) * (J - 1.0) ** 2.0 return energy + P_fn = jax.grad(psi) def first_PK_stress(u_grad, rho): @@ -37,11 +38,12 @@ def first_PK_stress(u_grad, rho): F = u_grad + I P = P_fn(F, rho) return P + return first_PK_stress def get_surface_maps(self): def surface_map(u, x): - return np.array([0., 0., 1e3]) + return np.array([0.0, 0.0, 1e3]) return [surface_map] @@ -54,77 +56,97 @@ def set_params(self, params): # Specify mesh-related information (first-order hexahedron element). -ele_type = 'HEX8' +ele_type = "HEX8" cell_type = get_meshio_cell_type(ele_type) -data_dir = os.path.join(os.path.dirname(__file__), 'data') -Lx, Ly, Lz = 1., 1., 1. -meshio_mesh = box_mesh(Nx=5, Ny=5, Nz=5, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type) +data_dir = os.path.join(os.path.dirname(__file__), "data") +Lx, Ly, Lz = 1.0, 1.0, 1.0 +meshio_mesh = box_mesh( + Nx=5, Ny=5, Nz=5, Lx=Lx, Ly=Ly, Lz=Lz, data_dir=data_dir, ele_type=ele_type +) mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) # Define Dirichlet boundary values. def get_dirichlet_bottom(scale): def dirichlet_bottom(point): - z_disp = scale*Lz + z_disp = scale * Lz return z_disp + return dirichlet_bottom + def zero_dirichlet_val(point): - return 0. + return 0.0 # Define boundary locations. def bottom(point): - return np.isclose(point[2], 0., atol=1e-5) + return np.isclose(point[2], 0.0, atol=1e-5) + def top(point): return np.isclose(point[2], Lz, atol=1e-5) -dirichlet_bc_info = [[bottom]*3, [0, 1, 2], [zero_dirichlet_val]*2 + [get_dirichlet_bottom(1.)]] + +dirichlet_bc_info = [ + [bottom] * 3, + [0, 1, 2], + [zero_dirichlet_val] * 2 + [get_dirichlet_bottom(1.0)], +] location_fns = [top] # Create an instance of the problem. -problem = HyperElasticity(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns) +problem = HyperElasticity( + mesh, + vec=3, + dim=3, + ele_type=ele_type, + dirichlet_bc_info=dirichlet_bc_info, + location_fns=location_fns, +) # Define parameters. -rho = 0.5*np.ones((problem.fe.num_cells, problem.fe.num_quads)) -E = 1.e6 -scale_d = 1. +rho = 0.5 * np.ones((problem.fe.num_cells, problem.fe.num_quads)) +E = 1.0e6 +scale_d = 1.0 params = [E, rho, scale_d] # Implicit differentiation wrapper -fwd_pred = ad_wrapper(problem) +fwd_pred = ad_wrapper(problem) sol_list = fwd_pred(params) -vtk_path = os.path.join(data_dir, f'vtk/u.vtu') +vtk_path = os.path.join(data_dir, f"vtk/u.vtu") save_sol(problem.fe, sol_list[0], vtk_path) + def test_fn(sol_list): - return np.sum(sol_list[0]**2) + return np.sum(sol_list[0] ** 2) + def composed_fn(params): return test_fn(fwd_pred(params)) + val = test_fn(sol_list) -h = 1e-3 # small perturbation +h = 1e-3 # small perturbation # Forward difference -E_plus = (1 + h)*E +E_plus = (1 + h) * E params_E = [E_plus, rho, scale_d] -dE_fd = (composed_fn(params_E) - val)/(h*E) +dE_fd = (composed_fn(params_E) - val) / (h * E) -rho_plus = rho.at[0, 0].set((1 + h)*rho[0, 0]) +rho_plus = rho.at[0, 0].set((1 + h) * rho[0, 0]) params_rho = [E, rho_plus, scale_d] -drho_fd_00 = (composed_fn(params_rho) - val)/(h*rho[0, 0]) +drho_fd_00 = (composed_fn(params_rho) - val) / (h * rho[0, 0]) -scale_d_plus = (1 + h)*scale_d +scale_d_plus = (1 + h) * scale_d params_scale_d = [E, rho, scale_d_plus] -dscale_d_fd = (composed_fn(params_scale_d) - val)/(h*scale_d) +dscale_d_fd = (composed_fn(params_scale_d) - val) / (h * scale_d) # Derivative obtained by automatic differentiation @@ -132,9 +154,12 @@ def composed_fn(params): # Comparison -print(f"\nDerivative comparison between automatic differentiation (AD) and finite difference (FD)") -print(f"\ndE = {dE}, dE_fd = {dE_fd}, WRONG results! Please avoid gradients w.r.t self.E") +print( + f"\nDerivative comparison between automatic differentiation (AD) and finite difference (FD)" +) +print( + f"\ndE = {dE}, dE_fd = {dE_fd}, WRONG results! Please avoid gradients w.r.t self.E" +) print(f"This is due to the use of glob variable self.E, inside a jax jitted function.") print(f"\ndrho[0, 0] = {drho[0, 0]}, drho_fd_00 = {drho_fd_00}") print(f"\ndscale_d = {dscale_d}, dscale_d_fd = {dscale_d_fd}") - diff --git a/demos/topology_optimization/animation.py b/demos/topology_optimization/animation.py index 82ab7b3..4af433f 100644 --- a/demos/topology_optimization/animation.py +++ b/demos/topology_optimization/animation.py @@ -1,5 +1,5 @@ import os from jax_fem.common import make_video -data_path = os.path.join(os.path.dirname(__file__), 'data') -make_video(data_path) \ No newline at end of file +data_path = os.path.join(os.path.dirname(__file__), "data") +make_video(data_path) diff --git a/demos/topology_optimization/example.py b/demos/topology_optimization/example.py index 615fe63..566b967 100644 --- a/demos/topology_optimization/example.py +++ b/demos/topology_optimization/example.py @@ -1,12 +1,16 @@ # Import some useful modules. +import os +import glob +import sys import numpy as onp import jax import jax.numpy as np -import os -import glob import matplotlib.pyplot as plt +sys.path.append("./") + + # Import JAX-FEM specific modules. from jax_fem.problem import Problem from jax_fem.solver import solver, ad_wrapper @@ -15,12 +19,12 @@ from jax_fem.mma import optimize -# Define constitutive relationship. -# Generally, JAX-FEM solves -div.(f(u_grad,alpha_1,alpha_2,...,alpha_N)) = b. +# Define constitutive relationship. +# Generally, JAX-FEM solves -div.(f(u_grad,alpha_1,alpha_2,...,alpha_N)) = b. # Here, we have f(u_grad,alpha_1,alpha_2,...,alpha_N) = sigma(u_grad, theta), -# reflected by the function 'stress'. The functions 'custom_init'and 'set_params' +# reflected by the function 'stress'. The functions 'custom_init'and 'set_params' # override base class methods. In particular, set_params sets the design variable theta. -class Elasticity(Problem): +class Elasticity(Problem): # 定义具体问题 def custom_init(self): # Override base class method. # Set up 'self.fe.flex_inds' so that location-specific TO can be realized. @@ -30,32 +34,40 @@ def custom_init(self): def get_tensor_map(self): def stress(u_grad, theta): # Plane stress assumption + # u_grad: gradient of displacement field + # theta: density field # Reference: https://en.wikipedia.org/wiki/Hooke%27s_law - Emax = 70.e3 - Emin = 1e-3*Emax + Emax = 70.0e3 + Emin = 1e-3 * Emax nu = 0.3 - penal = 3. - E = Emin + (Emax - Emin)*theta[0]**penal - epsilon = 0.5*(u_grad + u_grad.T) + penal = 3.0 + E = Emin + (Emax - Emin) * theta[0] ** penal + epsilon = 0.5 * (u_grad + u_grad.T) eps11 = epsilon[0, 0] eps22 = epsilon[1, 1] eps12 = epsilon[0, 1] - sig11 = E/(1 + nu)/(1 - nu)*(eps11 + nu*eps22) - sig22 = E/(1 + nu)/(1 - nu)*(nu*eps11 + eps22) - sig12 = E/(1 + nu)*eps12 + sig11 = E / (1 + nu) / (1 - nu) * (eps11 + nu * eps22) + sig22 = E / (1 + nu) / (1 - nu) * (nu * eps11 + eps22) + sig12 = E / (1 + nu) * eps12 sigma = np.array([[sig11, sig12], [sig12, sig22]]) return sigma + return stress def get_surface_maps(self): + # 没看懂这个函数的作用 def surface_map(u, x): - return np.array([0., 100.]) + return np.array([0.0, 100.0]) + return [surface_map] def set_params(self, params): # Override base class method. + # 这个函数的作用是将设计变量theta设置到问题中 full_params = np.ones((self.fe.num_cells, params.shape[1])) - full_params = full_params.at[self.fe.flex_inds].set(params) + full_params = full_params.at[self.fe.flex_inds].set( + params + ) # 将params中的值设置在 full_params 的flex_inds指定的位置 thetas = np.repeat(full_params[:, None, :], self.fe.num_quads, axis=1) self.full_params = full_params self.internal_vars = [thetas] @@ -64,66 +76,89 @@ def compute_compliance(self, sol): # Surface integral boundary_inds = self.boundary_inds_list[0] _, nanson_scale = self.fe.get_face_shape_grads(boundary_inds) - # (num_selected_faces, 1, num_nodes, vec) * # (num_selected_faces, num_face_quads, num_nodes, 1) - u_face = sol[self.fe.cells][boundary_inds[:, 0]][:, None, :, :] * self.fe.face_shape_vals[boundary_inds[:, 1]][:, :, :, None] - u_face = np.sum(u_face, axis=2) # (num_selected_faces, num_face_quads, vec) + # (num_selected_faces, 1, num_nodes, vec) * # (num_selected_faces, num_face_quads, num_nodes, 1) + # Nonson公式是连续介质力学中的概念,用于描述材料变形时面积向量的变换;nonson_scale表示每个面的变换尺度因子,用于计算变形后的面积向量 + u_face = ( + sol[self.fe.cells][boundary_inds[:, 0]][:, None, :, :] + * self.fe.face_shape_vals[boundary_inds[:, 1]][:, :, :, None] + ) + u_face = np.sum(u_face, axis=2) # (num_selected_faces, num_face_quads, vec) # (num_cells, num_faces, num_face_quads, dim) -> (num_selected_faces, num_face_quads, dim) - + # subset_quad_points = self.get_physical_surface_quad_points(boundary_inds) subset_quad_points = self.physical_surface_quad_points[0] neumann_fn = self.get_surface_maps()[0] - traction = -jax.vmap(jax.vmap(neumann_fn))(u_face, subset_quad_points) # (num_selected_faces, num_face_quads, vec) + traction = -jax.vmap(jax.vmap(neumann_fn))( + u_face, subset_quad_points + ) # (num_selected_faces, num_face_quads, vec) val = np.sum(traction * u_face * nanson_scale[:, :, None]) return val # Do some cleaning work. Remove old solution files. -data_path = os.path.join(os.path.dirname(__file__), 'data') -files = glob.glob(os.path.join(data_path, f'vtk/*')) +data_path = os.path.join(os.path.dirname(__file__), "data") +files = glob.glob(os.path.join(data_path, f"vtk/*")) for f in files: os.remove(f) # Specify mesh-related information. We use first-order quadrilateral element. -ele_type = 'QUAD4' +ele_type = "QUAD4" cell_type = get_meshio_cell_type(ele_type) -Lx, Ly = 60., 30. +Lx, Ly = 60.0, 30.0 meshio_mesh = rectangle_mesh(Nx=60, Ny=30, domain_x=Lx, domain_y=Ly) mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type]) +# meshio_mesh.points: coordinates of grid nodes +# meshio_mesh.cells_dict[cell_type]: connectivity of grid nodes (node id in one element) # Define boundary conditions and values. def fixed_location(point): - return np.isclose(point[0], 0., atol=1e-5) - + return np.isclose(point[0], 0.0, atol=1e-5) # 找到与 0.0 在 1e-5 误差范围内的点 + + def load_location(point): - return np.logical_and(np.isclose(point[0], Lx, atol=1e-5), np.isclose(point[1], 0., atol=0.1*Ly + 1e-5)) + return np.logical_and( + np.isclose(point[0], Lx, atol=1e-5), + np.isclose(point[1], 0.0, atol=0.1 * Ly + 1e-5), + ) + def dirichlet_val(point): - return 0. + return 0.0 -dirichlet_bc_info = [[fixed_location]*2, [0, 1], [dirichlet_val]*2] +dirichlet_bc_info = [[fixed_location] * 2, [0, 1], [dirichlet_val] * 2] +# 零位移(dirichlet)约束条件 location_fns = [load_location] # Define forward problem. -problem = Elasticity(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info, location_fns=location_fns) - - -# Apply the automatic differentiation wrapper. -# The flag 'use_petsc' specifies how the forward problem (could be linear or nonlinear) -# and the backward adjoint problem (always linear) should be solved by specifying use_petsc_adjoint. +problem = Elasticity( + mesh, + vec=2, + dim=2, + ele_type=ele_type, + dirichlet_bc_info=dirichlet_bc_info, + location_fns=location_fns, +) + + +# Apply the automatic differentiation wrapper. +# The flag 'use_petsc' specifies how the forward problem (could be linear or nonlinear) +# and the backward adjoint problem (always linear) should be solved by specifying use_petsc_adjoint. # This is a critical step that makes the problem solver differentiable. -fwd_pred = ad_wrapper(problem, linear=True, use_petsc=True, use_petsc_adjoint=True) +fwd_pred = ad_wrapper( + problem, linear=True, use_petsc=True, use_petsc_adjoint=True +) # 决定是否可以自动微分的关键一步 -# Define the objective function 'J_total(theta)'. +# Define the objective function 'J_total(theta)'. # In the following, 'sol = fwd_pred(params)' basically says U = U(theta). def J_total(params): - # J(u(theta), theta) + # J(u(theta), theta) sol_list = fwd_pred(params) compliance = problem.compute_compliance(sol_list[0]) return compliance @@ -131,15 +166,24 @@ def J_total(params): # Output solution files to local disk outputs = [] + + def output_sol(params, obj_val): print(f"\nOutput solution - need to solve the forward problem again...") sol_list = fwd_pred(params) sol = sol_list[0] - vtu_path = os.path.join(data_path, f'vtk/sol_{output_sol.counter:03d}.vtu') - save_sol(problem.fe, np.hstack((sol, np.zeros((len(sol), 1)))), vtu_path, cell_infos=[('theta', problem.full_params[:, 0])]) + vtu_path = os.path.join(data_path, f"vtk/sol_{output_sol.counter:03d}.vtu") + save_sol( + problem.fe, + np.hstack((sol, np.zeros((len(sol), 1)))), + vtu_path, + cell_infos=[("theta", problem.full_params[:, 0])], + ) print(f"compliance = {obj_val}") outputs.append(obj_val) output_sol.counter += 1 + + output_sol.counter = 0 @@ -150,6 +194,7 @@ def objectiveHandle(rho): # dJ has shape (...) = rho.shape J, dJ = jax.value_and_grad(J_total)(rho) output_sol(rho, J) + # rho: 1800x1 return J, dJ @@ -159,8 +204,9 @@ def consHandle(rho, epoch): # c should have shape (numConstraints,) # dc should have shape (numConstraints, ...) def computeGlobalVolumeConstraint(rho): - g = np.mean(rho)/vf - 1. + g = np.mean(rho) / vf - 1.0 return g + c, gradc = jax.value_and_grad(computeGlobalVolumeConstraint)(rho) c, gradc = c.reshape((1,)), gradc[None, ...] return c, gradc @@ -168,17 +214,21 @@ def computeGlobalVolumeConstraint(rho): # Finalize the details of the MMA optimizer, and solve the TO problem. vf = 0.5 -optimizationParams = {'maxIters':51, 'movelimit':0.1} -rho_ini = vf*np.ones((len(problem.fe.flex_inds), 1)) +optimizationParams = {"maxIters": 51, "movelimit": 0.1} +rho_ini = vf * np.ones((len(problem.fe.flex_inds), 1)) numConstraints = 1 -optimize(problem.fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numConstraints) -print(f"As a reminder, compliance = {J_total(np.ones((len(problem.fe.flex_inds), 1)))} for full material") +optimize( + problem.fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numConstraints +) +print( + f"As a reminder, compliance = {J_total(np.ones((len(problem.fe.flex_inds), 1)))} for full material" +) # Plot the optimization results. obj = onp.array(outputs) plt.figure(figsize=(10, 8)) -plt.plot(onp.arange(len(obj)) + 1, obj, linestyle='-', linewidth=2, color='black') +plt.plot(onp.arange(len(obj)) + 1, obj, linestyle="-", linewidth=2, color="black") plt.xlabel(r"Optimization step", fontsize=20) plt.ylabel(r"Objective value", fontsize=20) plt.tick_params(labelsize=20) diff --git a/jax_fem/generate_mesh.py b/jax_fem/generate_mesh.py index 5f50b3a..9ff40e5 100644 --- a/jax_fem/generate_mesh.py +++ b/jax_fem/generate_mesh.py @@ -10,12 +10,13 @@ import jax.numpy as np -class Mesh(): +class Mesh: """ A custom mesh manager might be better using a third-party library like meshio. """ - def __init__(self, points, cells, ele_type='TET4'): + + def __init__(self, points, cells, ele_type="TET4"): # TODO (Very important for debugging purpose!): Assert that cells must have correct orders self.points = points self.cells = cells @@ -62,6 +63,7 @@ def quality(pts): v12 = np.cross(v1, v2) v3 = p4 - p1 return np.dot(v12, v3) + qlts = jax.vmap(quality)(points[cells]) return qlts @@ -70,24 +72,24 @@ def get_meshio_cell_type(ele_type): """Reference: https://github.com/nschloe/meshio/blob/9dc6b0b05c9606cad73ef11b8b7785dd9b9ea325/src/meshio/xdmf/common.py#L36 """ - if ele_type == 'TET4': - cell_type = 'tetra' - elif ele_type == 'TET10': - cell_type = 'tetra10' - elif ele_type == 'HEX8': - cell_type = 'hexahedron' - elif ele_type == 'HEX27': - cell_type = 'hexahedron27' - elif ele_type == 'HEX20': - cell_type = 'hexahedron20' - elif ele_type == 'TRI3': - cell_type = 'triangle' - elif ele_type == 'TRI6': - cell_type = 'triangle6' - elif ele_type == 'QUAD4': - cell_type = 'quad' - elif ele_type == 'QUAD8': - cell_type = 'quad8' + if ele_type == "TET4": + cell_type = "tetra" + elif ele_type == "TET10": + cell_type = "tetra10" + elif ele_type == "HEX8": + cell_type = "hexahedron" + elif ele_type == "HEX27": + cell_type = "hexahedron27" + elif ele_type == "HEX20": + cell_type = "hexahedron20" + elif ele_type == "TRI3": + cell_type = "triangle" + elif ele_type == "TRI6": + cell_type = "triangle6" + elif ele_type == "QUAD4": + cell_type = "quad" + elif ele_type == "QUAD8": + cell_type = "quad8" else: raise NotImplementedError return cell_type @@ -97,7 +99,7 @@ def rectangle_mesh(Nx, Ny, domain_x, domain_y): dim = 2 x = onp.linspace(0, domain_x, Nx + 1) y = onp.linspace(0, domain_y, Ny + 1) - xv, yv = onp.meshgrid(x, y, indexing='ij') + xv, yv = onp.meshgrid(x, y, indexing="ij") points_xy = onp.stack((xv, yv), axis=dim) points = points_xy.reshape(-1, dim) points_inds = onp.arange(len(points)) @@ -107,11 +109,11 @@ def rectangle_mesh(Nx, Ny, domain_x, domain_y): inds3 = points_inds_xy[1:, 1:] inds4 = points_inds_xy[:-1, 1:] cells = onp.stack((inds1, inds2, inds3, inds4), axis=dim).reshape(-1, 4) - out_mesh = meshio.Mesh(points=points, cells={'quad': cells}) + out_mesh = meshio.Mesh(points=points, cells={"quad": cells}) return out_mesh - -def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type='HEX8'): + +def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type="HEX8"): """References: https://gitlab.onelab.info/gmsh/gmsh/-/blob/master/examples/api/hex.py https://gitlab.onelab.info/gmsh/gmsh/-/blob/gmsh_4_7_1/tutorial/python/t1.py @@ -120,25 +122,25 @@ def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type='HEX8'): Accepts ele_type = 'HEX8', 'TET4' or 'TET10' """ - assert ele_type != 'HEX20', f"gmsh cannot produce HEX20 mesh?" + assert ele_type != "HEX20", f"gmsh cannot produce HEX20 mesh?" cell_type = get_meshio_cell_type(ele_type) _, _, _, _, degree, _ = get_elements(ele_type) - msh_dir = os.path.join(data_dir, 'msh') + msh_dir = os.path.join(data_dir, "msh") os.makedirs(msh_dir, exist_ok=True) - msh_file = os.path.join(msh_dir, 'box.msh') + msh_file = os.path.join(msh_dir, "box.msh") - offset_x = 0. - offset_y = 0. - offset_z = 0. + offset_x = 0.0 + offset_y = 0.0 + offset_z = 0.0 domain_x = Lx domain_y = Ly domain_z = Lz gmsh.initialize() gmsh.option.setNumber("Mesh.MshFileVersion", 2.2) # save in old MSH format - if cell_type.startswith('tetra'): + if cell_type.startswith("tetra"): Rec2d = False # tris or quads Rec3d = False # tets, prisms or hexas else: @@ -156,8 +158,8 @@ def box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_dir, ele_type='HEX8'): gmsh.finalize() mesh = meshio.read(msh_file) - points = mesh.points # (num_total_nodes, dim) - cells = mesh.cells_dict[cell_type] # (num_cells, num_nodes) + points = mesh.points # (num_total_nodes, dim) + cells = mesh.cells_dict[cell_type] # (num_cells, num_nodes) out_mesh = meshio.Mesh(points=points, cells={cell_type: cells}) return out_mesh @@ -172,13 +174,13 @@ def cylinder_mesh(data_dir, R=5, H=10, circle_mesh=5, hight_mesh=20, rect_ratio= hight_mesh:num of meshs in hight rect_ratio: rect length/R """ - rect_coor = R*rect_ratio - msh_dir = os.path.join(data_dir, 'msh') + rect_coor = R * rect_ratio + msh_dir = os.path.join(data_dir, "msh") os.makedirs(msh_dir, exist_ok=True) - geo_file = os.path.join(msh_dir, 'cylinder.geo') - msh_file = os.path.join(msh_dir, 'cylinder.msh') + geo_file = os.path.join(msh_dir, "cylinder.geo") + msh_file = os.path.join(msh_dir, "cylinder.msh") - string=''' + string = """ Point(1) = {{0, 0, 0, 1.0}}; Point(2) = {{-{rect_coor}, {rect_coor}, 0, 1.0}}; Point(3) = {{{rect_coor}, {rect_coor}, 0, 1.0}}; @@ -227,19 +229,25 @@ def cylinder_mesh(data_dir, R=5, H=10, circle_mesh=5, hight_mesh=20, rect_ratio= Surface{{1:5}}; Layers {{{hight_mesh}}}; Recombine; }} - Mesh 3;'''.format(R=R, H=H, rect_coor=rect_coor, circle_mesh=circle_mesh, hight_mesh=hight_mesh) + Mesh 3;""".format( + R=R, H=H, rect_coor=rect_coor, circle_mesh=circle_mesh, hight_mesh=hight_mesh + ) with open(geo_file, "w") as f: f.write(string) - os.system("gmsh -3 {geo_file} -o {msh_file} -format msh2".format(geo_file=geo_file, msh_file=msh_file)) + os.system( + "gmsh -3 {geo_file} -o {msh_file} -format msh2".format( + geo_file=geo_file, msh_file=msh_file + ) + ) mesh = meshio.read(msh_file) - points = mesh.points # (num_total_nodes, dim) - cells = mesh.cells_dict['hexahedron'] # (num_cells, num_nodes) + points = mesh.points # (num_total_nodes, dim) + cells = mesh.cells_dict["hexahedron"] # (num_cells, num_nodes) # The mesh somehow has two redundant points... points = onp.vstack((points[1:14], points[15:])) cells = onp.where(cells > 14, cells - 2, cells - 1) - out_mesh = meshio.Mesh(points=points, cells={'hexahedron': cells}) + out_mesh = meshio.Mesh(points=points, cells={"hexahedron": cells}) return out_mesh diff --git a/jax_fem/mma.py b/jax_fem/mma.py index 69f5a3a..4c2efdf 100644 --- a/jax_fem/mma.py +++ b/jax_fem/mma.py @@ -6,6 +6,7 @@ Improvement is made to avoid N^2 memory operation so that the MMA solver is more scalable. """ + from numpy import diag as diags from numpy.linalg import solve import numpy as np @@ -17,6 +18,7 @@ import scipy from jax import config + config.update("jax_enable_x64", True) @@ -29,10 +31,10 @@ def compute_filter_kd_tree(fe): flex_cell_centroids = np.take(cell_centroids, fe.flex_inds, axis=0) V = np.sum(fe.JxW) - avg_elem_V = V/fe.num_cells + avg_elem_V = V / fe.num_cells - avg_elem_size = avg_elem_V**(1./fe.dim) - rmin = 1.5*avg_elem_size + avg_elem_size = avg_elem_V ** (1.0 / fe.dim) + rmin = 1.5 * avg_elem_size kd_tree = scipy.spatial.KDTree(flex_cell_centroids) I = [] @@ -42,13 +44,13 @@ def compute_filter_kd_tree(fe): num_nbs = 20 dd, ii = kd_tree.query(flex_cell_centroids[i], num_nbs) neighbors = np.take(flex_cell_centroids, ii, axis=0) - vals = np.where(rmin - dd > 0., rmin - dd, 0.) - I += [i]*num_nbs + vals = np.where(rmin - dd > 0.0, rmin - dd, 0.0) + I += [i] * num_nbs J += ii.tolist() V += vals.tolist() H_sp = scipy.sparse.csc_array((V, (I, J)), shape=(flex_num_cells, flex_num_cells)) - # TODO(Tianju): No need to create the full matrix. + # TODO(Tianju): No need to create the full matrix. # Will cause memory issue for large size problem. # High priority! @@ -58,69 +60,88 @@ def compute_filter_kd_tree(fe): def applySensitivityFilter(ft, rho, dJ, dvc): - dJ = np.matmul(ft['H'], rho*dJ/np.maximum(1e-3, rho)/ft['Hs'][:, None]) - dvc = np.matmul(ft['H'][None, :, :], rho[None, :, :]*dvc/np.maximum(1e-3, rho[None, :, :])/ft['Hs'][None, :, None]) + dJ = np.matmul(ft["H"], rho * dJ / np.maximum(1e-3, rho) / ft["Hs"][:, None]) + dvc = np.matmul( + ft["H"][None, :, :], + rho[None, :, :] + * dvc + / np.maximum(1e-3, rho[None, :, :]) + / ft["Hs"][None, :, None], + ) return dJ, dvc -#%% Optimizer +# %% Optimizer class MMA: # The code was modified from [MMA Svanberg 1987]. Please cite the paper if # you end up using this code. def __init__(self): - self.epoch = 0; + self.epoch = 0 + def resetMMACounter(self): - self.epoch = 0; + self.epoch = 0 + def registerMMAIter(self, xval, xold1, xold2): - self.epoch += 1; - self.xval = xval; - self.xold1 = xold1; - self.xold2 = xold2; + self.epoch += 1 + self.xval = xval + self.xold1 = xold1 + self.xold2 = xold2 + def setNumConstraints(self, numConstraints): - self.numConstraints = numConstraints; + self.numConstraints = numConstraints + def setNumDesignVariables(self, numDesVar): - self.numDesignVariables = numDesVar; + self.numDesignVariables = numDesVar + def setMinandMaxBoundsForDesignVariables(self, xmin, xmax): - self.xmin = xmin; - self.xmax = xmax; + self.xmin = xmin + self.xmax = xmax + def setObjectiveWithGradient(self, obj, objGrad): - self.objective = obj; - self.objectiveGradient = objGrad; + self.objective = obj + self.objectiveGradient = objGrad + def setConstraintWithGradient(self, cons, consGrad): - self.constraint = cons; - self.consGrad = consGrad; + self.constraint = cons + self.consGrad = consGrad + def setScalingParams(self, zconst, zscale, ylinscale, yquadscale): - self.zconst = zconst; - self.zscale = zscale; - self.ylinscale = ylinscale; - self.yquadscale = yquadscale; + self.zconst = zconst + self.zscale = zscale + self.ylinscale = ylinscale + self.yquadscale = yquadscale + def setMoveLimit(self, movelim): - self.moveLimit = movelim; + self.moveLimit = movelim + def setLowerAndUpperAsymptotes(self, low, upp): - self.lowAsymp = low; - self.upAsymp = upp; + self.lowAsymp = low + self.upAsymp = upp def getOptimalValues(self): - return self.xmma, self.ymma, self.zmma; + return self.xmma, self.ymma, self.zmma + def getLagrangeMultipliers(self): - return self.lam, self.xsi, self.eta, self.mu, self.zet; + return self.lam, self.xsi, self.eta, self.mu, self.zet + def getSlackValue(self): - return self.slack; + return self.slack + def getAsymptoteValues(self): - return self.lowAsymp, self.upAsymp; + return self.lowAsymp, self.upAsymp # Function for the MMA sub problem def mmasub(self, xval): - m = self.numConstraints; - n = self.numDesignVariables; - iter = self.epoch; - xmin, xmax = self.xmin, self.xmax; - xold1, xold2 = self.xold1, self.xold2; - f0val, df0dx = self.objective, self.objectiveGradient; - fval, dfdx = self.constraint, self.consGrad; - low, upp = self.lowAsymp, self.upAsymp; - a0, a, c, d = self.zconst, self.zscale, self.ylinscale, self.yquadscale; - move = self.moveLimit; + m = self.numConstraints + n = self.numDesignVariables + iter = self.epoch + xmin, xmax = self.xmin, self.xmax + xold1, xold2 = self.xold1, self.xold2 + f0val, df0dx = self.objective, self.objectiveGradient + fval, dfdx = self.constraint, self.consGrad + low, upp = self.lowAsymp, self.upAsymp + a0, a, c, d = self.zconst, self.zscale, self.ylinscale, self.yquadscale + move = self.moveLimit epsimin = 0.0000001 raa0 = 0.00001 @@ -133,211 +154,217 @@ def mmasub(self, xval): zeron = np.zeros((n, 1)) # Calculation of the asymptotes low and upp if iter <= 2: - low = xval-asyinit*(xmax-xmin) - upp = xval+asyinit*(xmax-xmin) + low = xval - asyinit * (xmax - xmin) + upp = xval + asyinit * (xmax - xmin) else: - zzz = (xval-xold1)*(xold1-xold2) + zzz = (xval - xold1) * (xold1 - xold2) factor = eeen.copy() - factor[np.where(zzz>0)] = asyincr - factor[np.where(zzz<0)] = asydecr - low = xval-factor*(xold1-low) - upp = xval+factor*(upp-xold1) - lowmin = xval-10*(xmax-xmin) - lowmax = xval-0.01*(xmax-xmin) - uppmin = xval+0.01*(xmax-xmin) - uppmax = xval+10*(xmax-xmin) - low = np.maximum(low,lowmin) - low = np.minimum(low,lowmax) - upp = np.minimum(upp,uppmax) - upp = np.maximum(upp,uppmin) + factor[np.where(zzz > 0)] = asyincr + factor[np.where(zzz < 0)] = asydecr + low = xval - factor * (xold1 - low) + upp = xval + factor * (upp - xold1) + lowmin = xval - 10 * (xmax - xmin) + lowmax = xval - 0.01 * (xmax - xmin) + uppmin = xval + 0.01 * (xmax - xmin) + uppmax = xval + 10 * (xmax - xmin) + low = np.maximum(low, lowmin) + low = np.minimum(low, lowmax) + upp = np.minimum(upp, uppmax) + upp = np.maximum(upp, uppmin) # Calculation of the bounds alfa and beta - zzz1 = low+albefa*(xval-low) - zzz2 = xval-move*(xmax-xmin) - zzz = np.maximum(zzz1,zzz2) - alfa = np.maximum(zzz,xmin) - zzz1 = upp-albefa*(upp-xval) - zzz2 = xval+move*(xmax-xmin) - zzz = np.minimum(zzz1,zzz2) - beta = np.minimum(zzz,xmax) + zzz1 = low + albefa * (xval - low) + zzz2 = xval - move * (xmax - xmin) + zzz = np.maximum(zzz1, zzz2) + alfa = np.maximum(zzz, xmin) + zzz1 = upp - albefa * (upp - xval) + zzz2 = xval + move * (xmax - xmin) + zzz = np.minimum(zzz1, zzz2) + beta = np.minimum(zzz, xmax) # Calculations of p0, q0, P, Q and b - xmami = xmax-xmin - xmamieps = 0.00001*eeen - xmami = np.maximum(xmami,xmamieps) - xmamiinv = eeen/xmami - ux1 = upp-xval - ux2 = ux1*ux1 - xl1 = xval-low - xl2 = xl1*xl1 - uxinv = eeen/ux1 - xlinv = eeen/xl1 + xmami = xmax - xmin + xmamieps = 0.00001 * eeen + xmami = np.maximum(xmami, xmamieps) + xmamiinv = eeen / xmami + ux1 = upp - xval + ux2 = ux1 * ux1 + xl1 = xval - low + xl2 = xl1 * xl1 + uxinv = eeen / ux1 + xlinv = eeen / xl1 p0 = zeron.copy() q0 = zeron.copy() - p0 = np.maximum(df0dx,0) - q0 = np.maximum(-df0dx,0) - pq0 = 0.001*(p0+q0)+raa0*xmamiinv - p0 = p0+pq0 - q0 = q0+pq0 - p0 = p0*ux2 - q0 = q0*xl2 - P = np.zeros((m,n)) ## @@ make sparse with scipy? - Q = np.zeros((m,n)) ## @@ make sparse with scipy? - P = np.maximum(dfdx,0) - Q = np.maximum(-dfdx,0) - PQ = 0.001*(P+Q)+raa0*np.dot(eeem,xmamiinv.T) - P = P+PQ - Q = Q+PQ + p0 = np.maximum(df0dx, 0) + q0 = np.maximum(-df0dx, 0) + pq0 = 0.001 * (p0 + q0) + raa0 * xmamiinv + p0 = p0 + pq0 + q0 = q0 + pq0 + p0 = p0 * ux2 + q0 = q0 * xl2 + P = np.zeros((m, n)) ## @@ make sparse with scipy? + Q = np.zeros((m, n)) ## @@ make sparse with scipy? + P = np.maximum(dfdx, 0) + Q = np.maximum(-dfdx, 0) + PQ = 0.001 * (P + Q) + raa0 * np.dot(eeem, xmamiinv.T) + P = P + PQ + Q = Q + PQ # P = (diags(ux2.flatten(),0).dot(P.T)).T # Q = (diags(xl2.flatten(),0).dot(Q.T)).T - P = ux2.T*P - Q = xl2.T*Q + P = ux2.T * P + Q = xl2.T * Q - b = (np.dot(P,uxinv)+np.dot(Q,xlinv)-fval) + b = np.dot(P, uxinv) + np.dot(Q, xlinv) - fval # Solving the subproblem by a primal-dual Newton method - xmma,ymma,zmma,lam,xsi,eta,mu,zet,s = subsolv(m,n,epsimin,low,upp,alfa,\ - beta,p0,q0,P,Q,a0,a,b,c,d) + xmma, ymma, zmma, lam, xsi, eta, mu, zet, s = subsolv( + m, n, epsimin, low, upp, alfa, beta, p0, q0, P, Q, a0, a, b, c, d + ) # Return values - self.xmma, self.ymma, self.zmma = xmma, ymma, zmma; - self.lam, self.xsi, self.eta, self.mu, self.zet = lam,xsi,eta,mu,zet; - self.slack = s; - self.lowAsymp, self.upAsymp = low, upp; + self.xmma, self.ymma, self.zmma = xmma, ymma, zmma + self.lam, self.xsi, self.eta, self.mu, self.zet = lam, xsi, eta, mu, zet + self.slack = s + self.lowAsymp, self.upAsymp = low, upp -def subsolv(m,n,epsimin,low,upp,alfa,beta,p0,q0,P,Q,a0,a,b,c,d): - een = np.ones((n,1)) - eem = np.ones((m,1)) +def subsolv(m, n, epsimin, low, upp, alfa, beta, p0, q0, P, Q, a0, a, b, c, d): + een = np.ones((n, 1)) + eem = np.ones((m, 1)) epsi = 1 - epsvecn = epsi*een - epsvecm = epsi*eem - x = 0.5*(alfa+beta) + epsvecn = epsi * een + epsvecm = epsi * eem + x = 0.5 * (alfa + beta) y = eem.copy() z = np.array([[1.0]]) lam = eem.copy() - xsi = een/(x-alfa) - xsi = np.maximum(xsi,een) - eta = een/(beta-x) - eta = np.maximum(eta,een) - mu = np.maximum(eem,0.5*c) + xsi = een / (x - alfa) + xsi = np.maximum(xsi, een) + eta = een / (beta - x) + eta = np.maximum(eta, een) + mu = np.maximum(eem, 0.5 * c) zet = np.array([[1.0]]) s = eem.copy() itera = 0 # Start while epsi>epsimin while epsi > epsimin: - epsvecn = epsi*een - epsvecm = epsi*eem - ux1 = upp-x - xl1 = x-low - ux2 = ux1*ux1 - xl2 = xl1*xl1 - uxinv1 = een/ux1 - xlinv1 = een/xl1 - plam = p0+np.dot(P.T,lam) - qlam = q0+np.dot(Q.T,lam) - gvec = np.dot(P,uxinv1)+np.dot(Q,xlinv1) - dpsidx = plam/ux2-qlam/xl2 - rex = dpsidx-xsi+eta - rey = c+d*y-mu-lam - rez = a0-zet-np.dot(a.T,lam) - relam = gvec-a*z-y+s-b - rexsi = xsi*(x-alfa)-epsvecn - reeta = eta*(beta-x)-epsvecn - remu = mu*y-epsvecm - rezet = zet*z-epsi - res = lam*s-epsvecm - residu1 = np.concatenate((rex, rey, rez), axis = 0) - residu2 = np.concatenate((relam, rexsi, reeta, remu, rezet, res), axis = 0) - residu = np.concatenate((residu1, residu2), axis = 0) - residunorm = np.sqrt((np.dot(residu.T,residu)).item()) + epsvecn = epsi * een + epsvecm = epsi * eem + ux1 = upp - x + xl1 = x - low + ux2 = ux1 * ux1 + xl2 = xl1 * xl1 + uxinv1 = een / ux1 + xlinv1 = een / xl1 + plam = p0 + np.dot(P.T, lam) + qlam = q0 + np.dot(Q.T, lam) + gvec = np.dot(P, uxinv1) + np.dot(Q, xlinv1) + dpsidx = plam / ux2 - qlam / xl2 + rex = dpsidx - xsi + eta + rey = c + d * y - mu - lam + rez = a0 - zet - np.dot(a.T, lam) + relam = gvec - a * z - y + s - b + rexsi = xsi * (x - alfa) - epsvecn + reeta = eta * (beta - x) - epsvecn + remu = mu * y - epsvecm + rezet = zet * z - epsi + res = lam * s - epsvecm + residu1 = np.concatenate((rex, rey, rez), axis=0) + residu2 = np.concatenate((relam, rexsi, reeta, remu, rezet, res), axis=0) + residu = np.concatenate((residu1, residu2), axis=0) + residunorm = np.sqrt((np.dot(residu.T, residu)).item()) residumax = np.max(np.abs(residu)) ittt = 0 # Start while (residumax>0.9*epsi) and (ittt<200) - while (residumax > 0.9*epsi) and (ittt < 200): - ittt = ittt+1 - itera = itera+1 - ux1 = upp-x - xl1 = x-low - ux2 = ux1*ux1 - xl2 = xl1*xl1 - ux3 = ux1*ux2 - xl3 = xl1*xl2 - uxinv1 = een/ux1 - xlinv1 = een/xl1 - uxinv2 = een/ux2 - xlinv2 = een/xl2 - plam = p0+np.dot(P.T,lam) - qlam = q0+np.dot(Q.T,lam) - gvec = np.dot(P,uxinv1)+np.dot(Q,xlinv1) + while (residumax > 0.9 * epsi) and (ittt < 200): + ittt = ittt + 1 + itera = itera + 1 + ux1 = upp - x + xl1 = x - low + ux2 = ux1 * ux1 + xl2 = xl1 * xl1 + ux3 = ux1 * ux2 + xl3 = xl1 * xl2 + uxinv1 = een / ux1 + xlinv1 = een / xl1 + uxinv2 = een / ux2 + xlinv2 = een / xl2 + plam = p0 + np.dot(P.T, lam) + qlam = q0 + np.dot(Q.T, lam) + gvec = np.dot(P, uxinv1) + np.dot(Q, xlinv1) # GG = (diags(uxinv2.flatten(),0).dot(P.T)).T-(diags\ # (xlinv2.flatten(),0).dot(Q.T)).T - GG = uxinv2.T*P - xlinv2.T*Q - - dpsidx = plam/ux2-qlam/xl2 - delx = dpsidx-epsvecn/(x-alfa)+epsvecn/(beta-x) - dely = c+d*y-lam-epsvecm/y - delz = a0-np.dot(a.T,lam)-epsi/z - dellam = gvec-a*z-y-b+epsvecm/lam - diagx = plam/ux3+qlam/xl3 - diagx = 2*diagx+xsi/(x-alfa)+eta/(beta-x) - diagxinv = een/diagx - diagy = d+mu/y - diagyinv = eem/diagy - diaglam = s/lam - diaglamyi = diaglam+diagyinv + GG = uxinv2.T * P - xlinv2.T * Q + + dpsidx = plam / ux2 - qlam / xl2 + delx = dpsidx - epsvecn / (x - alfa) + epsvecn / (beta - x) + dely = c + d * y - lam - epsvecm / y + delz = a0 - np.dot(a.T, lam) - epsi / z + dellam = gvec - a * z - y - b + epsvecm / lam + diagx = plam / ux3 + qlam / xl3 + diagx = 2 * diagx + xsi / (x - alfa) + eta / (beta - x) + diagxinv = een / diagx + diagy = d + mu / y + diagyinv = eem / diagy + diaglam = s / lam + diaglamyi = diaglam + diagyinv # Start if mresidunorm) and (itto<50) while (resinew > residunorm) and (itto < 50): - itto = itto+1 - x = xold+steg*dx - y = yold+steg*dy - z = zold+steg*dz - lam = lamold+steg*dlam - xsi = xsiold+steg*dxsi - eta = etaold+steg*deta - mu = muold+steg*dmu - zet = zetold+steg*dzet - s = sold+steg*ds - ux1 = upp-x - xl1 = x-low - ux2 = ux1*ux1 - xl2 = xl1*xl1 - uxinv1 = een/ux1 - xlinv1 = een/xl1 - plam = p0+np.dot(P.T,lam) - qlam = q0+np.dot(Q.T,lam) - gvec = np.dot(P,uxinv1)+np.dot(Q,xlinv1) - dpsidx = plam/ux2-qlam/xl2 - rex = dpsidx-xsi+eta - rey = c+d*y-mu-lam - rez = a0-zet-np.dot(a.T,lam) - relam = gvec-np.dot(a,z)-y+s-b - rexsi = xsi*(x-alfa)-epsvecn - reeta = eta*(beta-x)-epsvecn - remu = mu*y-epsvecm - rezet = np.dot(zet,z)-epsi - res = lam*s-epsvecm - residu1 = np.concatenate((rex,rey,rez),axis = 0) - residu2 = np.concatenate((relam,rexsi,reeta,remu,rezet,res), \ - axis = 0) - residu = np.concatenate((residu1,residu2),axis = 0) - resinew = np.sqrt(np.dot(residu.T,residu)) - steg = steg/2 + itto = itto + 1 + x = xold + steg * dx + y = yold + steg * dy + z = zold + steg * dz + lam = lamold + steg * dlam + xsi = xsiold + steg * dxsi + eta = etaold + steg * deta + mu = muold + steg * dmu + zet = zetold + steg * dzet + s = sold + steg * ds + ux1 = upp - x + xl1 = x - low + ux2 = ux1 * ux1 + xl2 = xl1 * xl1 + uxinv1 = een / ux1 + xlinv1 = een / xl1 + plam = p0 + np.dot(P.T, lam) + qlam = q0 + np.dot(Q.T, lam) + gvec = np.dot(P, uxinv1) + np.dot(Q, xlinv1) + dpsidx = plam / ux2 - qlam / xl2 + rex = dpsidx - xsi + eta + rey = c + d * y - mu - lam + rez = a0 - zet - np.dot(a.T, lam) + relam = gvec - np.dot(a, z) - y + s - b + rexsi = xsi * (x - alfa) - epsvecn + reeta = eta * (beta - x) - epsvecn + remu = mu * y - epsvecm + rezet = np.dot(zet, z) - epsi + res = lam * s - epsvecm + residu1 = np.concatenate((rex, rey, rez), axis=0) + residu2 = np.concatenate( + (relam, rexsi, reeta, remu, rezet, res), axis=0 + ) + residu = np.concatenate((residu1, residu2), axis=0) + resinew = np.sqrt(np.dot(residu.T, residu)) + steg = steg / 2 # End: while (resinew>residunorm) and (itto<50) residunorm = resinew.copy() residumax = max(abs(residu)) - steg = 2*steg + steg = 2 * steg # End: while (residumax>0.9*epsi) and (ittt<200) - epsi = 0.1*epsi + epsi = 0.1 * epsi # End: while epsi>epsimin xmma = x.copy() @@ -408,41 +436,43 @@ def subsolv(m,n,epsimin,low,upp,alfa,beta,p0,q0,P,Q,a0,a,b,c,d): zetmma = zet smma = s - return xmma,ymma,zmma,lamma,xsimma,etamma,mumma,zetmma,smma + return xmma, ymma, zmma, lamma, xsimma, etamma, mumma, zetmma, smma -def optimize(fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numConstraints): +def optimize( + fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numConstraints +): # TODO: Scale objective function value to be always within 1-100 # See comments in https://doi.org/10.1016/j.compstruc.2018.01.008 H, Hs = compute_filter_kd_tree(fe) - ft = {'H':H, 'Hs':Hs} + ft = {"H": H, "Hs": Hs} rho = rho_ini loop = 0 - m = numConstraints # num constraints - n = len(rho.reshape(-1)) # num params + m = numConstraints # num constraints + n = len(rho.reshape(-1)) # num params mma = MMA() mma.setNumConstraints(numConstraints) mma.setNumDesignVariables(n) - mma.setMinandMaxBoundsForDesignVariables\ - (np.zeros((n,1)),np.ones((n,1))) + mma.setMinandMaxBoundsForDesignVariables(np.zeros((n, 1)), np.ones((n, 1))) - xval = rho.reshape(-1)[:, None] + xval = rho.reshape(-1)[:, None] xold1, xold2 = xval.copy(), xval.copy() mma.registerMMAIter(xval, xold1, xold2) - mma.setLowerAndUpperAsymptotes(np.ones((n,1)), np.ones((n,1))) - mma.setScalingParams(1.0, np.zeros((m,1)), \ - 10000*np.ones((m,1)), np.zeros((m,1))) + mma.setLowerAndUpperAsymptotes(np.ones((n, 1)), np.ones((n, 1))) + mma.setScalingParams( + 1.0, np.zeros((m, 1)), 10000 * np.ones((m, 1)), np.zeros((m, 1)) + ) # Move limit is an important parameter that affects TO result; default can be 0.2 - mma.setMoveLimit(optimizationParams['movelimit']) + mma.setMoveLimit(optimizationParams["movelimit"]) - while loop < optimizationParams['maxIters']: + while loop < optimizationParams["maxIters"]: loop = loop + 1 print(f"MMA solver...") - + J, dJ = objectiveHandle(rho) vc, dvc = consHandle(rho, loop) @@ -478,6 +508,6 @@ def optimize(fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numCo print(f"MMA took {time_elapsed} [s]") - print(f'Iter {loop:d}; J {J:.5f}; constraint {vc}\n\n\n') + print(f"Iter {loop:d}; J {J:.5f}; constraint {vc}\n\n\n") return rho diff --git a/jax_fem/problem.py b/jax_fem/problem.py index 337d8c7..b36a56d 100644 --- a/jax_fem/problem.py +++ b/jax_fem/problem.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, List, Union import functools -from jax_fem.common import timeit +from jax_fem.common import timeit from jax_fem.generate_mesh import Mesh from jax_fem.fe import FiniteElement from jax_fem import logger @@ -17,9 +17,11 @@ class Problem: mesh: Mesh vec: int dim: int - ele_type: str = 'HEX8' + ele_type: str = "HEX8" gauss_order: int = None - dirichlet_bc_info: Optional[List[Union[List[Callable], List[int], List[Callable]]]] = None + dirichlet_bc_info: Optional[ + List[Union[List[Callable], List[int], List[Callable]]] + ] = None location_fns: Optional[List[Callable]] = None additional_info: Any = () @@ -34,20 +36,34 @@ def __post_init__(self): self.num_vars = len(self.mesh) - self.fes = [FiniteElement(mesh=self.mesh[i], - vec=self.vec[i], - dim=self.dim, - ele_type=self.ele_type[i], - gauss_order=self.gauss_order[i] if type(self.gauss_order) == type([]) else self.gauss_order, - dirichlet_bc_info=self.dirichlet_bc_info[i] if type(self.dirichlet_bc_info) == type([]) else self.dirichlet_bc_info) \ - for i in range(self.num_vars)] + self.fes = [ + FiniteElement( + mesh=self.mesh[i], + vec=self.vec[i], + dim=self.dim, + ele_type=self.ele_type[i], + gauss_order=( + self.gauss_order[i] + if type(self.gauss_order) == type([]) + else self.gauss_order + ), + dirichlet_bc_info=( + self.dirichlet_bc_info[i] + if type(self.dirichlet_bc_info) == type([]) + else self.dirichlet_bc_info + ), + ) + for i in range(self.num_vars) + ] self.cells_list = [fe.cells for fe in self.fes] # Assume all fes have the same number of cells, same dimension self.num_cells = self.fes[0].num_cells - self.boundary_inds_list = self.fes[0].get_boundary_conditions_inds(self.location_fns) + self.boundary_inds_list = self.fes[0].get_boundary_conditions_inds( + self.location_fns + ) - self.offset = [0] + self.offset = [0] for i in range(len(self.fes) - 1): self.offset.append(self.offset[i] + self.fes[i].num_total_dofs) @@ -55,7 +71,11 @@ def find_ind(*x): inds = [] for i in range(len(x)): x[i].reshape(-1) - crt_ind = self.fes[i].vec * x[i][:, None] + np.arange(self.fes[i].vec)[None, :] + self.offset[i] + crt_ind = ( + self.fes[i].vec * x[i][:, None] + + np.arange(self.fes[i].vec)[None, :] + + self.offset[i] + ) inds.append(crt_ind.reshape(-1)) return np.hstack(inds) @@ -67,29 +87,41 @@ def find_ind(*x): self.cells_list_face_list = [] for i, boundary_inds in enumerate(self.boundary_inds_list): - cells_list_face = [cells[boundary_inds[:, 0]] for cells in self.cells_list] # [(num_selected_faces, num_nodes), ...] - inds_face = onp.array(jax.vmap(find_ind)(*cells_list_face)) # (num_selected_faces, num_nodes*vec + ...) - I_face = onp.repeat(inds_face[:, :, None], inds_face.shape[1], axis=2).reshape(-1) - J_face = onp.repeat(inds_face[:, None, :], inds_face.shape[1], axis=1).reshape(-1) + cells_list_face = [ + cells[boundary_inds[:, 0]] for cells in self.cells_list + ] # [(num_selected_faces, num_nodes), ...] + inds_face = onp.array( + jax.vmap(find_ind)(*cells_list_face) + ) # (num_selected_faces, num_nodes*vec + ...) + I_face = onp.repeat( + inds_face[:, :, None], inds_face.shape[1], axis=2 + ).reshape(-1) + J_face = onp.repeat( + inds_face[:, None, :], inds_face.shape[1], axis=1 + ).reshape(-1) self.I = onp.hstack((self.I, I_face)) self.J = onp.hstack((self.J, J_face)) self.cells_list_face_list.append(cells_list_face) - - self.cells_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])(*self.cells_list) # (num_cells, num_nodes + ...) + + self.cells_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])( + *self.cells_list + ) # (num_cells, num_nodes + ...) dumb_array_dof = [np.zeros((fe.num_nodes, fe.vec)) for fe in self.fes] # TODO: dumb_array_dof is useless? dumb_array_node = [np.zeros(fe.num_nodes) for fe in self.fes] # _, unflatten_fn_node = jax.flatten_util.ravel_pytree(dumb_array_node) _, self.unflatten_fn_dof = jax.flatten_util.ravel_pytree(dumb_array_dof) - + dumb_sol_list = [np.zeros((fe.num_total_nodes, fe.vec)) for fe in self.fes] - dumb_dofs, self.unflatten_fn_sol_list = jax.flatten_util.ravel_pytree(dumb_sol_list) + dumb_dofs, self.unflatten_fn_sol_list = jax.flatten_util.ravel_pytree( + dumb_sol_list + ) self.num_total_dofs_all_vars = len(dumb_dofs) self.num_nodes_cumsum = onp.cumsum([0] + [fe.num_nodes for fe in self.fes]) # (num_cells, num_vars, num_quads) - self.JxW = onp.transpose(onp.stack([fe.JxW for fe in self.fes]), axes=(1, 0, 2)) + self.JxW = onp.transpose(onp.stack([fe.JxW for fe in self.fes]), axes=(1, 0, 2)) # (num_cells, num_quads, num_nodes +..., dim) self.shape_grads = onp.concatenate([fe.shape_grads for fe in self.fes], axis=2) # (num_cells, num_quads, num_nodes + ..., 1, dim) @@ -97,7 +129,7 @@ def find_ind(*x): # TODO: assert all vars quad points be the same # (num_cells, num_quads, dim) - self.physical_quad_points = self.fes[0].get_physical_quad_points() + self.physical_quad_points = self.fes[0].get_physical_quad_points() self.selected_face_shape_grads = [] self.nanson_scale = [] @@ -109,8 +141,12 @@ def find_ind(*x): s_shape_vals = [] for fe in self.fes: # (num_selected_faces, num_face_quads, num_nodes, dim), (num_selected_faces, num_face_quads) - face_shape_grads_physical, nanson_scale = fe.get_face_shape_grads(boundary_inds) - selected_face_shape_vals = fe.face_shape_vals[boundary_inds[:, 1]] # (num_selected_faces, num_face_quads, num_nodes) + face_shape_grads_physical, nanson_scale = fe.get_face_shape_grads( + boundary_inds + ) + selected_face_shape_vals = fe.face_shape_vals[ + boundary_inds[:, 1] + ] # (num_selected_faces, num_face_quads, num_nodes) s_shape_grads.append(face_shape_grads_physical) n_scale.append(nanson_scale) s_shape_vals.append(selected_face_shape_vals) @@ -118,11 +154,13 @@ def find_ind(*x): # (num_selected_faces, num_face_quads, num_nodes + ..., dim) s_shape_grads = onp.concatenate(s_shape_grads, axis=2) # (num_selected_faces, num_vars, num_face_quads) - n_scale = onp.transpose(onp.stack(n_scale), axes=(1, 0, 2)) + n_scale = onp.transpose(onp.stack(n_scale), axes=(1, 0, 2)) # (num_selected_faces, num_face_quads, num_nodes + ...) s_shape_vals = onp.concatenate(s_shape_vals, axis=2) # (num_selected_faces, num_face_quads, dim) - physical_surface_quad_points = self.fes[0].get_physical_surface_quad_points(boundary_inds) + physical_surface_quad_points = self.fes[0].get_physical_surface_quad_points( + boundary_inds + ) self.selected_face_shape_grads.append(s_shape_grads) self.nanson_scale.append(n_scale) @@ -136,33 +174,38 @@ def find_ind(*x): self.pre_jit_fns() def custom_init(self): - """Child class should override if more things need to be done in initialization - """ + """Child class should override if more things need to be done in initialization""" pass def get_laplace_kernel(self, tensor_map): - def laplace_kernel(cell_sol_flat, cell_shape_grads, cell_v_grads_JxW, *cell_internal_vars): + def laplace_kernel( + cell_sol_flat, cell_shape_grads, cell_v_grads_JxW, *cell_internal_vars + ): # cell_sol_flat: (num_nodes*vec + ...,) # cell_sol_list: [(num_nodes, vec), ...] # cell_shape_grads: (num_quads, num_nodes + ..., dim) # cell_v_grads_JxW: (num_quads, num_nodes + ..., 1, dim) cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) - cell_shape_grads = cell_shape_grads[:, :self.fes[0].num_nodes, :] + cell_shape_grads = cell_shape_grads[:, : self.fes[0].num_nodes, :] cell_sol = cell_sol_list[0] - cell_v_grads_JxW = cell_v_grads_JxW[:, :self.fes[0].num_nodes, :, :] + cell_v_grads_JxW = cell_v_grads_JxW[:, : self.fes[0].num_nodes, :, :] vec = self.fes[0].vec # (1, num_nodes, vec, 1) * (num_quads, num_nodes, 1, dim) -> (num_quads, num_nodes, vec, dim) u_grads = cell_sol[None, :, :, None] * cell_shape_grads[:, :, None, :] u_grads = np.sum(u_grads, axis=1) # (num_quads, vec, dim) - u_grads_reshape = u_grads.reshape(-1, vec, self.dim) # (num_quads, vec, dim) + u_grads_reshape = u_grads.reshape( + -1, vec, self.dim + ) # (num_quads, vec, dim) # (num_quads, vec, dim) - u_physics = jax.vmap(tensor_map)(u_grads_reshape, *cell_internal_vars).reshape(u_grads.shape) + u_physics = jax.vmap(tensor_map)( + u_grads_reshape, *cell_internal_vars + ).reshape(u_grads.shape) # (num_quads, num_nodes, vec, dim) -> (num_nodes, vec) val = np.sum(u_physics[:, None, :, :] * cell_v_grads_JxW, axis=(0, -1)) - val = jax.flatten_util.ravel_pytree(val)[0] # (num_nodes*vec + ...,) + val = jax.flatten_util.ravel_pytree(val)[0] # (num_nodes*vec + ...,) return val return laplace_kernel @@ -180,18 +223,34 @@ def mass_kernel(cell_sol_flat, x, cell_JxW, *cell_internal_vars): cell_JxW = cell_JxW[0] vec = self.fes[0].vec # (1, num_nodes, vec) * (num_quads, num_nodes, 1) -> (num_quads, num_nodes, vec) -> (num_quads, vec) - u = np.sum(cell_sol[None, :, :] * self.fes[0].shape_vals[:, :, None], axis=1) - u_physics = jax.vmap(mass_map)(u, x, *cell_internal_vars) # (num_quads, vec) + u = np.sum( + cell_sol[None, :, :] * self.fes[0].shape_vals[:, :, None], axis=1 + ) + u_physics = jax.vmap(mass_map)( + u, x, *cell_internal_vars + ) # (num_quads, vec) # (num_quads, 1, vec) * (num_quads, num_nodes, 1) * (num_quads, 1, 1) -> (num_nodes, vec) - val = np.sum(u_physics[:, None, :] * self.fes[0].shape_vals[:, :, None] * cell_JxW[:, None, None], axis=0) - val = jax.flatten_util.ravel_pytree(val)[0] # (num_nodes*vec + ...,) + val = np.sum( + u_physics[:, None, :] + * self.fes[0].shape_vals[:, :, None] + * cell_JxW[:, None, None], + axis=0, + ) + val = jax.flatten_util.ravel_pytree(val)[0] # (num_nodes*vec + ...,) return val return mass_kernel def get_surface_kernel(self, surface_map): - def surface_kernel(cell_sol_flat, x, face_shape_vals, face_shape_grads, face_nanson_scale, *cell_internal_vars_surface): + def surface_kernel( + cell_sol_flat, + x, + face_shape_vals, + face_shape_grads, + face_nanson_scale, + *cell_internal_vars_surface, + ): # face_shape_vals: (num_face_quads, num_nodes + ...) # face_shape_grads: (num_face_quads, num_nodes + ..., dim) # x: (num_face_quads, dim) @@ -199,14 +258,21 @@ def surface_kernel(cell_sol_flat, x, face_shape_vals, face_shape_grads, face_nan cell_sol_list = self.unflatten_fn_dof(cell_sol_flat) cell_sol = cell_sol_list[0] - face_shape_vals = face_shape_vals[:, :self.fes[0].num_nodes] + face_shape_vals = face_shape_vals[:, : self.fes[0].num_nodes] face_nanson_scale = face_nanson_scale[0] # (1, num_nodes, vec) * (num_face_quads, num_nodes, 1) -> (num_face_quads, vec) u = np.sum(cell_sol[None, :, :] * face_shape_vals[:, :, None], axis=1) - u_physics = jax.vmap(surface_map)(u, x, *cell_internal_vars_surface) # (num_face_quads, vec) + u_physics = jax.vmap(surface_map)( + u, x, *cell_internal_vars_surface + ) # (num_face_quads, vec) # (num_face_quads, 1, vec) * (num_face_quads, num_nodes, 1) * (num_face_quads, 1, 1) -> (num_nodes, vec) - val = np.sum(u_physics[:, None, :] * face_shape_vals[:, :, None] * face_nanson_scale[:, None, None], axis=0) + val = np.sum( + u_physics[:, None, :] + * face_shape_vals[:, :, None] + * face_nanson_scale[:, None, None], + axis=0, + ) return jax.flatten_util.ravel_pytree(val)[0] @@ -214,74 +280,121 @@ def surface_kernel(cell_sol_flat, x, face_shape_vals, face_shape_grads, face_nan def pre_jit_fns(self): def value_and_jacfwd(f, x): - pushfwd = functools.partial(jax.jvp, f, (x, )) + pushfwd = functools.partial(jax.jvp, f, (x,)) basis = np.eye(len(x.reshape(-1)), dtype=x.dtype).reshape(-1, *x.shape) - y, jac = jax.vmap(pushfwd, out_axes=(None, -1))((basis, )) + y, jac = jax.vmap(pushfwd, out_axes=(None, -1))((basis,)) return y, jac def get_kernel_fn_cell(): - def kernel(cell_sol_flat, physical_quad_points, cell_shape_grads, cell_JxW, cell_v_grads_JxW, *cell_internal_vars): + def kernel( + cell_sol_flat, + physical_quad_points, + cell_shape_grads, + cell_JxW, + cell_v_grads_JxW, + *cell_internal_vars, + ): """ universal_kernel should be able to cover all situations (including mass_kernel and laplace_kernel). mass_kernel and laplace_kernel are from legacy JAX-FEM. They can still be used, but not mandatory. """ - # TODO: If there is no kernel map, returning 0. is not a good choice. + # TODO: If there is no kernel map, returning 0. is not a good choice. # Return a zero array with proper shape will be better. - if hasattr(self, 'get_mass_map'): + if hasattr(self, "get_mass_map"): mass_kernel = self.get_mass_kernel(self.get_mass_map()) - mass_val = mass_kernel(cell_sol_flat, physical_quad_points, cell_JxW, *cell_internal_vars) + mass_val = mass_kernel( + cell_sol_flat, + physical_quad_points, + cell_JxW, + *cell_internal_vars, + ) else: - mass_val = 0. + mass_val = 0.0 - if hasattr(self, 'get_tensor_map'): + if hasattr(self, "get_tensor_map"): laplace_kernel = self.get_laplace_kernel(self.get_tensor_map()) - laplace_val = laplace_kernel(cell_sol_flat, cell_shape_grads, cell_v_grads_JxW, *cell_internal_vars) + laplace_val = laplace_kernel( + cell_sol_flat, + cell_shape_grads, + cell_v_grads_JxW, + *cell_internal_vars, + ) else: - laplace_val = 0. + laplace_val = 0.0 - if hasattr(self, 'get_universal_kernel'): + if hasattr(self, "get_universal_kernel"): universal_kernel = self.get_universal_kernel() - universal_val = universal_kernel(cell_sol_flat, physical_quad_points, cell_shape_grads, cell_JxW, - cell_v_grads_JxW, *cell_internal_vars) + universal_val = universal_kernel( + cell_sol_flat, + physical_quad_points, + cell_shape_grads, + cell_JxW, + cell_v_grads_JxW, + *cell_internal_vars, + ) else: - universal_val = 0. + universal_val = 0.0 return laplace_val + mass_val + universal_val - def kernel_jac(cell_sol_flat, *args): kernel_partial = lambda cell_sol_flat: kernel(cell_sol_flat, *args) - return value_and_jacfwd(kernel_partial, cell_sol_flat) # kernel(cell_sol_flat, *args), jax.jacfwd(kernel)(cell_sol_flat, *args) + return value_and_jacfwd( + kernel_partial, cell_sol_flat + ) # kernel(cell_sol_flat, *args), jax.jacfwd(kernel)(cell_sol_flat, *args) return kernel, kernel_jac def get_kernel_fn_face(ind): - def kernel(cell_sol_flat, physical_surface_quad_points, face_shape_vals, face_shape_grads, face_nanson_scale, *cell_internal_vars_surface): + def kernel( + cell_sol_flat, + physical_surface_quad_points, + face_shape_vals, + face_shape_grads, + face_nanson_scale, + *cell_internal_vars_surface, + ): """ universal_kernel should be able to cover all situations (including surface_kernel). surface_kernel is from legacy JAX-FEM. It can still be used, but not mandatory. """ - if hasattr(self, 'get_surface_maps'): - surface_kernel = self.get_surface_kernel(self.get_surface_maps()[ind]) - surface_val = surface_kernel(cell_sol_flat, physical_surface_quad_points, face_shape_vals, - face_shape_grads, face_nanson_scale, *cell_internal_vars_surface) + if hasattr(self, "get_surface_maps"): + surface_kernel = self.get_surface_kernel( + self.get_surface_maps()[ind] + ) + surface_val = surface_kernel( + cell_sol_flat, + physical_surface_quad_points, + face_shape_vals, + face_shape_grads, + face_nanson_scale, + *cell_internal_vars_surface, + ) else: - surface_val = 0. + surface_val = 0.0 - if hasattr(self, 'get_universal_kernels_surface'): + if hasattr(self, "get_universal_kernels_surface"): universal_kernel = self.get_universal_kernels_surface()[ind] - universal_val = universal_kernel(cell_sol_flat, physical_surface_quad_points, face_shape_vals, - face_shape_grads, face_nanson_scale, *cell_internal_vars_surface) + universal_val = universal_kernel( + cell_sol_flat, + physical_surface_quad_points, + face_shape_vals, + face_shape_grads, + face_nanson_scale, + *cell_internal_vars_surface, + ) else: - universal_val = 0. + universal_val = 0.0 return surface_val + universal_val def kernel_jac(cell_sol_flat, *args): # return jax.jacfwd(kernel)(cell_sol_flat, *args) kernel_partial = lambda cell_sol_flat: kernel(cell_sol_flat, *args) - return value_and_jacfwd(kernel_partial, cell_sol_flat) # kernel(cell_sol_flat, *args), jax.jacfwd(kernel)(cell_sol_flat, *args) + return value_and_jacfwd( + kernel_partial, cell_sol_flat + ) # kernel(cell_sol_flat, *args), jax.jacfwd(kernel)(cell_sol_flat, *args) return kernel, kernel_jac @@ -292,13 +405,12 @@ def kernel_jac(cell_sol_flat, *args): self.kernel_jac = kernel_jac num_surfaces = len(self.boundary_inds_list) - if hasattr(self, 'get_surface_maps'): + if hasattr(self, "get_surface_maps"): assert num_surfaces == len(self.get_surface_maps()) - elif hasattr(self, 'get_universal_kernels_surface'): - assert num_surfaces == len(self.get_universal_kernels_surface()) + elif hasattr(self, "get_universal_kernels_surface"): + assert num_surfaces == len(self.get_universal_kernels_surface()) else: assert num_surfaces == 0, "Missing definitions for surface integral" - self.kernel_face = [] self.kernel_jac_face = [] @@ -310,24 +422,37 @@ def kernel_jac(cell_sol_flat, *args): self.kernel_jac_face.append(kernel_jac_face) @timeit - def split_and_compute_cell(self, cells_sol_flat, np_version, jac_flag, internal_vars): - """Volume integral in weak form - """ + def split_and_compute_cell( + self, cells_sol_flat, np_version, jac_flag, internal_vars + ): + """Volume integral in weak form""" vmap_fn = self.kernel_jac if jac_flag else self.kernel num_cuts = 20 if num_cuts > self.num_cells: num_cuts = self.num_cells batch_size = self.num_cells // num_cuts - input_collection = [cells_sol_flat, self.physical_quad_points, self.shape_grads, self.JxW, self.v_grads_JxW, *internal_vars] + input_collection = [ + cells_sol_flat, + self.physical_quad_points, + self.shape_grads, + self.JxW, + self.v_grads_JxW, + *internal_vars, + ] if jac_flag: values = [] jacs = [] for i in range(num_cuts): if i < num_cuts - 1: - input_col = jax.tree_map(lambda x: x[i * batch_size:(i + 1) * batch_size], input_collection) + input_col = jax.tree_map( + lambda x: x[i * batch_size : (i + 1) * batch_size], + input_collection, + ) else: - input_col = jax.tree_map(lambda x: x[i * batch_size:], input_collection) + input_col = jax.tree_map( + lambda x: x[i * batch_size :], input_collection + ) val, jac = vmap_fn(*input_col) values.append(val) @@ -340,26 +465,40 @@ def split_and_compute_cell(self, cells_sol_flat, np_version, jac_flag, internal_ values = [] for i in range(num_cuts): if i < num_cuts - 1: - input_col = jax.tree_map(lambda x: x[i * batch_size:(i + 1) * batch_size], input_collection) + input_col = jax.tree_map( + lambda x: x[i * batch_size : (i + 1) * batch_size], + input_collection, + ) else: - input_col = jax.tree_map(lambda x: x[i * batch_size:], input_collection) + input_col = jax.tree_map( + lambda x: x[i * batch_size :], input_collection + ) val = vmap_fn(*input_col) values.append(val) values = np_version.vstack(values) return values - def compute_face(self, cells_sol_flat, np_version, jac_flag, internal_vars_surfaces): - """Surface integral in weak form - """ + def compute_face( + self, cells_sol_flat, np_version, jac_flag, internal_vars_surfaces + ): + """Surface integral in weak form""" if jac_flag: values = [] jacs = [] for i, boundary_inds in enumerate(self.boundary_inds_list): vmap_fn = self.kernel_jac_face[i] - selected_cell_sols_flat = cells_sol_flat[boundary_inds[:, 0]] # (num_selected_faces, num_nodes*vec + ...)) - input_collection = [selected_cell_sols_flat, self.physical_surface_quad_points[i], self.selected_face_shape_vals[i], - self.selected_face_shape_grads[i], self.nanson_scale[i], *internal_vars_surfaces[i]] + selected_cell_sols_flat = cells_sol_flat[ + boundary_inds[:, 0] + ] # (num_selected_faces, num_nodes*vec + ...)) + input_collection = [ + selected_cell_sols_flat, + self.physical_surface_quad_points[i], + self.selected_face_shape_vals[i], + self.selected_face_shape_grads[i], + self.nanson_scale[i], + *internal_vars_surfaces[i], + ] val, jac = vmap_fn(*input_collection) values.append(val) @@ -369,57 +508,96 @@ def compute_face(self, cells_sol_flat, np_version, jac_flag, internal_vars_surfa values = [] for i, boundary_inds in enumerate(self.boundary_inds_list): vmap_fn = self.kernel_face[i] - selected_cell_sols_flat = cells_sol_flat[boundary_inds[:, 0]] # (num_selected_faces, num_nodes*vec + ...)) + selected_cell_sols_flat = cells_sol_flat[ + boundary_inds[:, 0] + ] # (num_selected_faces, num_nodes*vec + ...)) # TODO: duplicated code - input_collection = [selected_cell_sols_flat, self.physical_surface_quad_points[i], self.selected_face_shape_vals[i], - self.selected_face_shape_grads[i], self.nanson_scale[i], *internal_vars_surfaces[i]] + input_collection = [ + selected_cell_sols_flat, + self.physical_surface_quad_points[i], + self.selected_face_shape_vals[i], + self.selected_face_shape_grads[i], + self.nanson_scale[i], + *internal_vars_surfaces[i], + ] val = vmap_fn(*input_collection) values.append(val) return values def compute_residual_vars_helper(self, weak_form_flat, weak_form_face_flat): res_list = [np.zeros((fe.num_total_nodes, fe.vec)) for fe in self.fes] - weak_form_list = jax.vmap(lambda x: self.unflatten_fn_dof(x))(weak_form_flat) # [(num_cells, num_nodes, vec), ...] - res_list = [res_list[i].at[self.cells_list[i].reshape(-1)].add(weak_form_list[i].reshape(-1, - self.fes[i].vec)) for i in range(self.num_vars)] + weak_form_list = jax.vmap(lambda x: self.unflatten_fn_dof(x))( + weak_form_flat + ) # [(num_cells, num_nodes, vec), ...] + res_list = [ + res_list[i] + .at[self.cells_list[i].reshape(-1)] + .add(weak_form_list[i].reshape(-1, self.fes[i].vec)) + for i in range(self.num_vars) + ] for ind, cells_list_face in enumerate(self.cells_list_face_list): - weak_form_face_list = jax.vmap(lambda x: self.unflatten_fn_dof(x))(weak_form_face_flat[ind]) # [(num_selected_faces, num_nodes, vec), ...] - res_list = [res_list[i].at[cells_list_face[i].reshape(-1)].add(weak_form_face_list[i].reshape(-1, - self.fes[i].vec)) for i in range(self.num_vars)] + weak_form_face_list = jax.vmap(lambda x: self.unflatten_fn_dof(x))( + weak_form_face_flat[ind] + ) # [(num_selected_faces, num_nodes, vec), ...] + res_list = [ + res_list[i] + .at[cells_list_face[i].reshape(-1)] + .add(weak_form_face_list[i].reshape(-1, self.fes[i].vec)) + for i in range(self.num_vars) + ] return res_list def compute_residual_vars(self, sol_list, internal_vars, internal_vars_surfaces): logger.debug(f"Computing cell residual...") - cells_sol_list = [sol[cells] for cells, sol in zip(self.cells_list, sol_list)] # [(num_cells, num_nodes, vec), ...] - cells_sol_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])(*cells_sol_list) # (num_cells, num_nodes*vec + ...) - weak_form_flat = self.split_and_compute_cell(cells_sol_flat, np, False, internal_vars) # (num_cells, num_nodes*vec + ...) - weak_form_face_flat = self.compute_face(cells_sol_flat, np, False, internal_vars_surfaces) # [(num_selected_faces, num_nodes*vec + ...), ...] + cells_sol_list = [ + sol[cells] for cells, sol in zip(self.cells_list, sol_list) + ] # [(num_cells, num_nodes, vec), ...] + cells_sol_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])( + *cells_sol_list + ) # (num_cells, num_nodes*vec + ...) + weak_form_flat = self.split_and_compute_cell( + cells_sol_flat, np, False, internal_vars + ) # (num_cells, num_nodes*vec + ...) + weak_form_face_flat = self.compute_face( + cells_sol_flat, np, False, internal_vars_surfaces + ) # [(num_selected_faces, num_nodes*vec + ...), ...] return self.compute_residual_vars_helper(weak_form_flat, weak_form_face_flat) def compute_newton_vars(self, sol_list, internal_vars, internal_vars_surfaces): logger.debug(f"Computing cell Jacobian and cell residual...") - cells_sol_list = [sol[cells] for cells, sol in zip(self.cells_list, sol_list)] # [(num_cells, num_nodes, vec), ...] - cells_sol_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])(*cells_sol_list) # (num_cells, num_nodes*vec + ...) + cells_sol_list = [ + sol[cells] for cells, sol in zip(self.cells_list, sol_list) + ] # [(num_cells, num_nodes, vec), ...] + cells_sol_flat = jax.vmap(lambda *x: jax.flatten_util.ravel_pytree(x)[0])( + *cells_sol_list + ) # (num_cells, num_nodes*vec + ...) # (num_cells, num_nodes*vec + ...), (num_cells, num_nodes*vec + ..., num_nodes*vec + ...) - weak_form_flat, cells_jac_flat = self.split_and_compute_cell(cells_sol_flat, onp, True, internal_vars) + weak_form_flat, cells_jac_flat = self.split_and_compute_cell( + cells_sol_flat, onp, True, internal_vars + ) self.V = onp.array(cells_jac_flat.reshape(-1)) # [(num_selected_faces, num_nodes*vec + ...,), ...], [(num_selected_faces, num_nodes*vec + ..., num_nodes*vec + ...,), ...] - weak_form_face_flat, cells_jac_face_flat = self.compute_face(cells_sol_flat, onp, True, internal_vars_surfaces) + weak_form_face_flat, cells_jac_face_flat = self.compute_face( + cells_sol_flat, onp, True, internal_vars_surfaces + ) for cells_jac_f_flat in cells_jac_face_flat: self.V = onp.hstack((self.V, onp.array(cells_jac_f_flat.reshape(-1)))) return self.compute_residual_vars_helper(weak_form_flat, weak_form_face_flat) def compute_residual(self, sol_list): - return self.compute_residual_vars(sol_list, self.internal_vars, self.internal_vars_surfaces) + return self.compute_residual_vars( + sol_list, self.internal_vars, self.internal_vars_surfaces + ) def newton_update(self, sol_list): - return self.compute_newton_vars(sol_list, self.internal_vars, self.internal_vars_surfaces) + return self.compute_newton_vars( + sol_list, self.internal_vars, self.internal_vars_surfaces + ) def set_params(self, params): - """Used for solving inverse problems. - """ - raise NotImplementedError("Child class must implement this function!") \ No newline at end of file + """Used for solving inverse problems.""" + raise NotImplementedError("Child class must implement this function!") diff --git a/jax_fem/utils.py b/jax_fem/utils.py index d5722cd..d46befe 100644 --- a/jax_fem/utils.py +++ b/jax_fem/utils.py @@ -11,12 +11,14 @@ def save_sol(fe, sol, sol_file, cell_infos=None, point_infos=None): sol_dir = os.path.dirname(sol_file) os.makedirs(sol_dir, exist_ok=True) out_mesh = meshio.Mesh(points=fe.points, cells={cell_type: fe.cells}) - out_mesh.point_data['sol'] = onp.array(sol, dtype=onp.float32) + out_mesh.point_data["sol"] = onp.array(sol, dtype=onp.float32) if cell_infos is not None: for cell_info in cell_infos: name, data = cell_info # TODO: vector-valued cell data - assert data.shape == (fe.num_cells,), f"cell data wrong shape, get {data.shape}, while num_cells = {fe.num_cells}" + assert data.shape == ( + fe.num_cells, + ), f"cell data wrong shape, get {data.shape}, while num_cells = {fe.num_cells}" out_mesh.cell_data[name] = [onp.array(data, dtype=onp.float32)] if point_infos is not None: for point_info in point_infos: @@ -34,14 +36,18 @@ def modify_vtu_file(input_file_path, output_file_path): fin = open(input_file_path, "r") fout = open(output_file_path, "w") for line in fin: - fout.write(line.replace('', '')) + fout.write( + line.replace( + '', + '', + ) + ) fin.close() fout.close() def read_abaqus_and_write_vtk(abaqus_file, vtk_file): - """Used for a quick inspection. Paraview can't open .inp file so we convert it to .vtu - """ + """Used for a quick inspection. Paraview can't open .inp file so we convert it to .vtu""" meshio_mesh = meshio.read(abaqus_file) meshio_mesh.write(vtk_file) @@ -51,4 +57,4 @@ def json_parse(json_filepath): args = json.load(f) json_formatted_str = json.dumps(args, indent=4) print(json_formatted_str) - return args \ No newline at end of file + return args