Giter Club home page Giter Club logo

Comments (7)

dpfau avatar dpfau commented on July 19, 2024

from ferminet.

eliasteve avatar eliasteve commented on July 19, 2024

I see. I was thinking that, apart from setting system.ndim=1 in the config, I would need to rewrite the parts of the code implementing the envelopes, hamiltonian and feature layer and then "plug them in" to the rest of the code by setting the appropriate variables in the config (network.make_envelope_fn and so on). Does this sound right? Thanks.

from ferminet.

dpfau avatar dpfau commented on July 19, 2024

from ferminet.

eliasteve avatar eliasteve commented on July 19, 2024

Ok, thanks a lot for your replies! I'm closing the issue for now.

from ferminet.

eliasteve avatar eliasteve commented on July 19, 2024

So, it seems I have actually run into a bit of an issue. I tried adapting the parts of the original code implementing the envelopes, hamiltonian and feature layer to 1D instead of rewriting them from scratch to make sure the code would work before starting to optimize it. This is the code for the feature layer (which mostly works as-is by setting ndim = 1, the only thing I rewrote was the periodic norm)

from typing import Optional, Tuple

import chex
from ferminet import networks
import jax.numpy as jnp

def periodic_norm_1d(metric: jnp.ndarray, scaled_r: jnp.ndarray) -> jnp.ndarray:
    del metric #unused
    return jnp.sin(jnp.pi*scaled_r[..., 0])/jnp.pi#Divide by pi to recover the euclidean norm in the limit of small scaled_r


#Same function as in the original code, just uses my norm.
def make_pbc_feature_layer_1d(
    natoms: Optional[int] = None,
    nspins: Optional[Tuple[int, ...]] = None,
    ndim: int = 1,
    rescale_inputs: bool = False,
    lattice: Optional[jnp.ndarray] = None,
    include_r_ae: bool = False,
) -> networks.FeatureLayer:
  """Returns the init and apply functions for periodic features.

  Args:
      natoms: number of atoms.
      nspins: tuple of the number of spin-up and spin-down electrons.
      ndim: dimension of the system.
      rescale_inputs: If true, rescales r_ae for stability. Note that unlike in
        the OBC case, we do not rescale r_ee as well.
      lattice: Matrix whose columns are the primitive lattice vectors of the
        system, shape (ndim, ndim).
      include_r_ae: Flag to enable electron-atom distance features. Set to False
        to avoid cusps with ghost atoms in, e.g., homogeneous electron gas.
  """

  del nspins

  if lattice is None:
    lattice = jnp.eye(ndim)

  # Calculate reciprocal vectors, factor 2pi omitted
  reciprocal_vecs = jnp.linalg.inv(lattice)
  lattice_metric = lattice.T @ lattice

  def init() -> Tuple[Tuple[int, int], networks.Param]:
    if include_r_ae:
      return (natoms * (2 * ndim + 1), 2 * ndim + 1), {}
    else:
      return (natoms * (2 * ndim), 2 * ndim + 1), {}

  def apply(ae, r_ae, ee, r_ee) -> Tuple[jnp.ndarray, jnp.ndarray]:
    # One e features in phase coordinates, (s_ae)_i = k_i . ae
    s_ae = jnp.einsum('il,jkl->jki', reciprocal_vecs, ae)
    # Two e features in phase coordinates
    s_ee = jnp.einsum('il,jkl->jki', reciprocal_vecs, ee)
    # Periodized features
    ae = jnp.concatenate(
        (jnp.sin(2 * jnp.pi * s_ae), jnp.cos(2 * jnp.pi * s_ae)), axis=-1)
    ee = jnp.concatenate(
        (jnp.sin(2 * jnp.pi * s_ee), jnp.cos(2 * jnp.pi * s_ee)), axis=-1)
    # Distance features defined on orthonormal projections
    r_ae = periodic_norm_1d(lattice_metric, s_ae)
    if rescale_inputs:
      r_ae = jnp.log(1 + r_ae)
    # Don't take gradients through |0|
    n = ee.shape[0]
    s_ee += jnp.eye(n)[..., None]
    r_ee = periodic_norm_1d(lattice_metric, s_ee) * (1.0 - jnp.eye(n))

    if include_r_ae:
      ae_features = jnp.concatenate((r_ae[..., None], ae), axis=2)
    else:
      ae_features = ae
    ae_features = jnp.reshape(ae_features, [jnp.shape(ae_features)[0], -1])
    ee_features = jnp.concatenate((r_ee[..., None], ee), axis=2)
    return ae_features, ee_features

  return networks.FeatureLayer(init=init, apply=apply)

This is the code for the envelopes (I just rewrote make_kpoints to give 1D "vectors", and this gives the correct value for the arguments of sine, cosine)

from typing import Mapping, Optional, Sequence, Tuple, Union
import jax.numpy as jnp

def make_kpoints_1d(
    lattice: Union[jnp.ndarray, jnp.ndarray],
    spins = Tuple[int, int],
    min_kpoints: Optional[int] = None
) -> jnp.ndarray:
    if min_kpoints is None:
        min_kpoints = sum(spins)
    elif min_kpoints < sum(spins):
        raise ValueError("Number of kpoints must be greater than number of electrons")
    
    #kpts will be of the form i*mod, with i an integer between -n and n inclusive
    mod = 2*jnp.pi/lattice[0][0]
    if min_kpoints%2 == 1: n = min_kpoints
    else: n = min_kpoints + 1 #In 1d always generate an odd number of pts
    n //= 2

    kpoints = mod*jnp.arange(-n, n+1, 1)
    #If the first kpt is zero the envelope evaluates to 1 at the beginning of training
    kpoints = jnp.concatenate((kpoints[n:], kpoints[:n])) 
    return kpoints[:, None]

The systems I've been testing the code on are ideal gases, so the hamiltonian is just the kinetic energy part, and again the code is very similar to the original since when computing the laplacian the coordinate vector used is flattened anyway

from ferminet import networks
from ferminet import hamiltonian
import jax.numpy as jnp
from typing import Sequence

def local_energy_free_particles(
    f: networks.FermiNetLike,
    charges: jnp.ndarray,
    nspins: Sequence[int],
    use_scan: bool = False,
    complex_output: bool = False,
) -> hamiltonian.LocalEnergy:
    """Just a wrapper with a signature that matches the one of 
    hamiltonian.MakeLocalEnergy"""

    del nspins, charges
    ke = hamiltonian.local_kinetic_energy(
        f,
        use_scan=use_scan,
        complex_output=complex_output
    )
    def _e_l(params, key, data):
        del key
        kinetic = ke(params, data)
        return kinetic

    return _e_l

This is one of the scripts I'm trying to run:

import sys
import subprocess as sp
import os
import time

from absl import logging
import jax
import jax.numpy as jnp

from ferminet import base_config as cfg
from ferminet.utils import system as system
from ferminet import train
from ferminet.pbc import envelopes

from pbc_1d import envelopes_1d as env
from pbc_1d import feature_layer_1d as fl
from pbc_1d import hamiltonian_1d as ham

def main():
    # Optional, for also printing training progress to STDOUT.
    # If running a script, you can also just use the --alsologtostderr flag.
    logging.get_absl_handler().python_handler.stream = sys.stdout
    logging.set_verbosity(logging.INFO)

    lattice = jnp.array([[5]]) #Matrix with lattice vectors as columns
    placeholder_atom = system.Atom(
      symbol = "X", #"Element" with atomic number 0
      charge = 0,
      coords = (0.,)
    )

    my_config = cfg.default()

    my_config.system.ndim = 1
    my_config.system.molecule = [placeholder_atom]
    systems = [(1, 1)]

    my_config.pretrain.iterations = 0
    my_config.optim.lr.rate = 1e-4

    my_config.network.make_feature_layer_fn = "pbc_1d.feature_layer_1d.make_pbc_feature_layer_1d"
    my_config.network.make_feature_layer_kwargs = {"include_r_ae":False, "lattice":lattice}
    my_config.network.make_envelope_fn = "ferminet.pbc.envelopes.make_multiwave_envelope"
    my_config.system.make_local_energy_fn = "pbc_1d.hamiltonian_1d.local_energy_free_particles"
    #The following is only used for the potential energy
    #my_config.system.make_local_energy_kwargs = {"lattice": lattice}
    my_config.network.determinants = 16
    my_config.log.save_frequency = 5.
    my_config.optim.iterations = 100

    my_config.debug.deterministic = True

    A = 1e4
    C = 1.
    B = 1
    
    my_config.optim.lr.rate = B/(A**C)
    my_config.optim.lr.decay = C
    my_config.optim.lr.delay = A

    for spin_pair in systems:
        print("Running spin_pair", spin_pair)
        my_config.system.electrons = spin_pair
        #If even, make_kpoints will add one
        min_kpoints = spin_pair[0] + spin_pair[1] + 1
        kpoints = env.make_kpoints_1d(lattice, my_config.system.electrons, min_kpoints = min_kpoints)
        print("kpoints:", kpoints)
        my_config.network.make_envelope_kwargs = {"kpoints":kpoints}
        train.train(my_config)

if __name__ == "__main__":
    main()

There are are a couple of indicators that there is a problem with the code:

  1. Calculations for 1, 2 electrons (unpolarized) seem to correctly converge to zero energy, but the ones for 3, 4, 5 (unpolarized) seem to converge to an energy lower than the one obtained theoretically. I have also run ideal gas calculations in 3D which seem to converge to the theoretical value. Here are plots for the energy of the 1D systems:
    energy (1, 0).pdf
    energy (1, 1).pdf
    energy (2, 1).pdf
    energy (2, 2).pdf
    energy (3, 2).pdf

  2. On some of the calculations (2, 5 electrons, unpolarized) the energy fluctuates widely in the first few hundred iterations, then it stabilizes. Here are blow-ups of the relevant plots:
    energy (1, 1) beginning.pdf
    energy (3, 2) beginning.pdf

I've also plotted the wavefunctions for 1, 2 electrons unpolarized:
wave function (1, 0)
wave function (1, 1)

I ran a calculation for the 2 electron polarized gas to make sure that the wavefunction would come out antisymmetric, and it does:
wave function (2, 0)
(Here energy (2, 0).pdf is the energy during training for this last calculation)

Any ideas what the problem might be? Thanks, and sorry for the long post.

from ferminet.

jsspencer avatar jsspencer commented on July 19, 2024

If your calculation is non-variational, then likely causes:

  • your wavefunction is not actually anti-symmetric. After network initialisation, generate some random positions and evaluate the network. Then, swap the positions of two electrons of the same spin and re-evaluate the network -- this should give the negative of the values obtained in the firsts evaluation. If not, you have a bug somewhere.
  • you are not sampling psi^2 correctly. This can happen at the start of a calculation. Run the MCMC burn in for longer. [This is a good idea anyway...]

Finally you might just need to train for longer.

[The N=(2,1) case looks to converge to 0.4, though hard to tell from the plot. If so, this is a bit suspicious...]

You can also test your Hamiltonian on the exact wavefunction.

from ferminet.

eliasteve avatar eliasteve commented on July 19, 2024

Yeah, I thought the N=(2, 1) case converging to 0.4 (the theoretical energy is supposed to be 0.79 and the 3D case seems to converge to this value pretty well) was suspicious too: if the energy had been higher than the theoretical one I could have thought it just hadn't come to convergence yet, but I doubt then that the energy would be lower. The 3D cases also seemed to converge to a few hundredths of an Hartree of the theoretical value in a number of iterations of the order of 10000.

I have run the (2, 1) calculation again with 1000 and 5000 burn-in steps (the original one was with the default of 100) and I get very similar results to the original calculation.

I have also tested the wavefunction for antysymmetry and it seems to be ok, however there appears to maybe be a problem with the periodic boundary conditions: I will run tests on that and also try testing the hamiltonian with exact wavefunctions. I am marking this as closed for now and I'll reopen the issue if I run into any more problems. Thanks a lot for the help.

from ferminet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.