下面是一个使用TensorFlow 2.0实现的可运行的word2vec代码示例:
import tensorflow as tf
from tensorflow.keras import layers
# 定义Word2Vec模型
class Word2Vec(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim):
super(Word2Vec, self).__init__()
self.target_embedding = layers.Embedding(vocab_size, embedding_dim)
self.context_embedding = layers.Embedding(vocab_size, embedding_dim)
def call(self, target, context):
target_embed = self.target_embedding(target)
context_embed = self.context_embedding(context)
return target_embed, context_embed
# 构建数据集(假设已经有了训练数据)
def build_dataset(texts, window_size):
# 构建单词索引
tokenizer = tf.keras.preprocessing.text.Tokenizer()
tokenizer.fit_on_texts(texts)
word_index = tokenizer.word_index
vocabulary_size = len(word_index) + 1
# 构建训练样本
sequences = tokenizer.texts_to_sequences(texts)
dataset = []
for seq in sequences:
for i in range(len(seq)):
for j in range(max(i - window_size, 0), min(i + window_size, len(seq))):
if i != j:
dataset.append((seq[i], seq[j]))
return dataset, vocabulary_size
# 训练Word2Vec模型
def train_word2vec(texts, embedding_dim=100, window_size=5, epochs=10):
dataset, vocab_size = build_dataset(texts, window_size)
model = Word2Vec(vocab_size=vocab_size, embedding_dim=embedding_dim)
loss_fn = tf.keras.losses.CosineSimilarity()
optimizer = tf.keras.optimizers.Adam()
for epoch in range(epochs):
total_loss = 0
for target, context in dataset:
with tf.GradientTape() as tape:
target_embed, context_embed = model(target, context)
loss = loss_fn(target_embed, context_embed)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
total_loss += loss.numpy()
avg_loss = total_loss / len(dataset)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
return model
# 示例用法
texts = ["I love machine learning", "Deep learning is fascinating", "Artificial intelligence is the future"]
model = train_word2vec(texts)
# 获取单词的嵌入向量
word_embeddings = model.target_embedding.weights[0].numpy()
# 打印单词和对应的嵌入向量
for word, index in tokenizer.word_index.items():
print(f"{word}: {word_embeddings[index]}")
请注意,以上代码仅提供一个基本示例,实际使用时可能需要根据数据集和任务进行适当调整。
内容由零声教学AI助手提供,问题来源于学员提问