Giter Club home page Giter Club logo

jaxwell's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

jaxwell's Issues

Convergence issue when using complex current density at an angle.

Hi,
First, thank you for sharing this software !
I encounter an issue when trying to input a current density with nonzero Jx and Jy where Jx and Jy are out of phase:
Jx = b*1j and Jy = a for example. The objective is to produce a circular Right-Handed or Left-Handed polarization.

Maybe I'm doing something wrong I tried to modify the number of iterations and the accuracy (eps).

Here is the code for b:

depth = 137
b = np.zeros((172, 172, depth), dtype=np.complex128)
b = b.copy(), b.copy(), b.copy()

ax = 1j / np.sqrt(2)
ay = 1 / np.sqrt(2)
# The plane at z=49 is a source
b[0][:,:,41+8] = -1j * omega * ax
b[1][:,:,41+8] = -1j * omega * ay

PMLs of 40 blocks in each direction are used around the whole structure.

I will link more information soon.
Have a nice day

Colab Example Erroring

After modifying from jax.experimental import optimizers to from jax.example_libraries import optimizers for compatibility the example colab errors out as follows:

TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'VecField'
---------------------------------------------------------------------------

UnfilteredStackTrace                      Traceback (most recent call last)

[<ipython-input-8-ad210ce07e97>](https://localhost:8080/#) in <module>
      6 
----> 7 print(f'Objective: {f(theta, currents):.3f}')
      8 

29 frames

[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in f(theta, currents)
     38     '''The function `f` to optimize over.'''
---> 39     x, _, _ = _model(theta, currents)
     40     return loss(x)

[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in _model(theta, currents)
     32     # Simulate.
---> 33     x, err = jaxwell.solve(params, z, b)
     34 

[/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.8/dist-packages/jax/_src/custom_derivatives.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
    553       flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
--> 554       out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
    555                                         *args_flat, out_trees=out_trees)

[/usr/local/lib/python3.8/dist-packages/jax/_src/custom_derivatives.py](https://localhost:8080/#) in bind(self, fun, fwd, bwd, out_trees, *args)
    673     bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
--> 674     outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
    675                                              out_trees=out_trees)

[/usr/local/lib/python3.8/dist-packages/jax/core.py](https://localhost:8080/#) in process_custom_vjp_call(***failed resolving arguments***)
    730     with new_sublevel():
--> 731       return fun.call_wrapped(*tracers)
    732 

[/usr/local/lib/python3.8/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    166     try:
--> 167       ans = self.f(*args, **dict(self.params, **kwargs))
    168     except:

[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve(params, z, b)
     49   '''
---> 50   x, err = solve_impl(z, b, params=params)
     51   return x, err[-1]

[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve_impl(z, b, adjoint, params, monitor_fn, monitor_every_n)
    114   for i in range(params.max_iters):
--> 115     p, r, x, err = iter(p, r, x, z)
    116     errs.append(err)

[/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
    161     try:
--> 162       return fun(*args, **kwargs)
    163     except Exception as e:

[/usr/local/lib/python3.8/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
    621         jax.config.jax_debug_nans or jax.config.jax_debug_infs):
--> 622       execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
    623       out_flat = call_bind_continuation(execute(*args_flat))

[/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl_lazy(***failed resolving arguments***)
    235     arg_specs = [(None, getattr(x, '_device', None)) for x in args]
--> 236   return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
    237                       *arg_specs)

[/usr/local/lib/python3.8/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
    302     else:
--> 303       ans = call(fun, *args)
    304       cache[key] = (ans, fun.stores)

[/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
    358   else:
--> 359     return lower_xla_callable(fun, device, backend, name, donated_invars, False,
    360                               keep_unused, *arg_specs).compile().unsafe_call

[/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    313     with TraceAnnotation(name, **decorator_kwargs):
--> 314       return func(*args, **kwargs)
    315     return wrapper

[/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
    444                         "for jit in {elapsed_time} sec"):
--> 445     jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
    446         fun, pe.debug_info_final(fun, "jit"))

[/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    313     with TraceAnnotation(name, **decorator_kwargs):
--> 314       return func(*args, **kwargs)
    315     return wrapper

[/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
   2076     with core.new_sublevel():
-> 2077       jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2078     del fun, main

[/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
   2026     in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2027     ans = fun.call_wrapped(*in_tracers_)
   2028     out_tracers = map(trace.full_raise, ans)

[/usr/local/lib/python3.8/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    166     try:
--> 167       ans = self.f(*args, **dict(self.params, **kwargs))
    168     except:

[/usr/local/lib/python3.8/dist-packages/jaxwell/cocg.py](https://localhost:8080/#) in iter(p, r, x, z)
     25     alpha = rho / vecfield.dot(p, v)
---> 26     x += alpha * p
     27     r -= alpha * v

[/usr/local/lib/python3.8/dist-packages/jax/core.py](https://localhost:8080/#) in __mul__(self, other)
    604   def __rsub__(self, other): return self.aval._rsub(self, other)
--> 605   def __mul__(self, other): return self.aval._mul(self, other)
    606   def __rmul__(self, other): return self.aval._rmul(self, other)

[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in deferring_binary_op(self, other)
   4937     if isinstance(other, _rejected_binop_types):
-> 4938       raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4939                       f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")

UnfilteredStackTrace: TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'VecField'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)

[<ipython-input-8-ad210ce07e97>](https://localhost:8080/#) in <module>
      5 currents = -1 * np.ones((1, 20), np.complex128)
      6 
----> 7 print(f'Objective: {f(theta, currents):.3f}')
      8 
      9 

[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in f(theta, currents)
     37   def f(theta, currents):
     38     '''The function `f` to optimize over.'''
---> 39     x, _, _ = _model(theta, currents)
     40     return loss(x)
     41 

[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in _model(theta, currents)
     31 
     32     # Simulate.
---> 33     x, err = jaxwell.solve(params, z, b)
     34 
     35     return x, err, theta

[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve(params, z, b)
     48     b: Same as `z` but for the `-iωJ` term.
     49   '''
---> 50   x, err = solve_impl(z, b, params=params)
     51   return x, err[-1]
     52 

[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve_impl(z, b, adjoint, params, monitor_fn, monitor_every_n)
    113   errs = []
    114   for i in range(params.max_iters):
--> 115     p, r, x, err = iter(p, r, x, z)
    116     errs.append(err)
    117     if i % monitor_every_n == 0:

[/usr/local/lib/python3.8/dist-packages/jaxwell/cocg.py](https://localhost:8080/#) in iter(p, r, x, z)
     24     v = A(p, z)
     25     alpha = rho / vecfield.dot(p, v)
---> 26     x += alpha * p
     27     r -= alpha * v
     28     beta = vecfield.dot(r, r) / rho

[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in deferring_binary_op(self, other)
   4936       return binary_op(*args)
   4937     if isinstance(other, _rejected_binop_types):
-> 4938       raise TypeError(f"unsupported operand type(s) for {opchar}: "
   4939                       f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
   4940     return NotImplemented

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.