Giter Club home page Giter Club logo

Comments (11)

mblondel avatar mblondel commented on May 18, 2024

Thanks a lot for your interest in JAXopt! We support pytrees for equality-constrained QPs but not for general QPs yet. For the former, we do so by using linear operators / matvecs. For example, the quadratic form0.5 x^T Q x can be written as 0.5 * tree_vdot(x, matvec_Q(params_Q, x)). This reduces to the array case with matvec_Q(Q, x) = jnp.dot(Q, x). There's a test using matvecs here (albeit not using pytrees). For general QPs, we will be able to support pytrees once we implement our own general solvers. In the meantime, you need to flatten your pytrees.

from jaxopt.

FerranAlet avatar FerranAlet commented on May 18, 2024

Thanks for the fast reply!

My bad, I hadn't seen the matvec commit from yesterday and was reading stale code from a few days ago. I re-did the changes on the new code and showed my point on the test you pointed me to. I've updated my original question with the appropriate details for the matvec test.

I also wasn't very clear; adding more details:

  • Both my research code and the test that line failed use equality-only KKT, without any inequalities.
  • My research code (that gave the PyTreeDef error mentioned above) is not a QuadraticProgramming problem. I just meant I used that example to see how to use idf.make_kkt_optimality_fun and idf.custom_root. There, I implemented my own solver to find the primal and dual variables.
  • Both in my research code and the slightly modified Quadratic Programming the "forward" computation of finding the primal and dual solutions works just fine. For instance in the modified QP it passes the test on line 104 because it returns the correct primal variable, just within a list. Problems come later: in my research code problems came when I asked for its gradient (where it has to use KKT optimality function). In the QP tests problems came when checking if the optimality function is all zeros (the error mentioned above).

from jaxopt.

mblondel avatar mblondel commented on May 18, 2024

Thanks for the clarifications. The ideal would be a minimal script to reproduce the issue. Without one, I can only speculate the potential issues.

However, it doesn't, this test line raises an assert for an array that should be all zeros and instead is:

Have you compared your primal and dual solutions to another solver (for instance QuadraticProgramming) by flattening your pytrees? Note that l2_optimality_error needs to be passed a tuple containing both the primal and dual solutions.

from jaxopt.

FerranAlet avatar FerranAlet commented on May 18, 2024

Yes, I have compared my solutions. The solver returns the same tuple except that the first DeviceArray (primal solution) is within a list, as expected. That's why the test on the line above (that checks the solution is correct) passes.

I attach the code, you will see it's a super minimal change from the original one, I've marked the differences with #CHANGED comments. You'll see some changes were necessary to run it in Google colab, but not fundamental to my point. I've also marked the test that passes with #PASSES, and the one that fails #FAILS.

"""Quadratic programming in JAX."""

from typing import Any
from typing import Callable
from typing import Optional
from typing import Tuple

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from jaxopt._src import base
from jaxopt._src import implicit_diff as idf
from jaxopt._src import linear_solve
from jaxopt._src import tree_util


ArrayPair = Tuple[jnp.ndarray, jnp.ndarray]


def _check_params(params_obj, params_eq=None, params_ineq=None):
  if params_obj is None:
    raise ValueError("params_obj should be a tuple (Q, c)")
  Q, c = params_obj
  if Q.shape[0] != Q.shape[1]:
    raise ValueError("Q must be a square matrix.")
  if Q.shape[1] != c.shape[0]:
    raise ValueError("Q.shape[1] != c.shape[0]")

  if params_eq is not None:
    A, b = params_eq
    if A.shape[0] != b.shape[0]:
      raise ValueError("A.shape[0] != b.shape[0]")
    if A.shape[1] != Q.shape[1]:
      raise ValueError("Q.shape[1] != A.shape[1]")

  if params_ineq is not None:
    G, h = params_ineq
    if G.shape[0] != h.shape[0]:
      raise ValueError("G.shape[0] != h.shape[0]")
    if G.shape[1] != Q.shape[1]:
      raise ValueError("G.shape[1] != Q.shape[1]")


def _matvec_and_rmatvec(matvec, x, y):
  """Returns both matvec(x) = dot(A, x) and rmatvec(y) = dot(A.T, y)."""
  matvec_x, vjp = jax.vjp(matvec, x)
  rmatvec_y, = vjp(y)
  return matvec_x, rmatvec_y


def _solve_eq_constrained_qp(init_params,
                             matvec_Q,
                             c,
                             matvec_A,
                             b,
                             maxiter):
  """Solves 0.5 * x^T Q x + c^T x subject to Ax = b.
  This solver returns both the primal solution (x) and the dual solution.
  """

  def matvec(u):
    primal_u, dual_u = u
    mv_A, rmv_A = _matvec_and_rmatvec(matvec_A, primal_u, dual_u)
    return (tree_util.tree_add(matvec_Q(primal_u), rmv_A), mv_A)

  minus_c = tree_util.tree_scalar_mul(-1, c)

  # Solves the following linear system:
  # [[Q A^T]  [primal_var = [-c
  #  [A 0  ]]  dual_var  ]    b]
  return linear_solve.solve_cg(matvec, (minus_c, b), init=init_params,
                               maxiter=maxiter)


def _solve_constrained_qp_cvxpy(params_obj, params_eq, params_ineq):
  """Solve 0.5 * x^T Q x + c^T x subject to Gx <= h, Ax = b."""

  # CVXPY runs on CPU. Hopefully, we can implement our own pure JAX solvers
  # and remove this dependency in the future.
  # TODO(frostig,mblondel): experiment with `jax.experimental.host_callback`
  # to "support" other devices (GPU/TPU) in the interim, by calling into the
  # host CPU and running cvxpy there.
  import cvxpy as cp

  Q, c = params_obj
  A, b = params_eq
  G, h = params_ineq

  x = cp.Variable(len(c))
  objective = 0.5 * cp.quad_form(x, Q) + c.T @ x
  constraints = [A @ x == b, G @ x <= h]
  pb = cp.Problem(cp.Minimize(objective), constraints)
  pb.solve()
  print("Primal:", [jnp.array(x.value)])
  return ([jnp.array(x.value)], jnp.array(pb.constraints[0].dual_value), #CHANGED
          jnp.array(pb.constraints[1].dual_value))


def _create_matvec(matvec, M):
  if matvec is not None:
    # M = params_M
    return lambda u: matvec(M, u)
  else:
    return lambda u: jnp.dot(M, u)


def _make_quadratic_prog_optimality_fun(matvec_Q, matvec_A):
  """Makes the optimality function for quadratic programming.
  Returns:
    optimality_fun(params, params_obj, params_eq, params_ineq) where
      params = (primal_var, eq_dual_var, ineq_dual_var)
      params_obj = (Q, c)
      params_eq = (A, b)
      params_ineq = (G, h) or None
  """
  def obj_fun(primal_var, params_obj):
    Q, c = params_obj
    _matvec_Q = _create_matvec(matvec_Q, Q)
    return (0.5 * tree_util.tree_vdot(primal_var[0], _matvec_Q(primal_var[0])) + #CHANGED
            tree_util.tree_vdot(primal_var[0], c)) #CHANGED

  def eq_fun(primal_var, params_eq):
    A, b = params_eq
    _matvec_A = _create_matvec(matvec_A, A)
    return tree_util.tree_sub(_matvec_A(primal_var[0]), b) #CHANGED

  def ineq_fun(primal_var, params_ineq):
    G, h = params_ineq
    return jnp.dot(G, primal_var[0]) - h #CHANGED

  return idf.make_kkt_optimality_fun(obj_fun, eq_fun, ineq_fun)


@dataclass
class QuadraticProgramming:
  """Quadratic programming solver.
  The objective function is::
    0.5 * x^T Q x + c^T x subject to Gx <= h, Ax = b.
  Attributes:
    matvec_Q: a Callable matvec_Q(params_Q, u).
      By default, matvec_Q(Q, u) = dot(Q, u), where Q = params_Q.
    matvec_A: a Callable matvec_A(params_A, u).
      By default, matvec_A(A, u) = dot(A, u), where A = params_A.
    maxiter: maximum number of iterations.
  """

  # TODO(mblondel): add matvec_G when we implement our own QP solvers.
  matvec_Q: Optional[Callable] = None
  matvec_A: Optional[Callable] = None
  maxiter: int = 1000

  def run(self,
          init_params: Optional[Tuple] = None,
          params_obj: Optional[ArrayPair] = None,
          params_eq: Optional[ArrayPair] = None,
          params_ineq: Optional[ArrayPair] = None) -> base.OptStep:
    """Runs the quadratic programming solver in CVXPY.
    The returned params contains both the primal and dual solutions.
    Args:
      init_params: ignored.
      params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided.
      params_eq: (A, b) or (params_A, b) if matvec_A is provided.
      params_ineq: = (G, h) or None if no inequality constraints.
    Return type:
      base.OptStep
    Returns:
      (params, state), ``params = (primal_var, dual_var_eq, dual_var_ineq)``
    """
    if self.matvec_Q is None and self.matvec_A is None:
      _check_params(params_obj, params_eq, params_ineq)

    Q, c = params_obj
    A, b = params_eq

    matvec_Q = _create_matvec(self.matvec_Q, Q)
    matvec_A = _create_matvec(self.matvec_A, A)

    if params_ineq is None:
      primal, dual_eq = _solve_eq_constrained_qp(init_params,
                                                 matvec_Q, c,
                                                 matvec_A, b,
                                                 self.maxiter)
      print("Primal:", [primal]) #CHANGED
      params = ([primal], dual_eq, None) #CHANGED
    else:
      params = _solve_constrained_qp_cvxpy(params_obj, params_eq, params_ineq)

    # No state needed currently as we use CVXPY.
    return base.OptStep(params=params, state=None)

  def l2_optimality_error(
      self,
      params: Any,
      params_obj: ArrayPair,
      params_eq: ArrayPair,
      params_ineq: Optional[ArrayPair] = None) -> base.OptStep:
    """Computes the L2 norm of the KKT residuals."""
    pytree = self.optimality_fun(params, params_obj, params_eq, params_ineq)
    print("Pytree:", pytree) #CHANGED
    return tree_util.tree_l2_norm(pytree)

  def __post_init__(self):
    self.optimality_fun = _make_quadratic_prog_optimality_fun(self.matvec_Q,
                                                              self.matvec_A)

    # Set up implicit diff.
    decorator = idf.custom_root(self.optimality_fun, has_aux=True)
    # pylint: disable=g-missing-from-attributes
    self.run = decorator(self.run)

import jax
from jax import test_util as jtu
import jax.numpy as jnp

from jaxopt import projection
# CHANGED: removed some imports so that it uses the modified QuadraticProgramming from above, not the original one
import numpy as onp


class QuadraticProgTest(jtu.JaxTestCase):

  def test_matvec_and_rmatvec(self):
    rng = onp.random.RandomState(0)
    A = rng.randn(5, 4)
    matvec = lambda x: jnp.dot(A, x)
    x = rng.randn(4)
    y = rng.randn(5)
    mv_A, rmv_A = _matvec_and_rmatvec(matvec, x, y)
    self.assertArraysAllClose(mv_A, jnp.dot(A, x))
    self.assertArraysAllClose(rmv_A, jnp.dot(A.T, y))

  def _check_derivative_A_and_b(self, solver, params, A, b):
    def fun(A, b):
      # reduce the primal variables to a scalar value for test purpose.
      hyperparams = dict(params_obj=params["params_obj"],
                         params_eq=(A, b),
                         params_ineq=params["params_ineq"])
      return jnp.sum(solver.run(**hyperparams).params[0])

    # Derivative w.r.t. A.
    rng = onp.random.RandomState(0)
    V = rng.rand(*A.shape)
    V /= onp.sqrt(onp.sum(V ** 2))
    eps = 1e-4
    deriv_jax = jnp.vdot(V, jax.grad(fun)(A, b))
    deriv_num = (fun(A + eps * V, b) - fun(A - eps * V, b)) / (2 * eps)
    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

    # Derivative w.r.t. b.
    v = rng.rand(*b.shape)
    v /= onp.sqrt(onp.sum(b ** 2))
    eps = 1e-4
    deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=1)(A, b))
    deriv_num = (fun(A, b + eps * v) - fun(A, b - eps * v)) / (2 * eps)
    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)

  def test_qp_eq_and_ineq(self):
    Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
    c = jnp.array([1.0, 1.0])
    A = jnp.array([[1.0, 1.0]])
    b = jnp.array([1.0])
    G = jnp.array([[-1.0, 0.0], [0.0, -1.0]])
    h = jnp.array([0.0, 0.0])
    qp = QuadraticProgramming()
    hyperparams = dict(params_obj=(Q, c), params_eq=(A, b), params_ineq=(G, h))
    sol = qp.run(**hyperparams).params
    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
    self._check_derivative_A_and_b(qp, hyperparams, A, b)

  def test_qp_eq_only(self):
    Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
    c = jnp.array([1.0, 1.0])
    A = jnp.array([[1.0, 1.0]])
    b = jnp.array([1.0])
    qp = QuadraticProgramming()
    hyperparams = dict(params_obj=(Q, c), params_eq=(A, b), params_ineq=None)
    sol = qp.run(**hyperparams).params
    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
    self._check_derivative_A_and_b(qp, hyperparams, A, b)

  def test_projection_hyperplane(self):
    x = jnp.array([1.0, 2.0])
    a = jnp.array([-0.5, 1.5])
    b = 0.3
    # Find ||y-x||^2 such that jnp.dot(y, a) = b.
    expected = projection.projection_hyperplane(x, (a, b))

    matvec_Q = lambda params_Q, u: u
    matvec_A = lambda params_A, u: jnp.dot(a, u).reshape(1)
    qp = QuadraticProgramming(matvec_Q=matvec_Q, matvec_A=matvec_A)
    # In this example, params_Q = params_A = None.
    hyperparams = dict(params_obj=(None, -x),
                       params_eq=(None, jnp.array([b])))
    sol = qp.run(**hyperparams).params
    primal_sol = sol[0][0] #CHANGED
    self.assertArraysAllClose(primal_sol, expected) #PASSES
    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0) #FAILS
    print("Test passed")

  def test_projection_simplex(self):
    def _projection_simplex_qp(x, s=1.0):
      Q = jnp.eye(len(x))
      A = jnp.array([jnp.ones_like(x)])
      b = jnp.array([s])
      G = -jnp.eye(len(x))
      h = jnp.zeros_like(x)
      hyperparams = dict(params_obj=(Q, -x), params_eq=(A, b),
                         params_ineq=(G, h))

      qp = QuadraticProgramming()
      # Returns the primal solution only.
      return qp.run(**hyperparams).params[0]

    rng = onp.random.RandomState(0)
    x = jnp.array(rng.randn(10).astype(onp.float32))
    p = projection.projection_simplex(x)
    p2 = _projection_simplex_qp(x)
    self.assertArraysAllClose(p, p2, atol=1e-4)

    J = jax.jacrev(projection.projection_simplex)(x)
    J2 = jax.jacrev(_projection_simplex_qp)(x)
    self.assertArraysAllClose(J, J2, atol=1e-5)

QPT = QuadraticProgTest() #CHANGED to run in Colab
QPT.test_projection_hyperplane() #CHANGED to run in Colab

from jaxopt.

mblondel avatar mblondel commented on May 18, 2024

I think I fixed the issue in #21. There was a + sign somewhere where we should have used a tree_add instead. I also added a KKTSolution named tuple to make user code more readable. Now we can do this:

sol = qp.run(**hyperparams).params
print(sol.primal)
print(sol.dual_eq)

Let me know if this fixes the issue on your side as well.

from jaxopt.

FerranAlet avatar FerranAlet commented on May 18, 2024

Yes, now it works; thanks!

from jaxopt.

FerranAlet avatar FerranAlet commented on May 18, 2024

Actually, I just realized that the same bug is probably also happening in line 228 (and possibly 229, but I'm less sure) on that pull request. It just didn't affect our tests because they're equality-only.

from jaxopt.

mblondel avatar mblondel commented on May 18, 2024

I was planning to do it when we get pytree support for general QPs in order to properly test it but we can try to be future proof I guess (just forced push).

from jaxopt.

FerranAlet avatar FerranAlet commented on May 18, 2024

Awesome. I think preemptively solving the bug is useful since people (including myself) may use make_kkt_optimality_fun with inequalities without using QPs. Thanks for solving it!

from jaxopt.

mblondel avatar mblondel commented on May 18, 2024

BTW, make_kkt_optimality_fun is not public API at the moment (it's in _src, which is supposed to be private stuff). I'm guessing you would like us to expose make_kkt_optimality_fun?

from jaxopt.

FerranAlet avatar FerranAlet commented on May 18, 2024

Yes, that would be great; thanks!!

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.