Giter Club home page Giter Club logo

jax-am's Introduction

A GPU-accelerated differentiable simulation toolbox for additive manufacturing (AM) based on JAX.

JAX-AM

Doc PyPI Github Star Github Fork License

JAX-AM is a collection of several numerical tools, currently including Discrete Element Method (DEM), Lattice Boltzmann Methods (LBM), Computational Fluid Dynamics (CFD), Phase Field Method (PFM) and Finite Element Method (FEM), that cover the analysis of the Process-Structure-Property relationship in AM.

Our vision is to share with the AM community a free, open-source (under the GPL-3.0 License) software that facilitates the relevant computational research. In the JAX ecosystem, we hope to emphasize the potential of JAX for scientific computing. At the same time, AI-enabled research in AM can be made easy with JAX-AM.

🔥 Join us for the development of JAX-AM!

Discrete Element Method (DEM)

DEM simulation can be used for simulating powder dynamics in metal AM.

Free falling of 64,000 spherical particles.

Lattice Boltzmann Methods (LBM)

LBM can simulate melt pool dynamics with a free-surface model.

A powder bed fusion process considering surface tension, Marangoni effect and recoil pressure.

Computational Fluid Dynamics (CFD)

CFD helps to understand the AM process by solving the (incompressible) Navier-Stokes equations for velocity, pressure and temperature.

Melt pool dynamics.

Phase Field Method (PFM)

PFM models the grain development that is critical to form the structure of the as-built sample.

Microstructure evolution.

Directional solidification with isotropic (left) and anisotropic (right) grain growth.

Finite Element Method (FEM)

📣 📣 📣

$${\color{red}We \space have \space decided \space to \space move \space the \space development \space of \space FEM \space to \space a \space separate \space repository.}$$

Check JAX-FEM. This is a design decision that aims to push the FEM module into a general-purpose, independent package that works for problems beyond additive manufacturing. Therefore, the code related to FEM in this repository (JAX-AM) will NOT be updated in the future.

Documentation

Please see the web documentation for the installation and use of this project.

License

This project is licensed under the GNU General Public License v3 - see the LICENSE for details.

Citations

If you found this library useful in academic or industry work, we appreciate your support if you consider 1) starring the project on Github, and 2) citing relevant papers:

@article{xue2023jax,
  title={JAX-FEM: A differentiable GPU-accelerated 3D finite element solver for automatic inverse design and mechanistic data science},
  author={Xue, Tianju and Liao, Shuheng and Gan, Zhengtao and Park, Chanwook and Xie, Xiaoyu and Liu, Wing Kam and Cao, Jian},
  journal={Computer Physics Communications},
  pages={108802},
  year={2023},
  publisher={Elsevier}
}
@article{xue2022physics,
  title={Physics-embedded graph network for accelerating phase-field simulation of microstructure evolution in additive manufacturing},
  author={Xue, Tianju and Gan, Zhengtao and Liao, Shuheng and Cao, Jian},
  journal={npj Computational Materials},
  volume={8},
  number={1},
  pages={201},
  year={2022},
  publisher={Nature Publishing Group UK London}
}

jax-am's People

Contributors

cgh20171006 avatar drzgan avatar ijku avatar itk22 avatar jinchoi-git avatar mohamadrezash204 avatar qiwei-chen avatar shuhengliao avatar snms95 avatar tianjuxue avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jax-am's Issues

Tolerances in load boundary condition definitions

Hello,

I'm currently working on implementing point loads in FEM simulations using your library and have some questions about the tolerance (atol) used for defining load boundary conditions. I noticed that in some examples, including applications/fem/top_opt/box.py, the load location function uses certain constants multiplied by the domain dimensions:

Lx, Ly, Lz = 2., 0.5, 1.
Nx, Ny, Nz = 80, 20, 40

meshio_mesh = box_mesh(Nx, Ny, Nz, Lx, Ly, Lz, data_path)
jax_mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['hexahedron'])

def fixed_location(point):
    return np.isclose(point[0], 0., atol=1e-5)

# Note the tolerance here:    
def load_location(point):
    return np.logical_and(np.isclose(point[0], Lx, atol=1e-5), np.isclose(point[2], 0., atol=0.1*Lz+1e-5))

def dirichlet_val(point):
    return 0.

def neumann_val(point):
    return np.array([0., 0., -1e6])

From my understanding, there are two factors to consider when determining this tolerance. First, we need the tolerance (atol = c*Lz) to be larger than the mesh edge length (dz = Lz/Nz) to ensure that at least one node falls within this range. However, I also understand that the multiplication factor c should ideally be decoupled from the mesh resolution to ensure consistent results across different mesh sizes.

Could you provide some guidance on choosing suitable tolerances and discuss any recommended practices or considerations? For instance, is there an optimal range or formula to calculate c in relation to the mesh size or problem scale?

Also, I'm curious if there might be a more systematic way to handle this tolerance selection within the library itself, such as by providing a helper function to calculate suitable tolerances based on mesh size and problem scale. I believe this would greatly assist users in correctly applying loads and constraints to their simulations.

Thank you in advance for your help and guidance. I appreciate your work on this library and look forward to your response.

Improving autodiff capabilities of JAX-FEM module

Hi,

Ideally, the solvers should be chosen from JAX to have easy access to all of JAX's transformations. However, we know that external libraries like PETSc have very efficient and diverse solvers, which can be treated as black-box solvers in JAX. So, we would like to use these solvers but at the same time maintain as much of JAX's capabilities as possible.

Currently, the solver part is implemented (ad_wrapper function in solver.py) by specifying the custom_vjp tag, which overrides the default traced auto-diff. But, this does not prevent JAX from tracing into the function. Since this function uses external libraries, this would result in errors if using JAX-FEM inside bigger pipelines. Further, providing only the vjp rule prevents access to forward autodiff as well.

As of now, there are two ways to handle this:

  1. Create a Primitive and assign individual transformation rules (Best for performance but probably very hard!) - We will have to do this wherever we choose external libraries.
  2. Use callbacks for all external calls (This causes those operations to run in CPU but allows access to all of JAX's transforms and is much easier and cleaner). Specifying a custom_jvp + pure_callback would ensure that the entire package is consistent.

I suggest we discuss this possibility during our meeting.

Feature request: Adjustable verbosity

Hello @tianjuxue,
I would like to request a feature where the user has the control over how much information is printed when running an FEM. In my use case, printing all the computation details at every iteration causes a lot of clutter and I would love to be able to control the output verbosity, like it's done in many other softwares.

Maybe specify specific package version in setup.py?

I am trying to run the demo, but failed with: ImportError: cannot import name 'config' from 'jax.config' (.../envs/jax-am/lib/python3.9/site-packages/jax/config.py)

It seems to be an error from mismatched (i.e., newer) jax version, and I find the required packages in setup.py just decalre the package name without detailed version.

So it would be better to specify the package version in setup.py?

Compatibility issue with jaxopt

Dear @tianjuxue,
Once again, thank you for this package. It's great! Recently, I was trying to implement an optimisation loop using jaxopt and I encountered an error which I think might be related to the inner workings of the JAX-AM package. Here is a minimal example which exposes this problem on my system. I adapted it from the plate.py example in the applications/fem/top_opt and you should be able to run this example within that subdirectory.

import jax.numpy as np
import jax
import os
import glob

import jaxopt

from jax_am.fem.generate_mesh import Mesh
from jax_am.fem.solver import ad_wrapper
from jax_am.fem.utils import save_sol
from jax_am.common import rectangle_mesh

from fem_model import Elasticity

problem_name = 'plate'
root_path = os.path.join(os.path.dirname(__file__), 'data')

files = glob.glob(os.path.join(root_path, f'vtk/{problem_name}/*'))
for f in files:
    os.remove(f)

L = 60.
W = 30.
N_L = 60
N_W = 30
meshio_mesh = rectangle_mesh(N_L, N_W, L, W)
jax_mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict['quad'])


def fixed_location(point):
    return np.isclose(point[0], 0., atol=1e-5)


def load_location(point):
    return np.logical_and(np.isclose(point[0], L, atol=1e-5),
                          np.isclose(point[1], 0., atol=1. + 1e-5))


def dirichlet_val(point):
    return 0.


def neumann_val(point):
    return np.array([0., -1])


dirichlet_bc_info = [[fixed_location] * 2, [0, 1], [dirichlet_val] * 2]
neumann_bc_info = [[load_location], [neumann_val]]
problem = Elasticity(jax_mesh,
                     vec=2,
                     dim=2,
                     ele_type='QUAD4',
                     dirichlet_bc_info=dirichlet_bc_info,
                     neumann_bc_info=neumann_bc_info,
                     additional_info=(problem_name, ))
fwd_pred = ad_wrapper(problem, linear=True, use_petsc=True)


def J_fn(dofs, params):
    """J(u, p)
        """
    sol = dofs.reshape((problem.num_total_nodes, problem.vec))
    compliance = problem.compute_compliance(neumann_val, sol)
    return compliance


def J_total(params):
    """J(u(p), p)
    """
    sol = fwd_pred(params)
    dofs = sol.reshape(-1)
    obj_val = J_fn(dofs, params)
    return obj_val


outputs = []


def output_sol(params, obj_val):
    print("\nOutput solution - need to solve the forward problem again...")
    sol = fwd_pred(params)
    vtu_path = os.path.join(
        root_path, f'vtk/{problem_name}/sol_{output_sol.counter:03d}.vtu')
    save_sol(problem,
             sol,
             vtu_path,
             cell_infos=[('theta', problem.full_params[:, 0])])
    print(f"compliance = {obj_val}")
    outputs.append(obj_val)
    output_sol.counter += 1


output_sol.counter = 0

vf = 0.5


def objectiveHandle(rho):
    J, dJ = jax.value_and_grad(J_total)(rho)
    output_sol(rho, J)
    return J, dJ


# Initialize params
params = vf * np.ones((len(problem.flex_inds), 1))
print(f'Initial compliance is: {J_total(params):.5e}')

solver = jaxopt.LBFGS(fun=objectiveHandle, value_and_grad=True, has_aux=False)

# Trainiting loop
opt_state = solver.init_state(params)

for _ in range(2):
    params, opt_state = solver.update(params=params, state=opt_state)

Here is the traceback that I am getting:

Traceback (most recent call last):
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_mre.py", line 114, in <module>
    params, opt_state = solver.update(params=params, state=opt_state)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 348, in update
    ls_state = zoom_linesearch(f=self._value_and_grad_with_aux,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/zoom_linesearch.py", line 467, in zoom_linesearch
    state = lax.while_loop(lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed),
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1150, in while_loop
    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py", line 1133, in _create_jaxpr
    body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2150, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/zoom_linesearch.py", line 385, in body
    (phi_i, dphi_i, g_i), aux_i = restricted_func_and_grad(a_i)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/zoom_linesearch.py", line 329, in restricted_func_and_grad
    (phi, aux), g = f_value_and_grad(xkp1, *args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/base.py", line 62, in value_and_grad_with_aux
    v, g = value_and_grad(*a, **kw)
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_mre.py", line 99, in objectiveHandle
    J, dJ = jax.value_and_grad(J_total)(rho)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/api.py", line 718, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/api.py", line 2174, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 139, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_mre.py", line 70, in J_total
    sol = fwd_pred(params)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 614, in __call__
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 763, in bind
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 396, in process_custom_vjp_call
    res_and_primals_out = fwd.call_wrapped(*fwd_in)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 545, in f_fwd
    sol = fwd_pred(params)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 614, in __call__
    out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/custom_derivatives.py", line 763, in bind
    outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 1985, in process_custom_vjp_call
    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2172, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 541, in fwd_pred
    sol = solver(problem, linear=linear, use_petsc=use_petsc)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 473, in solver
    return solver_row_elimination(problem, linear, precond, initial_guess, use_petsc)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 228, in solver_row_elimination
    res_vec, A_fn = newton_update_helper(dofs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 220, in newton_update_helper
    res_vec = problem.newton_update(dofs.reshape(sol_shape)).reshape(-1)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/core.py", line 665, in newton_update
    return self.compute_newton_vars(sol, **self.internal_vars)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/core.py", line 640, in compute_newton_vars
    weak_form, cells_jac = self.split_and_compute_cell(cells_sol, onp, True, **internal_vars)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/core.py", line 519, in split_and_compute_cell
    values = np_version.vstack(values)
  File "<__array_function__ internals>", line 200, in vstack
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/numpy/core/shape_base.py", line 293, in vstack
    arrs = atleast_2d(*tup)
  File "<__array_function__ internals>", line 200, in atleast_2d
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/numpy/core/shape_base.py", line 121, in atleast_2d
    ary = asanyarray(ary)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax/_src/core.py", line 598, in __array__
    raise TracerArrayConversionError(self)
jax._src.traceback_util.UnfilteredStackTrace: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[90,4,2])>with<DynamicJaxprTrace(level=1/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_mre.py", line 114, in <module>
    params, opt_state = solver.update(params=params, state=opt_state)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/lbfgs.py", line 348, in update
    ls_state = zoom_linesearch(f=self._value_and_grad_with_aux,
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/zoom_linesearch.py", line 467, in zoom_linesearch
    state = lax.while_loop(lambda state: (~state.done) & (state.i <= maxiter) & (~state.failed),
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/zoom_linesearch.py", line 385, in body
    (phi_i, dphi_i, g_i), aux_i = restricted_func_and_grad(a_i)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/zoom_linesearch.py", line 329, in restricted_func_and_grad
    (phi, aux), g = f_value_and_grad(xkp1, *args, **kwargs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jaxopt/_src/base.py", line 62, in value_and_grad_with_aux
    v, g = value_and_grad(*a, **kw)
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_mre.py", line 99, in objectiveHandle
    J, dJ = jax.value_and_grad(J_total)(rho)
  File "/home/igork/IRP/NeurOpt/jax-am-forked/applications/fem/top_opt/jaxopt_mre.py", line 70, in J_total
    sol = fwd_pred(params)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 545, in f_fwd
    sol = fwd_pred(params)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 541, in fwd_pred
    sol = solver(problem, linear=linear, use_petsc=use_petsc)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 473, in solver
    return solver_row_elimination(problem, linear, precond, initial_guess, use_petsc)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 228, in solver_row_elimination
    res_vec, A_fn = newton_update_helper(dofs)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/solver.py", line 220, in newton_update_helper
    res_vec = problem.newton_update(dofs.reshape(sol_shape)).reshape(-1)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/core.py", line 665, in newton_update
    return self.compute_newton_vars(sol, **self.internal_vars)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/core.py", line 640, in compute_newton_vars
    weak_form, cells_jac = self.split_and_compute_cell(cells_sol, onp, True, **internal_vars)
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/jax_am/fem/core.py", line 519, in split_and_compute_cell
    values = np_version.vstack(values)
  File "<__array_function__ internals>", line 200, in vstack
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/numpy/core/shape_base.py", line 293, in vstack
    arrs = atleast_2d(*tup)
  File "<__array_function__ internals>", line 200, in atleast_2d
  File "/home/igork/miniconda3/envs/jax-nrto/lib/python3.10/site-packages/numpy/core/shape_base.py", line 121, in atleast_2d
    ary = asanyarray(ary)
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float64[90,4,2])>with<DynamicJaxprTrace(level=1/1)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

So it seems that the incompatibility arises in jax/fem/core.py. Have you encountered a similar error before or do you have any ideas how it could be solved? I'm suspecting that the issue might have to do with the fact that jaxopt introduces its own decorators for implicit differentiation and they might clash with ad_wrapper.

Error in results for spatially varying material property

Hello, I am attempting to solve a 2D plane strain boundary value problem using a linear elastic material with Young's modulus that varies spatially, expressed as E = E(x,y) or E(nodes). I have used 'QUAD4' as element type. The basis functions utilized for the solution are assumed to support the Young's modulus. Below is the code I have employed to convert the values from nodes to the elements:

class Elasticity(FEM):
    def custom_init(self):
        """ Override base class method
        """
        self.flex_inds = np.arange(len(self.points))
        
    def get_tensor_map(self):
        def stress(u_grad, theta):
            nu = 0.3
            E = theta[0]
            epsilon = 0.5*(u_grad + u_grad.T)
            eps11 = epsilon[0,0]
            eps22 = epsilon[1,1]
            eps12 = epsilon[0,1]
            sig11 = E/((1 + nu)*(1 - 2*nu))*((1-nu)*eps11 + nu*eps22)
            sig22 = E/((1 + nu)*(1 - 2*nu))*(nu*eps11 + (1-nu)*eps22) 
            sig12 = E/((1 + nu)*2)*eps12
            sigma = np.array([[sig11, sig12],[sig12, sig22]])
            return sigma
        return stress
    
    def set_params(self, params):
        NCA = self.cells
        NCA = np.array(NCA)
        
        def compute_params(e):
            a = NCA[e,:]
            Ee = params[a]
            E = np.matmul(self.shape_vals, Ee)
            return np.transpose(E)
        
        result = jax.vmap(compute_params)(np.arange(self.num_cells))
        result = np.transpose(result, axes=(0,2,1))
        
        self.full_params = params
        self.internal_vars['laplace'] = [result]

When I used a uniform material parameter distribution (param = 0.01*np.ones((len(problem.flex_inds), 1))), the results matched those from a benchmark study. However, when I introduced a spatially varying (heterogeneous) distribution for the material parameter, the accuracy of the results decreased.

I attempted to find solutions in existing examples, but most of them assumed that the varying field had a constant value within each element, which does not apply to my case.

I would appreciate any guidance or suggestions to identify where I might have made errors in my implementation for handling the spatially varying material parameter and to improve the accuracy of the results.

Package Issues with jax-am

When Trying to pip install the package getting an error sort of like this:
ERROR: Cannot install jax-am==0.0.1 and jax-am==0.0.2 because these package versions have conflicting dependencies.

Please Help me out with this on how to rectify this error.

Package_error_Jax-am

Solver choices

I looked through your paper: https://arxiv.org/pdf/2212.00964.pdf
and it does not seem to mention the solver choices.

https://github.com/tianjuxue/jax-am/blob/main/jax_am/fem/tests/fenicsx_gold.py#L100-L106 uses un-preconditioned https://petsc.org/release/manualpages/KSP/KSPBICG/
with no null-space set, while as far as I can tell, JAX-FEM
uses a Jacobi preconditioner with BICGSTAB (as far as I can tell from: https://github.com/tianjuxue/jax-am/blob/main/jax_am/fem/solver.py#L31)

It would be interesting to see more detailed breakdowns of the timers as well, as in:

  1. How much time does it take to solve the problem for the various problems
  2. How much time is spent assembling inside the Newton iteration

Issue with installation

I cloned the repo, created a bare (mini)conda environment.
The only packages I installed were jax and pip.
When I tried to install using pip install ., the following error came up.

Failed to build fenics-basix **ERROR**: Could not build wheels for fenics-basix, which is required to install pyproject.toml-based projects

Potentially harmful use of assert statements in solver.py

Dear @tianjuxue,
I noticed a problematic use of assert statements in the FEM's solver.py code which could lead to failures. Here is the relevant code:

def jax_solve(problem, A_fn, b, x0, precond):
    pc = get_jacobi_precond(jacobi_preconditioner(problem)) if precond else None
    x, info = jax.scipy.sparse.linalg.bicgstab(A_fn, b, x0=x0, M=pc, tol=1e-10, atol=1e-10, maxiter=10000)

    # Verify convergence
    err = np.linalg.norm(A_fn(x) - b)
    print(f"JAX scipy linear solve res = {err}")

    # HERE IS THE PROBLEMATIC ASSERT:
    assert err < 0.1, f"JAX linear solver failed to converge with err = {err}"

    return x

The assert statement above acts as a control flow statement and requires concrete values to work properly. In one of my use cases, the err variable is actually a:

Traced<ShapedArray(float64[])>with<BatchTrace(level=1/0)> with
 val = Array([0.], dtype=float64)
 batch_dim = 0

and the assert statement breaks. I am not exactly sure why in my case err is not a concrete value but I do feel like having pure asserts here doesn't fit with the functional purity phiolosophy of JAX. I think a feasible alternative for this could be using jax.lax.cond or perhaps assertions from the Chex library.

0 Solution for small problems

  • If we try to run a simple linear elastic problem with <=10 elements [10x10], the solution is always zero.
  • The element stiffness matrix and assembly are ok.
  • The issue seems to be with the list of values in split_and_compute_cell

Memory leakage during optimization

It seems there is come memory leakage somewhere.
It can be clearly seen if the topopt example is run and you monitor the memory. It is steadily increasing. This results in OOM-KILL events on HPCs.

  1. Part of the reason is not destroying the PETSc objects. This is easily fixed.
  2. We need some tool to do this everytime we make changes [Scalene perhaps?]
  3. There is another source that I have not been able to pinpoint yet

Memory leakage problem in \applications\fem\thermal

I found that when I run the thermal simulation the memory usage increase very fast. It might be caused by the lists that store inner_faces, external_faces, and all_faces. Python's List data structure is reported to consume larger memory than other data structure, like numpy.array. I am trying to replace the lists with numpy.array to see if the memory usage could be diminished.

Cannot import 'config' from 'jax.config'

Follow the instruction "https://jax-am.readthedocs.io/en/latest/installation.html" for installation jax-am.
When run the example "python -m applications.fem.demo.example", the error occurs.

Traceback (most recent call last):
  File "/home/cl_linux/anaconda3/envs/jax-am/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/cl_linux/anaconda3/envs/jax-am/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/cl_linux/jax-am/applications/fem/demo/example.py", line 5, in <module>
    from jax_am.fem.models import LinearElasticity
  File "/home/cl_linux/jax-am/jax_am/fem/models.py", line 5, in <module>
    from jax_am.fem.core import FEM
  File "/home/cl_linux/jax-am/jax_am/fem/core.py", line 14, in <module>
    from jax.config import config
ImportError: cannot import name 'config' from 'jax.config' (/home/cl_linux/anaconda3/envs/jax-am/lib/python3.9/site-packages/jax/config.py)

Compute the anisotropy

Previously, I had been using an earlier version of the allen-cahn file for programming computations. Initially, I was solely using the CPU for calculations, and there were no issues. However, after switching to a GPU, I found that it was taking a significant amount of time to compile when performing anisotropy calculations. As a result, I chose to upgrade the file to the latest version, and I no longer encountered the issue of long compilation times during calculations. However, during the calculation process, I encountered the following error (which did not occur in the original version). Could you please help me investigate the cause of this error? Or inform me of any other changes made when calling 'state_rhs' between the previous and current versions, aside from modifying the allen-cahn file.

Traceback (most recent call last):
File "/home/cgh/projects/HTTD-PF/application/SingleCrystal_2/NotContainBase.py", line 333, in
output_next.inspect_sol(eta, eta0_next, TEMP, ts, increment + step_inds + 1)
File "/home/cgh/mambaforge3/lib/python3.10/site-packages/pf_am/post_process.py", line 35, in inspect_sol
raise ValueError(f"Found np.inf or np.nan in pf_sol - stop the program")
ValueError: Found np.inf or np.nan in pf_sol - stop the program

Infinite while loop for non-linear case

In solver.py, the following lines

tol = 1e-6
while res_val > tol:
    dofs = linear_incremental_solver(problem, res_vec, A_fn, dofs,
                                     precond, use_petsc)
    res_vec, A_fn = newton_update_helper(dofs)
    # test_jacobi_precond(problem, jacobi_preconditioner(problem, dofs), A_fn)
    res_val = np.linalg.norm(res_vec)
    logger.debug(f"res l_2 = {res_val}")

can result in an infinite loop if there are convergence issues.
This should be recified so that we break out of the loop after set number of iterations!

Issues with installing and using jax-am

Hello,
I am having issues with installing and using jax-am and I would highly appreciate any help you can provide.

I was able to fully install jax-am on my Apple silicon macbook (by following the instructions at https://jax-am.readthedocs.io/en/latest/installation.html), but when I try to run the hyperelasticity example demos.fem.hyperelasticity.example, I get the following error:

[2023-07-20 01:01:53,936 - INFO] - Creating sparse matrix with scipy...
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/vt/Library/CloudStorage/OneDrive-purdue.edu/Research/jax-fem/jax-am/demos/fem/hyperelasticity/example.py", line 83, in <module>
    sol = solver(problem, use_petsc=True)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vt/Library/CloudStorage/OneDrive-purdue.edu/Research/jax-fem/jax-am/jax_am/fem/solver.py", line 823, in solver
    return solver_row_elimination(problem, linear, precond, initial_guess,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vt/Library/CloudStorage/OneDrive-purdue.edu/Research/jax-fem/jax-am/jax_am/fem/solver.py", line 358, in solver_row_elimination
    res_vec, A_fn = newton_update_helper(dofs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vt/Library/CloudStorage/OneDrive-purdue.edu/Research/jax-fem/jax-am/jax_am/fem/solver.py", line 342, in newton_update_helper
    A_fn = get_A_fn(problem, use_petsc)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/vt/Library/CloudStorage/OneDrive-purdue.edu/Research/jax-fem/jax-am/jax_am/fem/solver.py", line 300, in get_A_fn
    A = PETSc.Mat().createAIJ(size=A_sp_scipy.shape,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "petsc4py/PETSc/Mat.pyx", line 357, in petsc4py.PETSc.Mat.createAIJ
  File "petsc4py/PETSc/petscmat.pxi", line 836, in petsc4py.PETSc.Mat_AllocAIJ
  File "petsc4py/PETSc/petscmat.pxi", line 804, in petsc4py.PETSc.Mat_AllocAIJ_CSR
  File "petsc4py/PETSc/arraynpy.pxi", line 137, in petsc4py.PETSc.iarray_i
  File "petsc4py/PETSc/arraynpy.pxi", line 130, in petsc4py.PETSc.iarray
TypeError: Cannot cast array data from dtype('int64') to dtype('int32') according to the rule 'safe'

As an alternative, I tried installing jax-am on a linux cluster. However, there I receive another error during the installation:

  note: This error originates from a subprocess, and is likely not a problem with pip.
  ERROR: Failed building wheel for fenics-basix
Failed to build fenics-basix
ERROR: Could not build wheels for fenics-basix, which is required to install pyproject.toml-based projects

I believe I am following the instructions at https://jax-am.readthedocs.io/en/latest/installation.html verbatim. I start with a fresh conda environment, install pip and then do pip install .. Then I install petsc4py using conda. (I have also tried installing petsc4py before jax-am, but that didn't work either).

Thank you for your attention.

A better preconditioner for jax solver

  • At the moment, for some BCs (especially in TopOpt), the petsc version converges but the jax_solve does not.
  • From our tests, we understood that this is due to the lack of a good preconditioner. PETSc solver has ilu as the preconditioner while the jax solver is using a simple jacobi

Can we implement an ILU preconditioner for jax solver ?

Improving FEniCSx benchmark

In

# Remark(Tianju): "problem" should be better defined outside of the for loop,
# but I wasn't able to find a way to assign Dirichlet values on the top boundary.

you state that you have problems with modifying the value of the Dirichlet BC, and hence have to redefine the solver at every time step.

Consider:
https://jsdokken.com/dolfinx-tutorial/chapter2/hyperelasticity.html

which uses a dolfinx.fem.Constant to assign a value to the traction (similarly you can do this for a dirichlet BC), which means that the solver can be defined outside of the loop.

If you need help implementing this I can help you

TypeError: get_bound_values() missing 1 required positional argument: 'values'

The issue is caused in this line:

T_bc = get_bound_values(T,BCs,dX)

TypeError Traceback (most recent call last)
/home/xie/projects/jax-am/src/cfd/examples/AM/AM_basic.ipynb Cell 7 in <cell line: 4>()
[2] T0 = np.zeros_like(xc) + T_ref
[3] T0_top = T0[:,:,[-1],:]
----> [4] k = update_cond(cond_func,T0,[[1,1,1,1,1,1],[0.,0.,0.,0.,0.,0]],dX)
[5] dt = 1e-5
[6] t = 0.

File ~/projects/jax-am/src/cfd/setupAM.py:29, in update_cond(temp_dependent_cond, T, BCs, dX)
28 def update_cond(temp_dependent_cond,T,BCs,dX):
---> 29 T_bc = get_bound_values(T,BCs,dX)
30 cond_surf_x = np.concatenate((temp_dependent_cond((T_bc[0]+T[[0],:,:])/2.),
31 temp_dependent_cond((T[1:,:,:] + T[:-1,:,:])/2.),
32 temp_dependent_cond((T_bc[1]+T[[-1],:,:])/2.)),axis=0)
34 cond_surf_y = np.concatenate((temp_dependent_cond((T_bc[2]+T[:,[0],:])/2.),
35 temp_dependent_cond((T[:,1:,:] + T[:,:-1,:])/2.),
36 temp_dependent_cond((T_bc[3]+T[:,[-1],:])/2.)),axis=1)

TypeError: get_bound_values() missing 1 required positional argument: 'values'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.