Giter Club home page Giter Club logo

Comments (1)

PFery4 avatar PFery4 commented on May 24, 2024

The temporal encoding is managed by the PositionalAgentEncoding class in the agentformer.py file:

class PositionalAgentEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_t_len=200, max_a_len=200, concat=False, use_agent_enc=False, agent_enc_learn=False):
super(PositionalAgentEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
self.concat = concat
self.d_model = d_model
self.use_agent_enc = use_agent_enc
if concat:
self.fc = nn.Linear((3 if use_agent_enc else 2) * d_model, d_model)
pe = self.build_pos_enc(max_t_len)
self.register_buffer('pe', pe)
if use_agent_enc:
if agent_enc_learn:
self.ae = nn.Parameter(torch.randn(max_a_len, 1, d_model) * 0.1)
else:
ae = self.build_pos_enc(max_a_len)
self.register_buffer('ae', ae)
def build_pos_enc(self, max_len):
pe = torch.zeros(max_len, self.d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
return pe
def build_agent_enc(self, max_len):
ae = torch.zeros(max_len, self.d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-np.log(10000.0) / self.d_model))
ae[:, 0::2] = torch.sin(position * div_term)
ae[:, 1::2] = torch.cos(position * div_term)
ae = ae.unsqueeze(0).transpose(0, 1)
return ae
def get_pos_enc(self, num_t, num_a, t_offset):
pe = self.pe[t_offset: num_t + t_offset, :]
pe = pe.repeat_interleave(num_a, dim=0)
return pe
def get_agent_enc(self, num_t, num_a, a_offset, agent_enc_shuffle):
if agent_enc_shuffle is None:
ae = self.ae[a_offset: num_a + a_offset, :]
else:
ae = self.ae[agent_enc_shuffle]
ae = ae.repeat(num_t, 1, 1)
return ae
def forward(self, x, num_a, agent_enc_shuffle=None, t_offset=0, a_offset=0):
num_t = x.shape[0] // num_a
pos_enc = self.get_pos_enc(num_t, num_a, t_offset)
if self.use_agent_enc:
agent_enc = self.get_agent_enc(num_t, num_a, a_offset, agent_enc_shuffle)
if self.concat:
feat = [x, pos_enc.repeat(1, x.size(1), 1)]
if self.use_agent_enc:
feat.append(agent_enc.repeat(1, x.size(1), 1))
x = torch.cat(feat, dim=-1)
x = self.fc(x)
else:
x += pos_enc
if self.use_agent_enc:
x += agent_enc
return self.dropout(x)

Hope this helps!

from agentformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.