stanfordnqp / jaxwell Goto Github PK
View Code? Open in Web Editor NEWJaxwell is JAX + Maxwell
License: GNU General Public License v3.0
Jaxwell is JAX + Maxwell
License: GNU General Public License v3.0
Hi,
First, thank you for sharing this software !
I encounter an issue when trying to input a current density with nonzero Jx and Jy where Jx and Jy are out of phase:
Jx = b*1j and Jy = a for example. The objective is to produce a circular Right-Handed or Left-Handed polarization.
Maybe I'm doing something wrong I tried to modify the number of iterations and the accuracy (eps).
Here is the code for b:
depth = 137
b = np.zeros((172, 172, depth), dtype=np.complex128)
b = b.copy(), b.copy(), b.copy()
ax = 1j / np.sqrt(2)
ay = 1 / np.sqrt(2)
# The plane at z=49 is a source
b[0][:,:,41+8] = -1j * omega * ax
b[1][:,:,41+8] = -1j * omega * ay
PMLs of 40 blocks in each direction are used around the whole structure.
I will link more information soon.
Have a nice day
Allow them to be used, or else do not expose them at all.
After modifying from jax.experimental import optimizers
to from jax.example_libraries import optimizers
for compatibility the example colab errors out as follows:
TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'VecField'
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-8-ad210ce07e97>](https://localhost:8080/#) in <module>
6
----> 7 print(f'Objective: {f(theta, currents):.3f}')
8
29 frames
[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in f(theta, currents)
38 '''The function `f` to optimize over.'''
---> 39 x, _, _ = _model(theta, currents)
40 return loss(x)
[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in _model(theta, currents)
32 # Simulate.
---> 33 x, err = jaxwell.solve(params, z, b)
34
[/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.8/dist-packages/jax/_src/custom_derivatives.py](https://localhost:8080/#) in __call__(self, *args, **kwargs)
553 flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
--> 554 out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
555 *args_flat, out_trees=out_trees)
[/usr/local/lib/python3.8/dist-packages/jax/_src/custom_derivatives.py](https://localhost:8080/#) in bind(self, fun, fwd, bwd, out_trees, *args)
673 bwd_ = lu.wrap_init(lambda *args: bwd.call_wrapped(*args))
--> 674 outs = top_trace.process_custom_vjp_call(self, fun, fwd, bwd_, tracers,
675 out_trees=out_trees)
[/usr/local/lib/python3.8/dist-packages/jax/core.py](https://localhost:8080/#) in process_custom_vjp_call(***failed resolving arguments***)
730 with new_sublevel():
--> 731 return fun.call_wrapped(*tracers)
732
[/usr/local/lib/python3.8/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
166 try:
--> 167 ans = self.f(*args, **dict(self.params, **kwargs))
168 except:
[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve(params, z, b)
49 '''
---> 50 x, err = solve_impl(z, b, params=params)
51 return x, err[-1]
[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve_impl(z, b, adjoint, params, monitor_fn, monitor_every_n)
114 for i in range(params.max_iters):
--> 115 p, r, x, err = iter(p, r, x, z)
116 errs.append(err)
[/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py](https://localhost:8080/#) in reraise_with_filtered_traceback(*args, **kwargs)
161 try:
--> 162 return fun(*args, **kwargs)
163 except Exception as e:
[/usr/local/lib/python3.8/dist-packages/jax/_src/api.py](https://localhost:8080/#) in cache_miss(*args, **kwargs)
621 jax.config.jax_debug_nans or jax.config.jax_debug_infs):
--> 622 execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
623 out_flat = call_bind_continuation(execute(*args_flat))
[/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_call_impl_lazy(***failed resolving arguments***)
235 arg_specs = [(None, getattr(x, '_device', None)) for x in args]
--> 236 return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
237 *arg_specs)
[/usr/local/lib/python3.8/dist-packages/jax/linear_util.py](https://localhost:8080/#) in memoized_fun(fun, *args)
302 else:
--> 303 ans = call(fun, *args)
304 cache[key] = (ans, fun.stores)
[/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, keep_unused, *arg_specs)
358 else:
--> 359 return lower_xla_callable(fun, device, backend, name, donated_invars, False,
360 keep_unused, *arg_specs).compile().unsafe_call
[/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
[/usr/local/lib/python3.8/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs)
444 "for jit in {elapsed_time} sec"):
--> 445 jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
446 fun, pe.debug_info_final(fun, "jit"))
[/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
313 with TraceAnnotation(name, **decorator_kwargs):
--> 314 return func(*args, **kwargs)
315 return wrapper
[/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_jaxpr_final2(fun, debug_info)
2076 with core.new_sublevel():
-> 2077 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2078 del fun, main
[/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py](https://localhost:8080/#) in trace_to_subjaxpr_dynamic2(fun, main, debug_info)
2026 in_tracers_ = [t for t, keep in zip(in_tracers, keep_inputs) if keep]
-> 2027 ans = fun.call_wrapped(*in_tracers_)
2028 out_tracers = map(trace.full_raise, ans)
[/usr/local/lib/python3.8/dist-packages/jax/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
166 try:
--> 167 ans = self.f(*args, **dict(self.params, **kwargs))
168 except:
[/usr/local/lib/python3.8/dist-packages/jaxwell/cocg.py](https://localhost:8080/#) in iter(p, r, x, z)
25 alpha = rho / vecfield.dot(p, v)
---> 26 x += alpha * p
27 r -= alpha * v
[/usr/local/lib/python3.8/dist-packages/jax/core.py](https://localhost:8080/#) in __mul__(self, other)
604 def __rsub__(self, other): return self.aval._rsub(self, other)
--> 605 def __mul__(self, other): return self.aval._mul(self, other)
606 def __rmul__(self, other): return self.aval._rmul(self, other)
[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in deferring_binary_op(self, other)
4937 if isinstance(other, _rejected_binop_types):
-> 4938 raise TypeError(f"unsupported operand type(s) for {opchar}: "
4939 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
UnfilteredStackTrace: TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'VecField'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
[<ipython-input-8-ad210ce07e97>](https://localhost:8080/#) in <module>
5 currents = -1 * np.ones((1, 20), np.complex128)
6
----> 7 print(f'Objective: {f(theta, currents):.3f}')
8
9
[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in f(theta, currents)
37 def f(theta, currents):
38 '''The function `f` to optimize over.'''
---> 39 x, _, _ = _model(theta, currents)
40 return loss(x)
41
[<ipython-input-7-19568bb5066e>](https://localhost:8080/#) in _model(theta, currents)
31
32 # Simulate.
---> 33 x, err = jaxwell.solve(params, z, b)
34
35 return x, err, theta
[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve(params, z, b)
48 b: Same as `z` but for the `-iωJ` term.
49 '''
---> 50 x, err = solve_impl(z, b, params=params)
51 return x, err[-1]
52
[/usr/local/lib/python3.8/dist-packages/jaxwell/fdfd.py](https://localhost:8080/#) in solve_impl(z, b, adjoint, params, monitor_fn, monitor_every_n)
113 errs = []
114 for i in range(params.max_iters):
--> 115 p, r, x, err = iter(p, r, x, z)
116 errs.append(err)
117 if i % monitor_every_n == 0:
[/usr/local/lib/python3.8/dist-packages/jaxwell/cocg.py](https://localhost:8080/#) in iter(p, r, x, z)
24 v = A(p, z)
25 alpha = rho / vecfield.dot(p, v)
---> 26 x += alpha * p
27 r -= alpha * v
28 beta = vecfield.dot(r, r) / rho
[/usr/local/lib/python3.8/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in deferring_binary_op(self, other)
4936 return binary_op(*args)
4937 if isinstance(other, _rejected_binop_types):
-> 4938 raise TypeError(f"unsupported operand type(s) for {opchar}: "
4939 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
4940 return NotImplemented
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.