Giter Club home page Giter Club logo

padiff's Introduction

PaDiff

Paddle Automatically Diff precision toolkits.

最近更新

  • 添加了api级别对齐检查,可以通过设置环境变量来关闭:export PADIFF_API_CHECK=OFF
  • 更新对齐策略:自顶向下对齐
  • 更新模型遍历策略:现在会尽可能滤过wrap layer,大部分情况无需手动调用LayerMap
  • 提供了新接口:assign_weight接口,将torch模型的权重拷贝到paddle模型
  • 优化权重初始化过程以及对齐报错信息: 现在会打印树形结构,并标注出错的位置
  • 更新了optimizer的使用方法:可以传入一个lambda(需要在lambda内自行clear grad)
  • 提供了自定义初始化接口:无法直接对齐的模型,现在能够通过提供一个自定义初始化函数进行初始化

简介

PaDiff是基于PaddlePaddle与PyTorch的模型精度对齐工具。传入Paddle与Torch模型,对齐训练中间结果以及训练后的模型权重,并提示精度diff第一次出现的位置。

安装

PaDiff v0.1 版本已发布,可通过如下命令安装:

pip install padiff

尝鲜版或开发者推荐clone源码并使用如下命令安装:

python setup.py install

快速开始

from padiff import auto_diff
import torch
import paddle

class SimpleLayer(paddle.nn.Layer):
  def __init__(self):
      super(SimpleLayer, self).__init__()
      self.linear1 = paddle.nn.Linear(100, 10)
  def forward(self, x):
      x = self.linear1(x)
      return x

class SimpleModule(torch.nn.Module):
  def __init__(self):
      super(SimpleModule, self).__init__()
      self.linear1 = torch.nn.Linear(100, 10)
  def forward(self, x):
      x = self.linear1(x)
      return x

layer = SimpleLayer()
module = SimpleModule()

inp = paddle.rand((100, 100)).numpy().astype("float32")
inp = ({'x': paddle.to_tensor(inp)},
     {'x': torch.as_tensor(inp) })

auto_diff(layer, module, inp, auto_weights=True, options={'atol': 1e-4, 'rtol':0, 'compare_mode': 'strict', 'single_step':False})

已支持 Special Init 的组件

  • MultiHeadAttention
  • LSTM
  • BatchNorm2D

padiff's People

Contributors

feifei-111 avatar linjieccc avatar aurelius84 avatar 2742195759 avatar xiaoguanghu01 avatar wenmuzhou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.