Comments (3)
Hello there @wise-east! You seem to be working with a modified code to support the falcon architecture, have you maybe resolved #532 on your own?
I tested your config with only change of replacing model tiiuae/falcon-7b-instruct
with reciprocate/tiny-llama
while keeping max_new_tokens=256
and I couldn't observe this error, so it seems to be falcon specific perhaps. Also which version of transformers did you use?
from trlx.
Yes, I adapted modeling_ppo.py
by adding a FalconModelBranch
based on what was done for other branches.
from transformers.models.falcon import modeling_falcon
class FalconModelBranch(ModelBranch):
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = modeling_falcon._make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = modeling_falcon._expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def forward(
self,
hidden_states: torch.Tensor,
output_shape: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = False,
):
"""
Reference: https://huggingface.co/tiiuae/falcon-7b-instruct/blob/main/modelling_RW.py
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = hidden_states.shape[:2]
if past_key_values is None:
past_key_values = tuple([None] * len(self.decoder_blocks))
head_mask = self.get_head_mask(head_mask, hf_get_num_hidden_layers(self.config))
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = modeling_falcon.build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.decoder_blocks, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.final_norm(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
lm_logits = self.lm_head(hidden_states)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return CausalLMOutputWithValue(
logits=lm_logits,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
where modeling_falcon
comes from https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py.
My pip show transformers
result:
Name: transformers
Version: 4.32.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: [email protected]
License: Apache 2.0 License
Location: /opt/conda/envs/vllm/lib/python3.9/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, trl, trlx
from trlx.
Oh that's good! However I'm seeing a bit different version of _expand_mask
(https://github.com/huggingface/transformers/blob/aea761499f4b1193f2706f471442da6f9df65d65/src/transformers/models/falcon/modeling_falcon.py#L197) which doesn't take tgt_length
as an argument on 4.32.1
and also on master
. If it's not too much work could you open a pr with your changes? 🙏
from trlx.
Related Issues (20)
- `position_ids` error in accelerate PPO trainer HOT 3
- Question about saving peft checkpoint HOT 2
- Problem with LLama training with LoRA HOT 3
- TypeError: reward_fn() got an unexpected keyword argument 'tokenizer' HOT 1
- multigpu support for summarization ppo example HOT 3
- Support parallel reward_fn in PPO training
- resume_from_checkpoint doesn't work HOT 1
- Issue since most recent transformers update HOT 1
- Multi-GPU training errors with peft HOT 1
- Attention mask when calculating log ratio for PPO
- when i use trlx ppotrainer train a model llama 13b model, but saved huggingface mode ,but when it inference , it has some strange keys ,and the inference result did not show ,it also have no error , it seems the result disapper HOT 1
- MPT is not working
- Runtime error when running examples (ilql_sentiments_t5.py) HOT 2
- Integration of Self-Play Fine-Tuning (SPIN) Method for Enhancing Large Language Models
- RLHF text summarization diverges
- [New Feature Request] Add KTO
- Issue of tensors share memory HOT 2
- TRLX Environment customization
- Why train dataloader is not prepared by Accelerator
- Data Loader Bug when running t5_summarization_daily_cnn.py
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from trlx.