|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import math |
| 4 | + |
| 5 | +class PositionalEncoding(nn.Module): |
| 6 | + def __init__(self, d_model, max_len=5000): |
| 7 | + super(PositionalEncoding, self).__init__() |
| 8 | + |
| 9 | + # Create a matrix of shape (max_len, d_model) containing the positional encodings |
| 10 | + pe = torch.zeros(max_len, d_model) |
| 11 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| 12 | + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| 13 | + |
| 14 | + pe[:, 0::2] = torch.sin(position * div_term) |
| 15 | + pe[:, 1::2] = torch.cos(position * div_term) |
| 16 | + |
| 17 | + pe = pe.unsqueeze(0).transpose(0, 1) |
| 18 | + self.register_buffer('pe', pe) |
| 19 | + |
| 20 | + def forward(self, x): |
| 21 | + x = x + self.pe[:x.size(0), :] |
| 22 | + return x |
| 23 | + |
| 24 | +class TransformerEmbedding(nn.Module): |
| 25 | + def __init__(self, vocab_size, d_model, max_len): |
| 26 | + super(TransformerEmbedding, self).__init__() |
| 27 | + self.token_embedding = nn.Embedding(vocab_size, d_model) |
| 28 | + self.position_encoding = PositionalEncoding(d_model, max_len) |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + token_embeddings = self.token_embedding(x) |
| 32 | + embeddings = self.position_encoding(token_embeddings) |
| 33 | + return embeddings |
| 34 | + |
| 35 | +# Example usage |
| 36 | +vocab_size = 30522 # vocabulary size |
| 37 | +d_model = 512 # Embedding size |
| 38 | +max_len = 100 # Maximum sequence length |
| 39 | + |
| 40 | +embedding_layer = TransformerEmbedding(vocab_size, d_model, max_len) |
| 41 | +input_ids = torch.tensor([[101, 19204, 2135, 1567, 2003, 2019, 2590, 3350, 1012, 102]]) # Example input |
| 42 | +embeddings = embedding_layer(input_ids) |
| 43 | + |
| 44 | +print(embeddings.shape) # Output shape should be (sequence_length, batch_size, d_model) |
0 commit comments