Giter Club home page Giter Club logo

Comments (5)

tpppppub avatar tpppppub commented on June 12, 2024

@fionser would you mind taking a look?

from spu.

fionser avatar fionser commented on June 12, 2024

@kanav99 Which specific OT you have used inside the MakeSPUContext function.
For now, the Cheetah back-end supports 3 kinds of OT.

  • EMP_Ferret the Ferret OT implementation from the EMP library
  • YACL_Ferret the Ferret OT implemented by the SPU team
  • YACL_Spokensoft the IKNP OT (variant)

Here is my standalone test file

#include "libspu/device/io.h"
#include "libspu/kernel/hlo/basic_binary.h"
#include "libspu/mpc/utils/simulate.h"


template <typename T>
spu::Value infeed(spu::SPUContext* hctx, const xt::xarray<T>& ds) {
  spu::device::ColocatedIo cio(hctx);
  if (hctx->lctx()->Rank() == 0) {
    cio.hostSetVar(fmt::format("x-{}", hctx->lctx()->Rank()), ds);
  }
  cio.sync();
  auto x = cio.deviceGetVar("x-0");
  return x;
}

TEST_P(ObjectivesTest, TestLess) {
  using namespace spu;
  using namespace spu::kernel;
  using namespace spu::mpc;
  //
  // preparing input here ...
  //

  spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
    spu::RuntimeConfig rt_config;
    rt_config.set_protocol(ProtocolKind::CHEETAH);
    rt_config.mutable_cheetah_2pc_config()->set_ot_kind(
        CheetahOtKind::YACL_Ferret);
    rt_config.set_field(FM32);
    rt_config.set_fxp_fraction_bits(12);

    auto _ctx = std::make_unique<spu::SPUContext>(rt_config, lctx);
    auto ctx = _ctx.get();
    spu::mpc::Factory::RegisterProtocol(ctx, lctx);

    auto x = infeed<double>(ctx, _x);
    auto y = infeed<double>(ctx, _x);
    int64_t numel = x.numel();

    auto b0 = ctx->lctx()->GetStats()->sent_bytes +
              ctx->lctx()->GetStats()->recv_bytes;
    auto comp = hlo::Less(ctx, x, y);

    auto b1 = ctx->lctx()->GetStats()->sent_bytes +
              ctx->lctx()->GetStats()->recv_bytes;

    SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / numel);

    b0 = ctx->lctx()->GetStats()->sent_bytes +
         ctx->lctx()->GetStats()->recv_bytes;
    comp = hlo::Less(ctx, x, y);
    b1 = ctx->lctx()->GetStats()->sent_bytes +
         ctx->lctx()->GetStats()->recv_bytes;

    SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / numel);
  });
}

The resultss are

[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.571] [info] [objectives_test.cc:500] bytes communicated: 2951.3418972332015
[2024-04-25 14:39:36.571] [info] [objectives_test.cc:500] bytes communicated: 2951.3418972332015
[2024-04-25 14:39:36.573] [info] [objectives_test.cc:508] bytes communicated: 40.569169960474305
[2024-04-25 14:39:36.573] [info] [objectives_test.cc:508] bytes communicated: 40.569169960474305

from spu.

fionser avatar fionser commented on June 12, 2024

@kanav99 Here is another 2PC example (you can simply replace the main function in here)
Yes, somting seems go wrong if we are only benchmarking on a single value. So it might be better to use a longer vector, e.g., n = 100 in the following example.

// bazel run -c opt experimental/squirrel:squirrel_demo_main -- --rank=0 --lr=0.2 --field=1
// bazel run -c opt experimental/squirrel:squirrel_demo_main -- --rank=1 --lr=0.3 --field=1

std::unique_ptr<spu::SPUContext> MakeSPUContext() {
  auto lctx = MakeLink(Parties.getValue(), Rank.getValue());

  spu::RuntimeConfig config;
  config.set_protocol(spu::ProtocolKind::CHEETAH);
  config.mutable_cheetah_2pc_config()->set_enable_mul_lsb_error(true);
  // replace `EMP_Ferret` or `YACL_Softspoken` 
  config.mutable_cheetah_2pc_config()->set_ot_kind(
      spu::CheetahOtKind::YACL_Ferret);

  config.set_field(static_cast<spu::FieldType>(Field.getValue()));
  config.set_fxp_fraction_bits(18);
  config.set_fxp_div_goldschmidt_iters(1);
  config.set_enable_hal_profile(EngineTrace.getValue());
  auto hctx = std::make_unique<spu::SPUContext>(config, lctx);
  spu::mpc::Factory::RegisterProtocol(hctx.get(), lctx);
  return hctx;
}

int main(int argc, char** argv) {
  SPDLOG_INFO("in process");
  llvm::cl::ParseCommandLineOptions(argc, argv);
  // YACL_Ferret 
  auto sctx = MakeSPUContext();
  std::vector<size_t> shape = {100};
  xt::xarray<double> input(shape);
  std::fill_n(input.data(), input.size(), LearningRate.getValue());

  auto x = InfeedLabel(sctx.get(), input, sctx->lctx()->Rank() == 0);
  auto y = InfeedLabel(sctx.get(), input, sctx->lctx()->Rank() == 1);

  size_t b0 = sctx->lctx()->GetStats()->sent_bytes +
              sctx->lctx()->GetStats()->recv_bytes;
  auto comp = spu::kernel::hlo::Less(sctx.get(), x, y);
  size_t b1 = sctx->lctx()->GetStats()->sent_bytes +
              sctx->lctx()->GetStats()->recv_bytes;

  printf("size %zd\n", input.size());

  SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());

  b0 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;
  comp = spu::kernel::hlo::Less(sctx.get(), x, y);
  b1 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());

  b0 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;
  comp = spu::kernel::hlo::Less(sctx.get(), x, y);
  b1 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());

  auto comp_revealed = spu::kernel::hal::dump_public_as<float>(
      sctx.get(),
      spu::kernel::hlo::Cast(sctx.get(), comp, spu::VIS_PUBLIC, comp.dtype()));

  SPDLOG_INFO("comp_revealed: {}", comp_revealed[0]);
  return 0;
}

Yes, the results are still "not making" sense to me. I would expect the 2nd run also communicated 41bytes.
But, if we enlarge the vector size (e.g., n > 200), it just works as I expected.

[2024-04-25 15:12:10.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 15:12:10.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
size 100
[2024-04-25 15:12:11.232] [info] [squirrel_demo_main.cc:283] bytes communicated: 12107.21
[2024-04-25 15:12:11.418] [info] [squirrel_demo_main.cc:291] bytes communicated: 2703.88
[2024-04-25 15:12:11.419] [info] [squirrel_demo_main.cc:299] bytes communicated: 41.28
[2024-04-25 15:12:11.419] [info] [squirrel_demo_main.cc:305] comp_revealed: 0
I0425 15:12:11   259 external/com_github_brpc_brpc/src/brpc/server.cpp:1218] Server[yacl::link::transport::internal::ReceiverServiceImpl] is going to quit
[2024-04-25 15:12:11.427] [warning] [channel.h:162] Channel destructor is called before WaitLinkTaskFinish, try stop send thread

from spu.

kanav99 avatar kanav99 commented on June 12, 2024

Thanks for the detailed response. Working in a large batch works for me!

I can close this issue, but do you want me to keep it as is? I realize that you do see an issue here

from spu.

fionser avatar fionser commented on June 12, 2024

@kanav99 I close this issue. The "wired" stats might due to some implementation decisions :).

from spu.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.