您好,欢迎来到华佗小知识。
搜索
您的当前位置:首页DGCRN 模型网络模型详解 / DGCRN模型项目代码详解 (3) —— net.py

DGCRN 模型网络模型详解 / DGCRN模型项目代码详解 (3) —— net.py

来源:华佗小知识

 这是一份用于动态图卷积循环网络DGCRN模型项目代码理解与运行的入门教程,采用论文与代码结合的方式阐述动态图的实现过程与图卷积GCN在RNN中的运用。如果文中某方面解读有误请在评论区指出。

 如果对模型功能与数学计算过程还有疑惑请移步以下链接学习:


1 项目文件总览

        如上图所示,整体项目除数据集外共包含五个文件,其中 train.py, trainer.py 为训练模型相关的文件,包含重要的参数设置与“课程学习”等功能的实现,也是整体项目中最复杂的部分;net.py 文件为 DGCRN 模型整体网络架构,是模型的核心部分;layer.py 文件为模型需要用到的图卷积层;util.py 为实现整体功能的必要函数集合,包含数据集的构建与损失计算等。 


2 重要细节解释

在学习本节之前强烈建议先了解layer层图卷积的定义:

         在论文中,为了应对更为一般的有向图情况,作者采用双向图卷积来进行计算,公式如下:

故在代码中,相关的图卷积计算都是以此类形式呈现。

        现在梳理每个RNN步骤中需要用到的图卷积次数:

        对于生成动态邻接矩阵的模块,一共需要 8 个类型为 ‘hpyer’ 的GCN层。由于定义了节点嵌入,所以在无向图模式下需要 2个 hpyer GCN 来获取空间特征:

同时由于我们使用双向图卷积,故卷积层数量提升到 4个 hpyer GCN。并且在 seq2seq架构下,encoder-decoder的卷积层数量相同,故数量提升到 8个 hpyer GCN。此时就完整解释了代码的以下定义:

        self.GCN1_tg = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                           'hyper')

        self.GCN2_tg = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                           'hyper')

        self.GCN1_tg_de = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                              'hyper')

        self.GCN2_tg_de = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                              'hyper')

        self.GCN1_tg_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                             'hyper')

        self.GCN2_tg_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                             'hyper')

        self.GCN1_tg_de_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                                'hyper')

        self.GCN2_tg_de_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                                'hyper')

         对于RNN架构中的GCN计算,一共需要 12 个类型为 ‘RNN’ 的GCN层。由于我们使用动态图图卷积替换了GRU中的MLP层,故一个时间步需要的基本图卷积个数为 3个

同时由于我们使用双向图卷积,故卷积层数量提升到 6个 RNN GCN。并且在 seq2seq架构下,encoder-decoder的卷积层数量相同,故数量提升到 12个 RNN GCN。此时就完整解释了代码的以下定义:

        self.gz1 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gz2 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr1 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr2 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc1 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc2 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')

        self.gz1_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gz2_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr1_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr2_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc1_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc2_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')

3 DGCRN网络模型详解

3.1 参数定义与初始化部分

class DGCRN(nn.Module):
    def __init__(self,
                 gcn_depth,
                 num_nodes,
                 device,
                 predefined_A=None,
                 dropout=0.3,
                 subgraph_size=20,
                 node_dim=40,
                 middle_dim=2,
                 seq_length=12,
                 in_dim=2,
                 out_dim=12,
                 layers=3,
                 list_weight=[0.05, 0.95, 0.95],
                 tanhalpha=3,
                 cl_decay_steps=4000,
                 rnn_size=,
                 hyperGNN_dim=16):

参数解释: 

        self.emb1 = nn.Embedding(self.num_nodes, node_dim)
        self.emb2 = nn.Embedding(self.num_nodes, node_dim)
        self.lin1 = nn.Linear(node_dim, node_dim)
        self.lin2 = nn.Linear(node_dim, node_dim)

        定义两个节点嵌入和线性层,其中节点嵌入给每个节点一个长度为 node_dim 维度的向量,线性层,用于对嵌入进行变换,输入与输出维度都是 node_dim

        dims_hyper = [
            self.hidden_size + in_dim, hyperGNN_dim, middle_dim, node_dim
        ]

        self.GCN1_tg = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                           'hyper')

        self.GCN2_tg = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                           'hyper')

        self.GCN1_tg_de = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                              'hyper')

        self.GCN2_tg_de = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                              'hyper')

        self.GCN1_tg_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                             'hyper')

        self.GCN2_tg_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                             'hyper')

        self.GCN1_tg_de_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                                'hyper')

        self.GCN2_tg_de_1 = gcn(dims_hyper, gcn_depth, dropout, *list_weight,
                                'hyper')
        
        self.alpha = tanhalpha

        在 layer.py 中我们得知,hyper GCN线性层的维度信息储存在一个长度为 4 的列表中,以上代码 dims_hyper 即是这个储存维度信息的列表,结构如图所示。代码同时定义了重要参数  。

        self.device = device
        self.k = subgraph_size
        dims = [in_dim + self.hidden_size, self.hidden_size]

        self.gz1 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gz2 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr1 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr2 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc1 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc2 = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')

        self.gz1_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gz2_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr1_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gr2_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc1_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')
        self.gc2_de = gcn(dims, gcn_depth, dropout, *list_weight, 'RNN')

        已知论文中 DGCRM 模块的结构示意图如下:

        参数 subgraph_size 决定了图卷积子图个数 。

        在 layer.py 中我们得知,RNN GCN线性层的维度信息储存在一个长度为 2 的列表中,以上代码 dims 即是这个储存维度信息的列表,结构如图所示。

        self.use_curriculum_learning = True
        self.cl_decay_steps = cl_decay_steps
        self.gcn_depth = gcn_depth
  • self.use_curriculum_learning = True:这行代码设置了一个布尔值变量 use_curriculum_learning,表示是否启用 课程学习(Curriculum Learning)策略。
  • cl_decay_steps 是控制课程学习衰减的步数参数,通常用于设定学习难度逐渐增加的速度。这个参数决定了衰减的步长或者更新频率,通常与训练进度(如迭代步数)相关。
  • gcn_depth 是图卷积网络(GCN)中的深度,表示图卷积层的数量。是RNN GCN 与 Hyper GCN 的重要参数。

3.2 功能函数部分

3.2.1 preprocessing(self, adj, predefined_A)

    def preprocessing(self, adj, predefined_A):
        adj = adj + torch.eye(self.num_nodes).to(self.device)
        adj = adj / torch.unsqueeze(adj.sum(-1), -1)
        return [adj, predefined_A]

        这段代码是一个 预处理 函数,用于处理输入的邻接矩阵 adj。函数的最终返回值是包含预处理后的邻接矩阵和预定义邻接矩阵的列表。

        函数先给邻接矩阵自连接,后进行归一化处理:

        返回的是经过预处理后的邻接矩阵和预定义邻接矩阵:

3.2.2 initHidden(self, batch_size, hidden_size)

def initHidden(self, batch_size, hidden_size):
    use_gpu = torch.cuda.is_available()  # 检查是否可用GPU
    if use_gpu:  # 如果使用GPU
        Hidden_State = Variable(
            torch.zeros(batch_size, hidden_size).to(self.device))  # 初始化为零矩阵,存储隐藏状态
        Cell_State = Variable(
            torch.zeros(batch_size, hidden_size).to(self.device))  # 初始化为零矩阵,存储单元状态

        nn.init.orthogonal(Hidden_State)  # 使用正交初始化方法初始化隐藏状态
        nn.init.orthogonal(Cell_State)  # 使用正交初始化方法初始化单元状态

        return Hidden_State, Cell_State  # 返回隐藏状态和单元状态
    else:  # 如果不使用GPU(使用CPU)
        Hidden_State = Variable(torch.zeros(batch_size, hidden_size))  # 初始化为零矩阵,存储隐藏状态
        Cell_State = Variable(torch.zeros(batch_size, hidden_size))  # 初始化为零矩阵,存储单元状态
        return Hidden_State, Cell_State  # 返回隐藏状态和单元状态

        这是一个用于初始化 GRU 隐藏状态与细胞状态的函数,由于在第一个时步没有前一个时间步的隐藏状态与细胞状态,所以我们要对其进行初始化。 GRU中这两个指标的计算公式如下:

        在 GPU 模式下,隐藏状态 Hidden_State 和单元状态 Cell_State 都通过 torch.zeros() 初始化为大小为 [batch_size, hidden_size] 的全零矩阵。
        如果不使用GPU(即使用CPU),同样会初始化为零矩阵,但不调用 to(self.device),因此数据将默认在CPU上进行计算。
        nn.init.orthogonal() 是一种初始化方法,它通过正交初始化为张量填充权重。正交初始化能够保证每一层的权重矩阵是正交的(即行与行或列与列之间是正交的),有助于训练过程中的稳定性。

  • 隐藏状态:可以看作是GRU记住的“短期”信息,它随着时间步骤不断更新。
  • 单元状态:可以看作是GRU记住的“长期”信息。它决定了记忆的更新和保持。

3.2.3  step()

    def step(self,
             input,
             Hidden_State,
             Cell_State,
             predefined_A,
             type='encoder',
             idx=None,
             i=None):

        x = input

        x = x.transpose(1, 2).contiguous()

        nodevec1 = self.emb1(self.idx)
        nodevec2 = self.emb2(self.idx)

        hyper_input = torch.cat(
            (x, Hidden_State.view(-1, self.num_nodes, self.hidden_size)), 2)

        if type == 'encoder':

            filter1 = self.GCN1_tg(hyper_input,
                                   predefined_A[0]) + self.GCN1_tg_1(
                                       hyper_input, predefined_A[1])
            filter2 = self.GCN2_tg(hyper_input,
                                   predefined_A[0]) + self.GCN2_tg_1(
                                       hyper_input, predefined_A[1])

        if type == 'decoder':

            filter1 = self.GCN1_tg_de(hyper_input,
                                      predefined_A[0]) + self.GCN1_tg_de_1(
                                          hyper_input, predefined_A[1])
            filter2 = self.GCN2_tg_de(hyper_input,
                                      predefined_A[0]) + self.GCN2_tg_de_1(
                                          hyper_input, predefined_A[1])

        nodevec1 = torch.tanh(self.alpha * torch.mul(nodevec1, filter1))
        nodevec2 = torch.tanh(self.alpha * torch.mul(nodevec2, filter2))

        a = torch.matmul(nodevec1, nodevec2.transpose(2, 1)) - torch.matmul(
            nodevec2, nodevec1.transpose(2, 1))

        adj = F.relu(torch.tanh(self.alpha * a))

        adp = self.preprocessing(adj, predefined_A[0])
        adpT = self.preprocessing(adj.transpose(1, 2), predefined_A[1])

        Hidden_State = Hidden_State.view(-1, self.num_nodes, self.hidden_size)
        Cell_State = Cell_State.view(-1, self.num_nodes, self.hidden_size)

        combined = torch.cat((x, Hidden_State), -1)

        if type == 'encoder':
            z = F.sigmoid(self.gz1(combined, adp) + self.gz2(combined, adpT))
            r = F.sigmoid(self.gr1(combined, adp) + self.gr2(combined, adpT))

            temp = torch.cat((x, torch.mul(r, Hidden_State)), -1)
            Cell_State = F.tanh(self.gc1(temp, adp) + self.gc2(temp, adpT))
        elif type == 'decoder':
            z = F.sigmoid(
                self.gz1_de(combined, adp) + self.gz2_de(combined, adpT))
            r = F.sigmoid(
                self.gr1_de(combined, adp) + self.gr2_de(combined, adpT))

            temp = torch.cat((x, torch.mul(r, Hidden_State)), -1)
            Cell_State = F.tanh(
                self.gc1_de(temp, adp) + self.gc2_de(temp, adpT))

        Hidden_State = torch.mul(z, Hidden_State) + torch.mul(
            1 - z, Cell_State)

        return Hidden_State.view(-1, self.hidden_size), Cell_State.view(
            -1, self.hidden_size)

        step 函数是 DGCRN的核心部分,它定义了在一个时间步内的网络前向传播运算过程,包括动态图的生成与GRU的计算。它的作用是在网络的一个时间步上,基于当前的输入和状态,计算新的隐藏状态和单元状态。该函数通过输入 input、Hidden_State、Cell_State 等参数,执行图卷积操作和GRU更新,生成更新后的隐藏状态和单元状态。

        我们以实际例子来展示函数的运行过程,定义图结构节点个数为 ,输入特征个数为 ,批次 batch_size 为  :

        首先模型输入为:

x = input

        经过以此转置将特征转化到最后一个维度上:        

x = x.transpose(1, 2).contiguous()

         生成两个节点嵌入矩阵,其中每个节点嵌入向量维度为  :

        nodevec1 = self.emb1(self.idx)
        nodevec2 = self.emb2(self.idx)

        生成动态矩阵输入数据的特征拼接操作,将 和 Hidden_State 在特征维度进行拼接,得到一个新的张量 hyper_input。

其中,,:

        hyper_input = torch.cat(
            (x, Hidden_State.view(-1, self.num_nodes, self.hidden_size)), 2)

        对拼接数据进行双向图卷积操作提取空间特征(已知hyper GCN输出维度为 node_dim = 40):

        if type == 'encoder':

            filter1 = self.GCN1_tg(hyper_input,
                                   predefined_A[0]) + self.GCN1_tg_1(
                                       hyper_input, predefined_A[1])
            filter2 = self.GCN2_tg(hyper_input,
                                   predefined_A[0]) + self.GCN2_tg_1(
                                       hyper_input, predefined_A[1])

        if type == 'decoder':

            filter1 = self.GCN1_tg_de(hyper_input,
                                      predefined_A[0]) + self.GCN1_tg_de_1(
                                          hyper_input, predefined_A[1])
            filter2 = self.GCN2_tg_de(hyper_input,
                                      predefined_A[0]) + self.GCN2_tg_de_1(
                                          hyper_input, predefined_A[1])

        生成完整的节点嵌入,使用哈达玛乘积:

        nodevec1 = torch.tanh(self.alpha * torch.mul(nodevec1, filter1))
        nodevec2 = torch.tanh(self.alpha * torch.mul(nodevec2, filter2))

        生成动态邻接矩阵:

        a = torch.matmul(nodevec1, nodevec2.transpose(2, 1)) - torch.matmul(
            nodevec2, nodevec1.transpose(2, 1))

        adj = F.relu(torch.tanh(self.alpha * a))

        储存动态邻接矩阵:

        adp = self.preprocessing(adj, predefined_A[0])
        adpT = self.preprocessing(adj.transpose(1, 2), predefined_A[1])

        我们已经值得函数 preprocessing 的作用是生成一个两个元素的列表,列表的第一个元素为传入的第一个邻接矩阵的自连接归一化矩阵,列表的第二个原始为传入第二个邻接矩阵。

        此代码中, predefined_A 列表包含内容如下:

        此时,代码中两个列表储存矩阵如下:

        其中  和  都是动态邻接矩阵处理后的结果。

        Hidden_State = Hidden_State.view(-1, self.num_nodes, self.hidden_size)
        Cell_State = Cell_State.view(-1, self.num_nodes, self.hidden_size)

        这两行代码主要是使用 .view() 方法改变 Hidden_State Cell_State 张量的形状,以确保它们符合后续操作的输入要求。

        view(-1, self.num_nodes, self.hidden_size):这里 -1 的意思是让 PyTorch 自动推算出 batch_size,根据 num_nodes hidden_size 计算总的元素数量,确保最终的张量大小不变。

  • -1:自动推算出 batch_size(通常情况下,batch_size 是已知的,但 -1 让我们不需要手动指定它)。
  • self.num_nodes:保留节点数维度。
  • self.hidden_size:保留隐藏状态的特征维度。
  • 最终的张量形状会是 (batch_size, num_nodes, hidden_size),这是符合后续神经网络层输入要求的形状。
        combined = torch.cat((x, Hidden_State), -1)

        if type == 'encoder':
            z = F.sigmoid(self.gz1(combined, adp) + self.gz2(combined, adpT))
            r = F.sigmoid(self.gr1(combined, adp) + self.gr2(combined, adpT))

            temp = torch.cat((x, torch.mul(r, Hidden_State)), -1)
            Cell_State = F.tanh(self.gc1(temp, adp) + self.gc2(temp, adpT))
        elif type == 'decoder':
            z = F.sigmoid(
                self.gz1_de(combined, adp) + self.gz2_de(combined, adpT))
            r = F.sigmoid(
                self.gr1_de(combined, adp) + self.gr2_de(combined, adpT))

            temp = torch.cat((x, torch.mul(r, Hidden_State)), -1)
            Cell_State = F.tanh(
                self.gc1_de(temp, adp) + self.gc2_de(temp, adpT))

        Hidden_State = torch.mul(z, Hidden_State) + torch.mul(
            1 - z, Cell_State)

         这一部分是采用我们预设好的动态双向图卷积替换掉GRU中的MLP的计算过程,首先计算重置门部分:

 r = F.sigmoid(self.gr1(combined, adp) + self.gr2(combined, adpT))

         接下来计算更新门部分:

z = F.sigmoid(self.gz1(combined, adp) + self.gz2(combined, adpT))

         接下来计算被重置门处理后的信息 temp:

temp = torch.cat((x, torch.mul(r, Hidden_State)), -1)

 

         接下来计算候选隐藏状态细胞状态):

 Cell_State = F.tanh(self.gc1(temp, adp) + self.gc2(temp, adpT))

        接下来计算隐藏状态

 Hidden_State = torch.mul(z, Hidden_State) + torch.mul(1 - z, Cell_State)

 

        return Hidden_State.view(-1, self.hidden_size), Cell_State.view(
            -1, self.hidden_size)

        最终函数返回两个输出:Hidden_State Cell_State。

3.2.4 _compute_sampling_threshold(self, batches_seen)

def _compute_sampling_threshold(self, batches_seen):
    return self.cl_decay_steps / (
        self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps))

        _compute_sampling_threshold 是一个用于计算 "样本采样阈值" 的函数。它的作用是在训练过程中调整“课程学习”策略中的采样阈值,以控制新任务的学习难度。

其中:

  • 是“衰减步骤”,对应 cl_decay_steps;
  • 是“已见批次数”,对应 batches_seen;
  • exp 是自然指数函数。

初期训练:当训练的批次数  较小时,指数项  较小,因此采样阈值接近 1,模型主要学习简单样本。

后期训练:随着训练批次  增多,指数项  变大,采样阈值开始减少,最终趋近于 0,模型逐渐接触更复杂的样本进行学习。

3.3 前向传播部分

    def forward(self,
                input,
                idx=None,
                ycl=None,
                batches_seen=None,
                task_level=12):

参数解释:

  • input:输入数据,通常是一个 4D 张量,包含批次、时间步长、节点数等信息。
  • idx:用于标识某些特殊的索引或位置。
  • ycl:目标数据,用于监督学习任务,维度也是 4D 张量。
  • batches_seen:表示已经处理的批次数,在使用课程学习时会用到。
  • task_level:任务的级别或步骤数,这里是 12 表示需要进行 12 个预测步骤。
        predefined_A = self.predefined_A
        x = input
        batch_size = x.size(0)
        Hidden_State, Cell_State = self.initHidden(batch_size * self.num_nodes,
                                                   self.hidden_size)

        初始化输入数据、邻接矩阵、批次数、隐藏状态与细胞状态。

编码器部分(Encoder):

for i in range(self.seq_length):
    Hidden_State, Cell_State = self.step(torch.squeeze(x[..., i]),
                                         Hidden_State, Cell_State,
                                         predefined_A, 'encoder', idx,
                                         i)
  • for i in range(self.seq_length):对每个时间步进行迭代,seq_length 表示序列的长度。
  • torch.squeeze(x[..., i]):从输入数据中提取第 i 个时间步的数据。
  • Hidden_State, Cell_State = self.step(...):调用 step 函数进行一次 GRU 的前向传播,返回更新后的隐藏状态和细胞状态。

if outputs is None:
    outputs = Hidden_State.unsqueeze(1)
else:
    outputs = torch.cat((outputs, Hidden_State.unsqueeze(1)), 1)

        通过 outputs 变量将每个时间步的 Hidden_State 拼接起来,形成最终的输出序列:

  • outputs = Hidden_State.unsqueeze(1):如果是第一个时间步,将隐藏状态添加到输出中。
  • torch.cat(...):后续时间步的隐藏状态会被追加到 outputs 中。

        循环12次后输出结果如下:

解码器部分(Decoder):

decoder_input = go_symbol
outputs_final = []

for i in range(task_level):
    decoder_input = torch.cat([decoder_input, timeofday[..., i]], dim=1)
    Hidden_State, Cell_State = self.step(decoder_input, Hidden_State, Cell_State, predefined_A, 'decoder', idx, None)
    decoder_output = self.fc_final(Hidden_State)
  • decoder_input = go_symbol:解码器的输入初始化为零张量 go_symbol。
  • time_of_day = ycl[:, 1:, :, :]:从目标数据 ycl 中提取时间段数据。time_of_day 代表一天中的其他时间特征。
  • torch.cat([decoder_input, timeofday[..., i]], dim=1):将目标数据拼接到解码器的输入中。
  • Hidden_State, Cell_State = self.step(...):再次调用 step 函数,使用更新后的输入和状态来生成新的隐藏状态。
  • decoder_output = self.fc_final(Hidden_State):通过 fc_final 层将隐藏状态映射到输出空间,产生最终的预测。

 

decoder_input = decoder_output.view(batch_size, self.num_nodes, self.output_dim).transpose(1, 2)

        decoder_input 更新为新的预测值,用于下一个时间步的解码。

课程学习(Curriculum Learning):

if self.training and self.use_curriculum_learning:
    # 检查当前是否处于训练模式,并且是否启用了课程学习策略。
    c = np.random.uniform(0, 1)
    # 生成一个在 [0, 1) 范围内的随机数 c。
    
    if c < self._compute_sampling_threshold(batches_seen):
        # 计算一个采样阈值,并判断随机数 c 是否小于该阈值。
        decoder_input = ycl[:, :1, :, i]
        # 如果 c 小于阈值,则将目标数据 ycl 的前一个时间步作为解码器的输入。

        该部分的作用是在训练时按一定的概率选择是否使用目标数据作为解码器的输入。

        outputs_final = torch.stack(outputs_final, dim=1)

        outputs_final = outputs_final.view(batch_size, self.num_nodes,
                                           task_level,
                                           self.output_dim).transpose(1, 2)

        return outputs_final

        将每个时间步的解码器输出堆叠起来,得到一个包含所有预测结果的张量通过 view 和 transpose 调整输出的形状,最后返回最终的输出 outputs_final


因篇幅问题不能全部显示,请点此查看更多更全内容

Copyright © 2019- huatuo0.cn 版权所有 湘ICP备2023017654号-2

违法及侵权请联系:TEL:199 18 7713 E-MAIL:2724546146@qq.com

本站由北京市万商天勤律师事务所王兴未律师提供法律服务