当前位置: 首页 > news >正文

文本嵌入层

1、代码演示

embedding = nn.Embedding(10,3)
print(embedding)
input = torch.LongTensor([[1,2,3,4],[4,3,2,9]])
embedding(input)

2、构建Embeddings类来实现文本嵌入层

# 构建Embedding类来实现文本嵌入层
class Embeddings(nn.Module):
    def __init__(self,d_model,vocab):
        """
        :param d_model: 词嵌入的维度
        :param vocab: 词表的大小
        """
        super(Embeddings,self).__init__()
        self.lut = nn.Embedding(vocab,d_model)
        self.d_model = d_model
    def forward(self,x):
        """
        :param x: 因为Embedding层是首层,所以代表输入给模型的文本通过词汇映射后的张量
        :return:
        """
        return self.lut(x) * math.sqrt(self.d_model)
x = Variable(torch.LongTensor([[100,2,42,508],[491,998,1,221]]))
emb = Embeddings(512,1000)
embr = emb(x)
print(embr.shape)             # torch.Size([2, 4, 512])
print(embr)
print(embr[0][0].shape)       # torch.Size([512])

相关文章:

  • Qt raise()问题
  • 【QT】使用toBase64方法将.txt文件的明文变为非明文(类似加密)
  • Mysql生产随笔
  • vue下载在前端存放的pdf文件
  • 玩碎Java之CompletableFuture的例子
  • Java初始化大量数据到Neo4j中(二)
  • lambda的使用案例(1)
  • 探索视听新纪元: ChatGPT的最新语音和图像功能全解析
  • Flutter笔记:AnimationMean、AnimationMax 和 AnimationMin 三个类的用法
  • 朴素贝叶斯分类(下):数据挖掘十大算法之一
  • 了解ActiveMQ、RabbitMQ、RocketMQ和Kafka的特点
  • 嵌入式开源库之libmodbus学习笔记
  • 27、Flink 的SQL之SELECT (Pattern Recognition 模式检测)介绍及详细示例(7)
  • Linux网络编程- struct ifreq ioctl() 系统调用
  • Android 13 - Media框架(8)- MediaExtractor(2)
  • 机器学习第十四课--神经网络
  • stream对list数据进行多字段去重
  • 问答区混赏金的集合贴
  • 华为杯数学建模比赛经验分享
  • $nextTick解决echarts宽度固定为100%的问题
  • 诺奖得主等数十位经济学家发表宣言反对美关税政策
  • 境外机构来华发行熊猫债累计超9500亿元
  • 马上评|古籍书店焕新归来,“故纸陈香”滋养依旧
  • 北京一季度GDP为12159.9亿元,同比增长5.5%
  • 霸王茶姬成美股“中国茶饮第一股”:首日涨近16%,市值60亿美元
  • OpenAI推出全新推理模型o3、o4-mini,以及一个编程智能体