Comments (2)
you need to pad the weight to 4 channel and set the weights, e.g. https://github.com/facebookincubator/AITemplate/blob/main/examples/01_resnet-50/weight_utils.py#L147
from aitemplate.
Closing this. Still would be nice to add this to documentations.
Using this works:
def map_pt_params(ait_model, pt_model):
ait_model.name_parameter_tensor()
pt_params = dict(pt_model.named_parameters())
mapped_pt_params = OrderedDict()
# names should be valid C++ variables
for name, _ in ait_model.named_parameters():
ait_name = name.replace(".", "_")
assert name in pt_params
params = pt_params[name]
if len(params.shape) == 4:
# NCHW->NHWC
params = params.transpose(1,-1).contiguous()
# Pad for few channels
if params.shape[-1] == 3:
params = torch.nn.functional.pad(params, (0,1))
mapped_pt_params[ait_name] = params
return mapped_pt_params
from aitemplate.
Related Issues (20)
- Confused on the shape of input Tensor HOT 7
- AIT seems not able to achieve max along a given dim HOT 9
- complie controlnet error HOT 2
- Stable Diffusion (GLIGEN) Download Error HOT 4
- Building rocm docker image is very slow. Is there a prebuilt docker image? HOT 1
- error during inferencing: Error: Constant embeddings_token_embedding_weight was not set! Set the value with set_constant. HOT 2
- Failed to compile the controlnet: thepowefuldeez/sd21-controlnet-canny HOT 1
- gemm_gen_profiler() takes 3 positional arguments but 4 were given HOT 2
- <class 'src.pipeline_stable_diffusion_ait.StableDiffusionAITPipeline'> is incorrectly implemented. Expected {'feature_extractor', 'scheduler', 'tokenizer', 'text_encoder', 'safety_checker', 'unet', 'vae'} to be defined HOT 4
- `Fatal Python error: Floating point exception` with `run_with_tensors` HOT 3
- gcc: internal compiler error HOT 2
- `Unsupported workload for this conv2d specialization` when using dynamic shape together with permute HOT 6
- Does AIT handle if/else in forward function? HOT 1
- multi-gpu at runtime error HOT 4
- windows platform cannot link _binary_constants_bin_end and _binary_constants_bin_start HOT 5
- Model is successfully compiled, but OOM when loading
- Does Concatenate order matters? HOT 1
- Docker image fails to build due to python dependency issue HOT 2
- Got cutlass error: Error Internal at: 214 HOT 1
- model_interface.cu:231: Error: Constant pretrained_model_patch_embed_proj_weight was not set! Set the value with set_constant. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from aitemplate.