Comments (7)
Could you directly modify this unittest to see if there anything abnormal? https://github.com/facebookincubator/AITemplate/blob/main/tests/unittest/ops/test_conv.py
from aitemplate.
When I change that UT it passes:
class ConvTestCase(unittest.TestCase):
def test_fp16(self, batch=1):
target = detect_target()
X = Tensor(
shape=[1, 384, 384, 4],
dtype="float16",
name="input_0",
is_input=True,
)
W = Tensor(
shape=[256, 7, 7, 4], dtype="float16", name="input_1", is_input=True
)
OP = ops.conv2d(stride=4, pad=3, dilate=1)
Y = OP(X, W)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./tmp", "conv2d")
X_pt = torch.randn(1, 4, 384, 384).cuda().half()
W_pt = torch.randn(256, 4, 7, 7).cuda().half()
Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=3, stride=4)
x = X_pt.permute((0, 2, 3, 1)).contiguous()
w = W_pt.permute((0, 2, 3, 1)).contiguous()
y = torch.empty([1, 96, 96, 256]).cuda().half()
module.run_with_tensors({"input_0": x, "input_1": w}, [y])
y_transpose = y.permute((0, 3, 1, 2))
if target.name() == "cuda":
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1))
Will have to look further to see what is the difference between module and OP versions
from aitemplate.
This visualization tool is very helpful for investigating: https://facebookincubator.github.io/AITemplate/tutorial/how_to_visualize.html
from aitemplate.
The visualization is pretty simple
from aitemplate.
from aitemplate.
The only attribute is op_type. For nn.Conv2dBias
it is conv2d_bias
so I changed test_conv_bias accordingly but it too passes the UT.
class ConvBiasTestCase(unittest.TestCase):
def test_fp16(self, batch=4):
target = detect_target()
X = Tensor(
shape=[1, 384, 384, 4],
dtype="float16",
name="input_0",
is_input=True,
)
W = Tensor(
shape=[256, 7, 7, 4], dtype="float16", name="input_1", is_input=True
)
B = Tensor(shape=[256], dtype="float16", name="input_2", is_input=True)
OP = ops.conv2d_bias(stride=4, pad=3, dilate=1)
Y = OP(X, W, B)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./tmp", "conv2d_bias")
X_pt = torch.randn(1, 4, 384, 384).cuda().half()
W_pt = torch.randn(256, 4, 7, 7).cuda().half()
B_pt = torch.randn(1, 256, 1, 1).cuda().half()
Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=3, stride=4)
Y_pt = Y_pt + B_pt
x = X_pt.permute((0, 2, 3, 1)).contiguous()
w = W_pt.permute((0, 2, 3, 1)).contiguous()
inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}
y = torch.empty([1, 96, 96, 256]).cuda().half()
module.run_with_tensors(inputs, [y])
y_transpose = y.permute((0, 3, 1, 2))
if target.name() == "cuda":
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1))
from aitemplate.
I think this was because I had to delete the temp folder.
from aitemplate.
Related Issues (20)
- fx2ait install error -- c10::StorageImpl has no member named mutable_data HOT 1
- AIT seems not able to achieve max along a given dim HOT 9
- complie controlnet error HOT 2
- Stable Diffusion (GLIGEN) Download Error HOT 4
- Building rocm docker image is very slow. Is there a prebuilt docker image? HOT 1
- error during inferencing: Error: Constant embeddings_token_embedding_weight was not set! Set the value with set_constant. HOT 2
- Failed to compile the controlnet: thepowefuldeez/sd21-controlnet-canny HOT 1
- gemm_gen_profiler() takes 3 positional arguments but 4 were given HOT 2
- <class 'src.pipeline_stable_diffusion_ait.StableDiffusionAITPipeline'> is incorrectly implemented. Expected {'feature_extractor', 'scheduler', 'tokenizer', 'text_encoder', 'safety_checker', 'unet', 'vae'} to be defined HOT 4
- `Fatal Python error: Floating point exception` with `run_with_tensors` HOT 3
- gcc: internal compiler error HOT 2
- `Unsupported workload for this conv2d specialization` when using dynamic shape together with permute HOT 6
- Does AIT handle if/else in forward function? HOT 1
- multi-gpu at runtime error HOT 4
- windows platform cannot link _binary_constants_bin_end and _binary_constants_bin_start HOT 5
- Model is successfully compiled, but OOM when loading
- Does Concatenate order matters? HOT 1
- Docker image fails to build due to python dependency issue HOT 2
- Got cutlass error: Error Internal at: 214 HOT 1
- model_interface.cu:231: Error: Constant pretrained_model_patch_embed_proj_weight was not set! Set the value with set_constant. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
π Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. πππ
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google β€οΈ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from aitemplate.