Giter Club home page Giter Club logo

Comments (3)

hlu1 avatar hlu1 commented on July 17, 2024

Can you try commenting out this line: https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/compiler/transform/remove_no_ops.py#L163? I think you hit a bug in this pass. Also, are there expand ops in your model?

from aitemplate.

ioeddk avatar ioeddk commented on July 17, 2024

Is there an update on this problem? I'm facing a similar problem and the line to be commented off above is an parenthesis. So in the function _is_compatible_with_broadcasting we just try simply return True?

My code is trying to build Backbone in YOLOX:

class YOLOPAFPN(nn.Module):
    """
    YOLOv3 model. Darknet 53 is the default backbone of this model.
    """

    def __init__(
        self,
        depth=1.0,
        width=1.0,
        in_features=("dark3", "dark4", "dark5"),
        in_channels=[256, 512, 1024],
        depthwise=False,
        act="relu",
    ):
        super().__init__()
        self.backbone = CSPDarknet(depth, width, depthwise=depthwise, out_features=in_features, act=act)
        self.in_features = in_features
        self.in_channels = in_channels
        Conv = DWConv if depthwise else BaseConv

        self.upsample = nn.upsampling2d(scale_factor=2, mode="nearest")
        self.lateral_conv0 = BaseConv(
            int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
        )
        self.C3_p4 = CSPLayer(
            int(2 * in_channels[1] * width),
            int(in_channels[1] * width),
            round(3 * depth),
            False,
            depthwise=depthwise,
            act=act,
        )  # cat

        self.reduce_conv1 = BaseConv(
            int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
        )
        self.C3_p3 = CSPLayer(
            int(2 * in_channels[0] * width),
            int(in_channels[0] * width),
            round(3 * depth),
            False,
            depthwise=depthwise,
            act=act,
        )

        # bottom-up conv
        self.bu_conv2 = Conv(
            int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
        )
        self.C3_n3 = CSPLayer(
            int(2 * in_channels[0] * width),
            int(in_channels[1] * width),
            round(3 * depth),
            False,
            depthwise=depthwise,
            act=act,
        )

        # bottom-up conv
        self.bu_conv1 = Conv(
            int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
        )
        self.C3_n4 = CSPLayer(
            int(2 * in_channels[1] * width),
            int(in_channels[2] * width),
            round(3 * depth),
            False,
            depthwise=depthwise,
            act=act,
        )

    def forward(self, input):
        """
        Args:
            inputs: input images.

        Returns:
            Tuple[Tensor]: FPN feature.
        """

        #  backbone
        out_features = self.backbone(input)
        
        
        
        features = [out_features[f] for f in self.in_features]
        [x2, x1, x0] = features

        fpn_out0 = self.lateral_conv0(x0)  # 1024->512/32
        f_out0 = self.upsample(fpn_out0)  # 512/16
        f_out0 = aitcat()([f_out0, x1], 3)  # 512->1024/16
        f_out0 = self.C3_p4(f_out0)  # 1024->512/16

        fpn_out1 = self.reduce_conv1(f_out0)  # 512->256/16
        f_out1 = self.upsample(fpn_out1)  # 256/8
        f_out1 = aitcat()([f_out1, x2], 3)  # 256->512/8
        pan_out2 = self.C3_p3(f_out1)  # 512->256/8

        p_out1 = self.bu_conv2(pan_out2)  # 256->256/16
        p_out1 = aitcat()([p_out1, fpn_out1], 3)  # 256->512/16
        pan_out1 = self.C3_n3(p_out1)  # 512->512/16

        p_out0 = self.bu_conv1(pan_out1)  # 512->512/32
        p_out0 = aitcat()([p_out0, fpn_out0], 3)  # 512->1024/32
        pan_out0 = self.C3_n4(p_out0)  # 1024->1024/32

        # outputs = (pan_out2, pan_out1, pan_out0)
        return pan_out2

I have already replaced the convolution operation with aitemplate.frontend.nn.conv2dBias and can compile successfully when only have f_out0.
My error looks like this:

.
.
.
2023-07-17 19:50:15,384 DEBUG <aitemplate.compiler.transform.name_graph> before name_graph: func_cnt=101, tensor_cnt=0, len(func_name_to_tensor_cnt)=101, len(user_provided_dim)=546
2023-07-17 19:50:15,385 DEBUG <aitemplate.compiler.transform.name_graph> after name_graph: func_cnt=101, tensor_cnt=0, len(func_name_to_tensor_cnt)=101, len(user_provided_dim)=546
2023-07-17 19:50:15,565 DEBUG <aitemplate.utils.graph_utils> Dumped dedup_symbolic_name visualization to ./tmp/test_compile_yolox_backbone_1/dedup_symbolic_name_graph_vis.html
2023-07-17 19:50:15,569 INFO <aitemplate.compiler.transform.memory_planning> Workspace shared_size=0 unique_size=0
2023-07-17 19:50:15,569 INFO <aitemplate.compiler.transform.memory_planning> max_blob=3670016 constant_offset=0
Traceback (most recent call last):
  File "/workspaces/torchsparse-misc/conversion/yolox/test_ait_model.py", line 72, in <module>
    model_compiled = compile_module(
  File "/workspaces/torchsparse-misc/conversion/yolox/test_ait_model.py", line 44, in compile_module
    module = compile_model(y, target, "./tmp", model_name)
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/utils/misc.py", line 93, in inner_function
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/compiler.py", line 308, in compile_model
    _verify_outputs_still_in_graph(graph, output_tensors)
  File "/usr/local/lib/python3.8/dist-packages/aitemplate/compiler/compiler.py", line 95, in _verify_outputs_still_in_graph
    raise ValueError(
ValueError: Output output_0 was not found in the graph after optimizations.

from aitemplate.

hlky avatar hlky commented on July 17, 2024

Ran into a similar issue while working on RRDBNet.

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(self.cat((x, x1), 3)))
        x3 = self.lrelu(self.conv3(self.cat((x, x1, x2), 3)))
        x4 = self.lrelu(self.conv4(self.cat((x, x1, x2, x3), 3)))
        x5 = self.conv5(self.cat((x, x1, x2, x3, x4), 3))
        # Empirically, we use 0.2 to scale the residual for better performance
        out = x5 * 0.2 + x
        return out

When using ops.concatenate directly it appears that the next conv layers are not used, only the first is used.

AIT dump (dump_ait_to_py):

    def model(self):
        conv2d_bias_0_0 = ops.conv2d_bias(dilate=1, group=1, pad=1, stride=1)(self.rdb_input, self.conv1_weight, self.conv1_bias)

        # Set outputs

        # End of setting outputs
        return 

Another indicator was that profiling only ran for the first conv layer.

Marking as is_input resolves the issue.
Wrapping the operator appears to resolve the issue.

def cat(self, tensors, dim):
        out = ops.concatenate()(tensors, dim)
        return out

from aitemplate.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.