Giter Club home page Giter Club logo

pinnsformer's Issues

The derivative in the code

I have been closely examining the implementation of the Navier-Stokes equation in your code and I appreciate the effort put into this work. I have a query regarding the computation of the second-order derivative, specifically u_xx.

In the code, the second-order derivative u_xx is calculated directly from u with respect to x_train, as shown below:

u_xx = torch.autograd.grad(u, x_train, ...)

However, I believe that to obtain the second-order derivative u_xx, we should first compute the first-order derivative u_x with respect to x_train, and then differentiate u_x with respect to x_train again, like this:

u_x = torch.autograd.grad(u, x_train, ..., create_graph=True)[0]
u_xx = torch.autograd.grad(u_x, x_train, ...)[0]

Could you please confirm if my understanding is correct? If so, would it be possible to update the code accordingly to reflect this change?

Thank you for your time and consideration. I look forward to your response.

Some confusions in `1d_wave_pinn_ntk`

Hello, great work!

According to the NTK algorithm for PINNs in When and why PINNs fail to train: A neural tangent kernel perspective, I am quite puzzled by the process of solving for J1-J3 below.

Based on my understanding, J1 should be the Jacobian matrix of $u_{tt} - 4u_{xx}$ with respect to the parameters, and J2 should be the Jacobian matrix of $u_t$-initial-conditions with respect to the parameters. Also, the initial-conditions for $u$ should be combined with the boundary-conditions to form J3 ?

...
...
for i in tqdm(range(1000)):
    if i % 50 == 0:
        J1 = torch.zeros((D1, n_params))
        J2 = torch.zeros((D2, n_params))
        J3 = torch.zeros((D3, n_params))

        batch_ind = np.random.choice(len(x_res), kernel_size, replace=False)
        x_train, t_train = x_res[batch_ind], t_res[batch_ind]

        pred_res = model(x_train, t_train)
        pred_left = model(x_left, t_left)
        pred_upper = model(x_upper, t_upper)
        pred_lower = model(x_lower, t_lower)

        for j in range(len(x_train)):
            model.zero_grad()
            pred_res[j].backward(retain_graph=True)
            J1[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])

        for j in range(len(x_left)):
            model.zero_grad()
            pred_left[j].backward(retain_graph=True)
            J2[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])

        for j in range(len(x_lower)):
            model.zero_grad()
            pred_lower[j].backward(retain_graph=True)
            pred_upper[j].backward(retain_graph=True)
            J3[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])
        ...
        ...

Here is the code I have roughly modified, I am not sure if it is correct.

        J1 = torch.zeros((D1, n_params))
        J2 = torch.zeros((D2, n_params))
        J3 = torch.zeros((D3, n_params))

        batch_ind = np.random.choice(len(x_res), kernel_size, replace=False)
        x_train, t_train = x_res[batch_ind], t_res[batch_ind]

        pred_res = model(x_train, t_train)
        pred_left = model(x_left, t_left)
        pred_upper = model(x_upper, t_upper)
        pred_lower = model(x_lower, t_lower)

        u_x = torch.autograd.grad(pred_res, x_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_xx = torch.autograd.grad(u_x, x_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_t = torch.autograd.grad(pred_res, t_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        u_tt = torch.autograd.grad(u_t, t_train, grad_outputs=torch.ones_like(pred_res), retain_graph=True, create_graph=True)[0]
        wave_opt = u_tt - 4 * u_xx  # wave operator
        del u_x, u_xx, u_t, u_tt

        pred_t = torch.autograd.grad(pred_left, t_left, grad_outputs=torch.ones_like(pred_left), retain_graph=True, create_graph=True)[0]

        for j in range(len(x_train)):
            model.zero_grad()
            wave_opt[j].backward(retain_graph=True)
            J1[j, :] = torch.cat([p.grad.view(-1) if p.grad is not None else torch.tensor([0.]).view(-1) for p in model.parameters()])

        for j in range(len(x_left)):
            model.zero_grad()
            pred_t[j].backward(retain_graph=True)
            J2[j, :] = torch.cat([p.grad.view(-1) if p.grad is not None else torch.tensor([0.]).view(-1) for p in model.parameters()])

        for j in range(len(x_lower)):
            model.zero_grad()
            pred_left[j].backward(retain_graph=True)
            pred_lower[j].backward(retain_graph=True)
            pred_upper[j].backward(retain_graph=True)
            J3[j, :] = torch.cat([p.grad.view(-1) for p in model.parameters()])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.