Comments (8)
from ferminet.
from ferminet.
Thanks for the quick response!
I could work around this issue by modifying the logsumdet
function.
The following should work as a drop-in replacement:
def logsumdet(
xs: Sequence[jnp.ndarray],
w: Optional[jnp.ndarray] = None) -> jnp.ndarray:
# Special case if there is only one electron in any channel
# We can avoid the log(0) issue by not going into the log domain
dets = [x.reshape(-1) for x in xs if x.shape[-1] == 1]
dets = functools.reduce(
lambda a, b: a*b, dets
) if len(dets) > 0 else 1
slogdets = [slogdet(x) for x in xs if x.shape[-1] > 1]
maxlogdet = 0
if len(slogdets) > 0:
sign_in, logdet = functools.reduce(
lambda a, b: (a[0]*b[0], a[1]+b[1]), slogdets
)
maxlogdet = jnp.max(logdet)
det = sign_in * dets * jnp.exp(logdet - maxlogdet)
else:
det = dets
if w is None:
result = jnp.sum(det)
else:
result = jnp.dot(det, w)
sign_out = jnp.sign(result)
log_out = jnp.log(jnp.abs(result)) + maxlogdet
return sign_out, log_out
This avoids working in the log domain if it's not necessary.
from ferminet.
from ferminet.
I'm curious about your technical set up (jax, jaxlib, and CUDA versions). The hydrogen atom works fine for me (several thousand iterations), though I agree the jax version is not as careful around this as the TF version.
The energy looks essentially converged by this point so I'm surprised you're hitting a NaN. Given the wavefunction shouldn't(!) contain any nodes, it can only happen far from the nucleus, so an MCMC configuration must have wandered quite a way away. (As an aside, the default step size is way too small for hydrogen -- you should increase it to at least 0.5).
from ferminet.
Sure, here are the details:
Lib | Version |
---|---|
jax | 0.2.10 |
jaxlib | 0.1.62 |
CUDA | 10.2 |
While the mean of the energy looked good, it wasn't converged yet (high variance). I checked the position of the electron and it was in a radius of 3A around the nucleus which isn't very afaik.
The error also occurs at higher step sizes.
from ferminet.
This error also occurs for Lithium far from the optimum.
I0306 02:47:37.867998 139629635032896 train.py:461] Step 00093: -6.6119 E_h, pmove=0.76
I0306 02:47:38.068941 139629635032896 train.py:461] Step 00094: -6.6105 E_h, pmove=0.76
I0306 02:47:38.270185 139629635032896 train.py:461] Step 00095: -6.6636 E_h, pmove=0.76
I0306 02:47:38.471880 139629635032896 train.py:461] Step 00096: nan E_h, pmove=0.76
I0306 02:47:38.671543 139629635032896 train.py:461] Step 00097: nan E_h, pmove=0.00
I0306 02:47:38.870149 139629635032896 train.py:461] Step 00098: nan E_h, pmove=0.00
Though, only if one sets full_det
to False
. As far as I can tell #23 fixes this.
On a side note: Is there a particular reason why full_det
defaults to True
? Isn't a wavefunction only antisymmetric with respect to permutation of electrons of the same spin? Also, it does not align with the definition of FermiNet in the papers.
from ferminet.
Fixed in #23 .
from ferminet.
Related Issues (20)
- Question about exact_cusp function HOT 1
- Installation Error HOT 7
- How does training time scale w.r.t. model size? HOT 1
- Jax install - issue with correct version number HOT 1
- AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr' HOT 1
- Jax error running on A100 GPU (everything is okay on CPU) HOT 2
- unable to setup HOT 1
- The proper way to cite FermiNet repo HOT 1
- Ground State Energies HOT 2
- Question about pbc ewald part. HOT 2
- nan when training with 'adam' HOT 1
- About configs HOT 3
- Question About load Checkpoint HOT 1
- Evaluating logprob using batch_network in train HOT 1
- Issue on running pytest HOT 5
- Extension of PBC code to 1D HOT 7
- Something went wrong in RepeatedDenseBlock.update_curvature_matrix_estimate HOT 2
- Different results obtained from the paper for ch3nh2 HOT 2
- kfac_jax error when running H2 example script HOT 2
- Upstream breaking change in `kfac-jax`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ferminet.