Comments (7)
from ferminet.
I see. I was thinking that, apart from setting system.ndim=1 in the config, I would need to rewrite the parts of the code implementing the envelopes, hamiltonian and feature layer and then "plug them in" to the rest of the code by setting the appropriate variables in the config (network.make_envelope_fn and so on). Does this sound right? Thanks.
from ferminet.
from ferminet.
Ok, thanks a lot for your replies! I'm closing the issue for now.
from ferminet.
So, it seems I have actually run into a bit of an issue. I tried adapting the parts of the original code implementing the envelopes, hamiltonian and feature layer to 1D instead of rewriting them from scratch to make sure the code would work before starting to optimize it. This is the code for the feature layer (which mostly works as-is by setting ndim = 1, the only thing I rewrote was the periodic norm)
from typing import Optional, Tuple
import chex
from ferminet import networks
import jax.numpy as jnp
def periodic_norm_1d(metric: jnp.ndarray, scaled_r: jnp.ndarray) -> jnp.ndarray:
del metric #unused
return jnp.sin(jnp.pi*scaled_r[..., 0])/jnp.pi#Divide by pi to recover the euclidean norm in the limit of small scaled_r
#Same function as in the original code, just uses my norm.
def make_pbc_feature_layer_1d(
natoms: Optional[int] = None,
nspins: Optional[Tuple[int, ...]] = None,
ndim: int = 1,
rescale_inputs: bool = False,
lattice: Optional[jnp.ndarray] = None,
include_r_ae: bool = False,
) -> networks.FeatureLayer:
"""Returns the init and apply functions for periodic features.
Args:
natoms: number of atoms.
nspins: tuple of the number of spin-up and spin-down electrons.
ndim: dimension of the system.
rescale_inputs: If true, rescales r_ae for stability. Note that unlike in
the OBC case, we do not rescale r_ee as well.
lattice: Matrix whose columns are the primitive lattice vectors of the
system, shape (ndim, ndim).
include_r_ae: Flag to enable electron-atom distance features. Set to False
to avoid cusps with ghost atoms in, e.g., homogeneous electron gas.
"""
del nspins
if lattice is None:
lattice = jnp.eye(ndim)
# Calculate reciprocal vectors, factor 2pi omitted
reciprocal_vecs = jnp.linalg.inv(lattice)
lattice_metric = lattice.T @ lattice
def init() -> Tuple[Tuple[int, int], networks.Param]:
if include_r_ae:
return (natoms * (2 * ndim + 1), 2 * ndim + 1), {}
else:
return (natoms * (2 * ndim), 2 * ndim + 1), {}
def apply(ae, r_ae, ee, r_ee) -> Tuple[jnp.ndarray, jnp.ndarray]:
# One e features in phase coordinates, (s_ae)_i = k_i . ae
s_ae = jnp.einsum('il,jkl->jki', reciprocal_vecs, ae)
# Two e features in phase coordinates
s_ee = jnp.einsum('il,jkl->jki', reciprocal_vecs, ee)
# Periodized features
ae = jnp.concatenate(
(jnp.sin(2 * jnp.pi * s_ae), jnp.cos(2 * jnp.pi * s_ae)), axis=-1)
ee = jnp.concatenate(
(jnp.sin(2 * jnp.pi * s_ee), jnp.cos(2 * jnp.pi * s_ee)), axis=-1)
# Distance features defined on orthonormal projections
r_ae = periodic_norm_1d(lattice_metric, s_ae)
if rescale_inputs:
r_ae = jnp.log(1 + r_ae)
# Don't take gradients through |0|
n = ee.shape[0]
s_ee += jnp.eye(n)[..., None]
r_ee = periodic_norm_1d(lattice_metric, s_ee) * (1.0 - jnp.eye(n))
if include_r_ae:
ae_features = jnp.concatenate((r_ae[..., None], ae), axis=2)
else:
ae_features = ae
ae_features = jnp.reshape(ae_features, [jnp.shape(ae_features)[0], -1])
ee_features = jnp.concatenate((r_ee[..., None], ee), axis=2)
return ae_features, ee_features
return networks.FeatureLayer(init=init, apply=apply)
This is the code for the envelopes (I just rewrote make_kpoints to give 1D "vectors", and this gives the correct value for the arguments of sine, cosine)
from typing import Mapping, Optional, Sequence, Tuple, Union
import jax.numpy as jnp
def make_kpoints_1d(
lattice: Union[jnp.ndarray, jnp.ndarray],
spins = Tuple[int, int],
min_kpoints: Optional[int] = None
) -> jnp.ndarray:
if min_kpoints is None:
min_kpoints = sum(spins)
elif min_kpoints < sum(spins):
raise ValueError("Number of kpoints must be greater than number of electrons")
#kpts will be of the form i*mod, with i an integer between -n and n inclusive
mod = 2*jnp.pi/lattice[0][0]
if min_kpoints%2 == 1: n = min_kpoints
else: n = min_kpoints + 1 #In 1d always generate an odd number of pts
n //= 2
kpoints = mod*jnp.arange(-n, n+1, 1)
#If the first kpt is zero the envelope evaluates to 1 at the beginning of training
kpoints = jnp.concatenate((kpoints[n:], kpoints[:n]))
return kpoints[:, None]
The systems I've been testing the code on are ideal gases, so the hamiltonian is just the kinetic energy part, and again the code is very similar to the original since when computing the laplacian the coordinate vector used is flattened anyway
from ferminet import networks
from ferminet import hamiltonian
import jax.numpy as jnp
from typing import Sequence
def local_energy_free_particles(
f: networks.FermiNetLike,
charges: jnp.ndarray,
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
) -> hamiltonian.LocalEnergy:
"""Just a wrapper with a signature that matches the one of
hamiltonian.MakeLocalEnergy"""
del nspins, charges
ke = hamiltonian.local_kinetic_energy(
f,
use_scan=use_scan,
complex_output=complex_output
)
def _e_l(params, key, data):
del key
kinetic = ke(params, data)
return kinetic
return _e_l
This is one of the scripts I'm trying to run:
import sys
import subprocess as sp
import os
import time
from absl import logging
import jax
import jax.numpy as jnp
from ferminet import base_config as cfg
from ferminet.utils import system as system
from ferminet import train
from ferminet.pbc import envelopes
from pbc_1d import envelopes_1d as env
from pbc_1d import feature_layer_1d as fl
from pbc_1d import hamiltonian_1d as ham
def main():
# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)
lattice = jnp.array([[5]]) #Matrix with lattice vectors as columns
placeholder_atom = system.Atom(
symbol = "X", #"Element" with atomic number 0
charge = 0,
coords = (0.,)
)
my_config = cfg.default()
my_config.system.ndim = 1
my_config.system.molecule = [placeholder_atom]
systems = [(1, 1)]
my_config.pretrain.iterations = 0
my_config.optim.lr.rate = 1e-4
my_config.network.make_feature_layer_fn = "pbc_1d.feature_layer_1d.make_pbc_feature_layer_1d"
my_config.network.make_feature_layer_kwargs = {"include_r_ae":False, "lattice":lattice}
my_config.network.make_envelope_fn = "ferminet.pbc.envelopes.make_multiwave_envelope"
my_config.system.make_local_energy_fn = "pbc_1d.hamiltonian_1d.local_energy_free_particles"
#The following is only used for the potential energy
#my_config.system.make_local_energy_kwargs = {"lattice": lattice}
my_config.network.determinants = 16
my_config.log.save_frequency = 5.
my_config.optim.iterations = 100
my_config.debug.deterministic = True
A = 1e4
C = 1.
B = 1
my_config.optim.lr.rate = B/(A**C)
my_config.optim.lr.decay = C
my_config.optim.lr.delay = A
for spin_pair in systems:
print("Running spin_pair", spin_pair)
my_config.system.electrons = spin_pair
#If even, make_kpoints will add one
min_kpoints = spin_pair[0] + spin_pair[1] + 1
kpoints = env.make_kpoints_1d(lattice, my_config.system.electrons, min_kpoints = min_kpoints)
print("kpoints:", kpoints)
my_config.network.make_envelope_kwargs = {"kpoints":kpoints}
train.train(my_config)
if __name__ == "__main__":
main()
There are are a couple of indicators that there is a problem with the code:
-
Calculations for 1, 2 electrons (unpolarized) seem to correctly converge to zero energy, but the ones for 3, 4, 5 (unpolarized) seem to converge to an energy lower than the one obtained theoretically. I have also run ideal gas calculations in 3D which seem to converge to the theoretical value. Here are plots for the energy of the 1D systems:
energy (1, 0).pdf
energy (1, 1).pdf
energy (2, 1).pdf
energy (2, 2).pdf
energy (3, 2).pdf -
On some of the calculations (2, 5 electrons, unpolarized) the energy fluctuates widely in the first few hundred iterations, then it stabilizes. Here are blow-ups of the relevant plots:
energy (1, 1) beginning.pdf
energy (3, 2) beginning.pdf
I've also plotted the wavefunctions for 1, 2 electrons unpolarized:
I ran a calculation for the 2 electron polarized gas to make sure that the wavefunction would come out antisymmetric, and it does:
(Here energy (2, 0).pdf is the energy during training for this last calculation)
Any ideas what the problem might be? Thanks, and sorry for the long post.
from ferminet.
If your calculation is non-variational, then likely causes:
- your wavefunction is not actually anti-symmetric. After network initialisation, generate some random positions and evaluate the network. Then, swap the positions of two electrons of the same spin and re-evaluate the network -- this should give the negative of the values obtained in the firsts evaluation. If not, you have a bug somewhere.
- you are not sampling psi^2 correctly. This can happen at the start of a calculation. Run the MCMC burn in for longer. [This is a good idea anyway...]
Finally you might just need to train for longer.
[The N=(2,1) case looks to converge to 0.4, though hard to tell from the plot. If so, this is a bit suspicious...]
You can also test your Hamiltonian on the exact wavefunction.
from ferminet.
Yeah, I thought the N=(2, 1) case converging to 0.4 (the theoretical energy is supposed to be 0.79 and the 3D case seems to converge to this value pretty well) was suspicious too: if the energy had been higher than the theoretical one I could have thought it just hadn't come to convergence yet, but I doubt then that the energy would be lower. The 3D cases also seemed to converge to a few hundredths of an Hartree of the theoretical value in a number of iterations of the order of 10000.
I have run the (2, 1) calculation again with 1000 and 5000 burn-in steps (the original one was with the default of 100) and I get very similar results to the original calculation.
I have also tested the wavefunction for antysymmetry and it seems to be ok, however there appears to maybe be a problem with the periodic boundary conditions: I will run tests on that and also try testing the hamiltonian with exact wavefunctions. I am marking this as closed for now and I'll reopen the issue if I run into any more problems. Thanks a lot for the help.
from ferminet.
Related Issues (20)
- How does training time scale w.r.t. model size? HOT 1
- Jax install - issue with correct version number HOT 1
- AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr' HOT 1
- Jax error running on A100 GPU (everything is okay on CPU) HOT 2
- unable to setup HOT 1
- The proper way to cite FermiNet repo HOT 1
- Ground State Energies HOT 2
- Question about pbc ewald part. HOT 2
- nan when training with 'adam' HOT 1
- About configs HOT 3
- Question About load Checkpoint HOT 1
- Evaluating logprob using batch_network in train HOT 1
- Issue on running pytest HOT 5
- Something went wrong in RepeatedDenseBlock.update_curvature_matrix_estimate HOT 2
- Different results obtained from the paper for ch3nh2 HOT 2
- kfac_jax error when running H2 example script HOT 2
- Upstream breaking change in `kfac-jax`
- KeyError raised after burn-in MCMC steps HOT 1
- Logdet Bug Similar to e9f8c64 HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ferminet.