Giter Club home page Giter Club logo

Comments (7)

williamyang1991 avatar williamyang1991 commented on May 17, 2024 1

Yes, you need to create two controlnet models like model1 and model2 and load different saved pth.

And you will need to extract the controlnet features for each of the controlnet model. (you need to modify cond and the sample function

cond = {

samples, _ = ddim_v_sampler.sample(
)

And to use two controlnet simultanously,
you need to add the multiple control features to the SD feature rather than a single control feature (you need to modify the code)
https://github.com/lllyasviel/ControlNet/blob/ed85cd1e25a5ed592f7d8178495b4483de0331bf/cldm/cldm.py#L35
h += control.pop() to h += control1.pop() + control2.pop()

https://github.com/lllyasviel/ControlNet/blob/ed85cd1e25a5ed592f7d8178495b4483de0331bf/cldm/cldm.py#L41
h = torch.cat([h, hs.pop() + control.pop()], dim=1) to h = torch.cat([h, hs.pop() + control1.pop() + control2.pop()], dim=1)

from rerender_a_video.

hzhou17 avatar hzhou17 commented on May 17, 2024

I notice that codes below:

    if cfg.control_type == 'HED':
          model.load_state_dict(
              load_state_dict('./models/control_sd15_hed.pth', location='cuda'))
      elif cfg.control_type == 'canny':
          model.load_state_dict(
              load_state_dict('./models/control_v11p_sd15_canny.pth',
                              location='cuda')) 

How should I change the code to load Both of them?

from rerender_a_video.

hzhou17 avatar hzhou17 commented on May 17, 2024

@williamyang1991 Thank you very much for the reply! But I still need a bit more assistance...

I changed the cldm.py code as you listed above, and I changed the webUI code to be like this:

def update_detector(self, control_type, canny_low=100, canny_high=200):
    #if self.detector_type == control_type:
    #    return

    if control_type == 'canny':
        canny_detector = CannyDetector()
        low_threshold = canny_low
        high_threshold = canny_high

        def apply_canny(x):
            return canny_detector(x, low_threshold, high_threshold)

        self.detector1 = apply_canny

        midas = MidasDetector()

        def apply_midas(x):
            detected_map, _ = midas(x)
            return detected_map

        self.detector2 = apply_midas
       ......

       ddim_v_sampler = global_state.ddim_v_sampler
       model = ddim_v_sampler.model

       ### Changed Here
       detector1 = global_state.detector1 
       detector2 = global_state.detector2 

       controller = global_state.controller
       model.control_scales = [cfg.control_strength] * 13
       ......

        detected_map1 = detector1(img)
        detected_map1 = HWC3(detected_map1)

        control1 = torch.from_numpy(
            detected_map1.copy()).float().cuda() / 255.0
        control1 = torch.stack([control1 for _ in range(num_samples)], dim=0)
        control1 = einops.rearrange(control1, 'b h w c -> b c h w').clone()

        detected_map2 = detector2(img)
        detected_map2 = HWC3(detected_map2)

        control2 = torch.from_numpy(
            detected_map2.copy()).float().cuda() / 255.0
        control2 = torch.stack([control2 for _ in range(num_samples)], dim=0)
        control2 = einops.rearrange(control2, 'b h w c -> b c h w').clone()


        cond = {
            'c_concat': [control1, control2],
            'c_crossattn': [
                model.get_learned_conditioning(
                    [cfg.prompt + ', ' + cfg.a_prompt] * num_samples)
            ]
        }
        un_cond = {
            'c_concat': [control1, control2],
            'c_crossattn':
            [model.get_learned_conditioning([cfg.n_prompt] * num_samples)]
        }

But I ran into this error:

Traceback (most recent call last):
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/queueing.py", line 388, in call_prediction
output = await route_utils.call_process_api(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/route_utils.py", line 219, in call_process_api
output = await app.get_blocks().process_api(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/blocks.py", line 1437, in process_api
result = await self.call_function(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/blocks.py", line 1109, in call_function
prediction = await anyio.to_thread.run_sync(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/anyio/to_thread.py", line 33, in run_sync
return await get_asynclib().run_sync_in_worker_thread(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
return await future
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/anyio/_backends/asyncio.py", line 807, in run
result = context.run(func, *args)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/gradio/utils.py", line 650, in wrapper
response = f(*args, **kwargs)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "webUI.py", line 367, in process1
x_samples, x_samples_np = generate_first_img(img
, first_strength)
File "webUI.py", line 342, in generate_first_img
samples, _ = ddim_v_sampler.sample(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/heran/Rerender_A_Video/src/ddim_v_hacked.py", line 212, in sample
samples, intermediates = self.ddim_sampling(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/heran/Rerender_A_Video/src/ddim_v_hacked.py", line 329, in ddim_sampling
outs = self.p_sample_ddim(
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/heran/Rerender_A_Video/src/ddim_v_hacked.py", line 381, in p_sample_ddim
model_t = self.model.apply_model(x, t, c)
File "/home/heran/Rerender_A_Video/deps/ControlNet/cldm/cldm.py", line 337, in apply_model
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/heran/Rerender_A_Video/deps/ControlNet/cldm/cldm.py", line 288, in forward
guided_hint = self.input_hint_block(hint, emb, context)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/heran/Rerender_A_Video/deps/ControlNet/ldm/modules/diffusionmodules/openaimodel.py", line 86, in forward
x = layer(x)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 457, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/heran/anaconda3/envs/rerender/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [16, 3, 3, 3], expected input[1, 6, 512, 704] to have 3 channels, but got 6 channels instead

The RuntimeError suggests that the two conditions are not properly combined, I guess. Would you please take a look and give some suggestions? I'd really appreciate it~

from rerender_a_video.

hzhou17 avatar hzhou17 commented on May 17, 2024

Sorry about the lengthy post... I just realized that I did not change this:

        samples, _ = ddim_v_sampler.sample(

What should I change here?

from rerender_a_video.

williamyang1991 avatar williamyang1991 commented on May 17, 2024

You need change
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)

Since now you have two control_models and two features in cond['c_concat']

from rerender_a_video.

hzhou17 avatar hzhou17 commented on May 17, 2024

@williamyang1991 Thank you very much for the reply. I really appreciate it...

I found that code in cldm.py, but I don't know how to change it. Would you be so kind to show me? I browsed thru https://github.com/Mikubill/sd-webui-controlnet, but could not find how multi-controlnet is implemented there.

def apply_model(self, x_noisy, t, cond, *args, **kwargs):
    assert isinstance(cond, dict)
    diffusion_model = self.model.diffusion_model

    cond_txt = torch.cat(cond['c_crossattn'], 1)

    if cond['c_concat'] is None:
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
    else:
        ### Here
        control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
        control = [c * scale for c, scale in zip(control, self.control_scales)]
        eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)

    return eps

from rerender_a_video.

williamyang1991 avatar williamyang1991 commented on May 17, 2024

I'm sorry that I'm under deadline pressure and cannot help you with every details.
And I'm not familar with https://github.com/Mikubill/sd-webui-controlnet.

The main idea is to track everywhere cond['c_concat'] is used and modify the corresponding code.

from rerender_a_video.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.