您的当前位置:首页正文

KAN文件kan层 阅读

2024-11-09 来源:个人技术集锦


class下属性介绍

这段代码定义了一个名为KANLayer的PyTorch模块类,它继承自nn.Module。以下是对类定义和属性的逐行解释:

class KANLayer(nn.Module):
  • 定义一个名为KANLayer的类,该类继承自nn.Module。这意味着KANLayer是一个自定义的神经网络层。

"""
KANLayer class
  • 这是一个多行字符串,用作类的文档字符串,描述KANLayer类。

Attributes:
  • 这部分开始列举类的属性,并给出每个属性的说明。

    -----------
        in_dim: int
            input dimension
  • in_dim:一个整数,表示输入数据的维度。

        out_dim: int
            output dimension
  • out_dim:一个整数,表示输出数据的维度。

        num: int
            the number of grid intervals
  • num:一个整数,表示网格区间的数量。

        k: int
            the piecewise polynomial order of splines
  • k:一个整数,表示样条函数的分段多项式阶数。

        noise_scale: float
            spline scale at initialization
  • noise_scale:一个浮点数,表示样条函数在初始化时的缩放比例

        coef: 2D torch.tensor
            coefficients of B-spline bases
  • coef:一个二维的PyTorch张量,表示B样条基的系数。

        scale_base_mu: float
            magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
  • scale_base_mu:一个浮点数,表示残差函数b(x)从正态分布N(μ, σ^2)中抽取的均值μ。
  • mu = sigma_base_mu
     

        scale_base_sigma: float
            magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
  • scale_base_sigma:一个浮点数,表示残差函数b(x)从正态分布N(μ, σ^2)中抽取的方差σ^2。

        scale_sp: float
            mangitude of the spline function spline(x)
  • scale_sp:一个浮点数,表示样条函数spline(x)的幅度。

        base_fun: fun
            residual function b(x)
  • base_fun:一个函数,表示残差函数b(x)。

        mask: 1D torch.float
            mask of spline functions. setting some element of the mask to zero means setting the corresponding activation to zero function.
  • mask:一个一维的PyTorch浮点张量,表示样条函数的掩码。
  • 将掩码中的某些元素设置为零意味着将相应的激活函数设置为零函数。

        grid_eps: float in [0,1]
            a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
  • grid_eps:一个在[0,1]范围内的浮点数,是一个用于update_grid_from_samples的超参数。
  • grid_eps = 1时,网格是均匀的;
  • grid_eps = 0时,网格使用样本的百分位数进行划分。
  • 0 < grid_eps < 1在两个极端之间插值。

        device: str
            device
  • device:一个字符串,表示设备(例如CPU或GPU)。

这个类的属性描述了KANLayer层的各种配置和状态,但类的方法(构造函数__init__和其他可能的方法)没有在给出的代码段中定义。这些属性将被用于配置和优化神经网络层的行为。

init

KANLayer类的构造函数__init__的实现,用于初始化KANLayer类的实例。 

def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.5, scale_base_mu=0.0, scale_base_sigma=1.0, scale_sp=1.0, base_fun=torch.nn.SiLU(), grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, save_plot_data = True, device='cpu', sparse_init=False):
  • 定义构造函数,接受多个参数来配置KANLayer的属性,包括输入和输出维度、网格间隔数量、多项式阶数、噪声缩放比例、残差函数的缩放参数、样条函数的缩放参数、残差函数、网格参数、网格范围、是否可训练参数、设备信息以及是否使用稀疏初始化。

super(KANLayer, self).__init__()
  • 调用父类nn.Module的构造函数来初始化基类。

# size 
self.out_dim = out_dim
self.in_dim = in_dim
self.num = num
self.k = k
  • 设置类的属性,包括输出维度、输入维度、网格间隔数量和多项式阶数。

grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)
  • 创建一个均匀分布的网格,并将其扩展到输入维度。

grid = extend_grid(grid, k_extend=k)
  • 调用一个名为extend_grid的函数来扩展网格,以适应多项式的阶数。
  • 需要注意的是,这个 extend_grid 函数不是 PyTorch 或其他常见库的标准函数,它可能是特定项目或代码库中的自定义函数。

self.grid = torch.nn.Parameter(grid).requires_grad_(False)
  • 将网格转换为PyTorch参数,但不计算梯度。

这行代码是使用PyTorch库时常见的操作,它涉及到创建一个不可训练的参数。以下是代码中每个部分的详细解释:

将所有这些部分结合起来,以下是对这行代码的解释:

  • self.grid = torch.nn.Parameter(grid):这行代码将grid张量转换为一个模型参数,并将其赋值给self.gridself表示这个属性属于一个类的实例。

  • .requires_grad_(False):这行代码随后将self.gridrequires_grad属性设置为False,意味着在模型的反向传播过程中,这个参数不会积累梯度,也不会被优化器更新。

通常,这种做法用于以下场景:

  • 固定参数:如果你有某些参数不希望在训练过程中被修改,你可以将其requires_grad设置为False。这在你想要保留某些预先定义的参数(如卷积核或坐标网格)时很有用。

  • 优化效率:如果某些参数不需要梯度,设置requires_gradFalse可以减少内存消耗,并可能加速计算,因为不需要为这些参数计算和存储梯度。

在深度学习模型中,这样的操作可能用于定义一个静态的坐标网格,该网格在训练过程中保持不变,但仍然作为模型的一部分参与前向传播计算。

noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num
  • 初始化噪声张量,并将其缩放到[-0.5, 0.5]的范围。

  1. torch.rand(self.num+1, self.in_dim, self.out_dim)

    • torch.rand:这是一个PyTorch函数,用于生成一个服从均匀分布的随机张量,其元素值在01之间(不包括1)。
    • self.num+1:这表示张量在第一个维度上的大小,其中self.num是某个类的属性,表示一个数字,这里张量的第一个维度比self.num1
    • self.in_dim:这是张量在第二个维度上的大小,也是类的属性。
    • self.out_dim:这是张量在第三个维度上的大小,同样是类的属性。

    因此,这行代码创建了一个形状为(self.num+1, self.in_dim, self.out_dim)的张量,其元素是均匀分布在01之间的随机数。

  2. - 1/2

    • 这一步将张量中的每个元素减去0.5,使得新的元素值在-0.50.5之间。
  3. * noise_scale / num

    • noise_scale:这看起来是某个类的属性或变量,用于控制噪声的大小或强度。
    • num:这可能是类的另一个属性或变量,表示一个数值。
    • 将上一步得到的张量乘以noise_scale / num,这进一步调整了噪声张量的大小,使得噪声的范围被缩放到由noise_scalenum确定的值。

综上所述,这行代码的目的是创建一个三维张量,其元素是均匀分布在-0.50.5之间的随机数,然后根据noise_scalenum对噪声进行缩放。这样的噪声张量可能在训练深度学习模型时用作输入数据的一部分,以增加模型的鲁棒性或用于其他目的。

self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
  • 使用curve2coef函数将曲线转换为系数,并将其设置为可训练的参数。

这行代码是使用PyTorch库在定义一个神经网络模型时创建一个模型参数。以下是代码中每个部分的详细解释:

  1. self.coef

    • 这是一个实例变量,用于存储模型中的一个参数。
  2. torch.nn.Parameter

    • 这是一个PyTorch类,用于将一个张量转换为模型参数。这意味着这个张量会被添加到模型的参数列表中,并且在训练过程中可以被优化。
  3. curve2coef

    • 这是一个函数,它将输入的曲线(在这里是网格和噪声)转换为系数。这个函数可能是自定义的,因为PyTorch标准库中没有名为curve2coef的函数。该函数的细节取决于其具体实现,但通常它用于从给定的输入中提取或计算系数。
  4. self.grid[:,k:-k].permute(1,0)

    • self.grid:这是一个实例变量,表示一个网格张量。
    • [:,k:-k]:这是一个切片操作,它在第一个维度上取所有元素,但在第二个维度上跳过前k个和后k个元素。
    • .permute(1,0):这是一个张量操作,用于交换张量的两个维度。在这里,它将第一个维度和第二个维度交换位置。
  5. noises

    • 这是之前创建的噪声张量。
  6. self.grid

    • 这是网格张量,这里作为curve2coef函数的输入之一。
  7. k

    • 这是一个变量,表示在第二个维度上要跳过的元素数量。

将这些部分组合起来,这行代码的目的是使用curve2coef函数从self.grid的一个切片和noises张量中计算系数,然后将结果转换为模型的一个参数。这个参数可能用于模型中的某些计算,比如卷积操作或者特征变换。

请注意,curve2coef函数的具体实现没有给出,因此我们无法提供更详细的解释。这个函数可能是根据特定的应用场景或数学模型定制的。

if sparse_init:
    self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(False)
else:
    self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(False)
  • 根据是否使用稀疏初始化,创建一个掩码张量,并将其设置为不可训练的参数。

self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
                     scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)
  • 初始化残差函数的缩放参数,并根据sb_trainable决定是否可训练。

self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * self.mask).requires_grad_(sp_trainable)  # make scale trainable
  • 初始化样条函数的缩放参数,并根据sp_trainable决定是否可训练。

self.base_fun = base_fun
  • 设置残差函数。

self.grid_eps = grid_eps
  • 设置网格参数grid_eps

self.to(device)
  • 将所有模型参数移动到指定的设备上(例如CPU或GPU)。

这个构造函数设置了KANLayer层的各种参数,初始化了网格、系数、掩码、缩放参数等,并将它们作为可训练或不可训练的参数。这些参数将在后续的模型训练过程中被优化。

forward

KANLayer类的forward方法,它定义了该层的前向传播过程。 

def forward(self, x):
  • 定义forward方法,它接受输入张量x

'''
KANLayer forward given input x
        
Args:
-----
    x : 2D torch.float
        inputs, shape (number of samples, input dimension)
        
Returns:
--------
    y : 2D torch.float
        outputs, shape (number of samples, output dimension)
    preacts : 3D torch.float
        fan out x into activations, shape (number of samples, output dimension, input dimension)
    postacts : 3D torch.float
        the outputs of activation functions with preacts as inputs
    postspline : 3D torch.float
        the outputs of spline functions with preacts as inputs
  • 这是方法的文档字符串,描述了输入参数和返回值的类型和形状。
  • 详细——
KANLayer forward given input x
  • 这一行说明这个forward方法是在给定输入x的情况下,KANLayer层的前向传播过程。

Args:
-----
    x : 2D torch.float
        inputs, shape (number of samples, input dimension)
  • 这部分描述了forward方法的输入参数:
    • x:一个二维的PyTorch浮点张量,代表输入数据。
    • shape (number of samples, input dimension):输入张量x的形状,其中第一维是样本数量,第二维是输入维度。

Returns:
--------
    y : 2D torch.float
        outputs, shape (number of samples, output dimension)
  • 这部分描述了forward方法返回的第一个值:
    • y:一个二维的PyTorch浮点张量,代表输出数据。
    • shape (number of samples, output dimension):输出张量y的形状,其中第一维是样本数量,第二维是输出维度。

    preacts : 3D torch.float
        fan out x into activations, shape (number of samples, output dimension, input dimension)
  • 这部分描述了forward方法返回的第二个值:
    • preacts:一个三维的PyTorch浮点张量,代表在激活函数之前的中间激活值。
    • shape (number of samples, output dimension, input dimension)preacts张量的形状,其中第一维是样本数量,第二维是输出维度,第三维是输入维度。

    postacts : 3D torch.float
        the outputs of activation functions with preacts as inputs
  • 这部分描述了forward方法返回的第三个值:
    • postacts:一个三维的PyTorch浮点张量,代表激活函数的输出,其中preacts是输入。
    • 由于preacts已经是三维的,postacts的形状将与preacts相同

    postspline : 3D torch.float
        the outputs of spline functions with preacts as inputs
  • 这部分描述了forward方法返回的第四个值:
    • postspline:一个三维的PyTorch浮点张量,代表样条函数的输出,其中preacts是输入。
    • postspline的形状也与preacts相同。

总的来说,这个文档字符串清楚地说明了forward方法期望的输入和它将返回的四个张量的形状和含义。

batch = x.shape[0]
  • 获取输入张量x的批处理大小。

preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)
  • 创建preacts张量,它是输入x的扩展版本,形状变为(batch, output dimension, input dimension)

base = self.base_fun(x) # (batch, in_dim)
  • 应用残差函数base_fun到输入x上,得到基础激活值。

y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
  • 使用coef2curve函数将系数转换为曲线,这里x_eval是输入xgrid是网格,coef是系数,k是多项式阶数。

postspline = y.clone().permute(0,2,1)
  • 创建postspline张量,它是y的转置版本,形状变为(batch, input dimension, output dimension)

y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
  • 将基础激活值与样条函数的输出相加,并分别乘以它们的缩放参数。

y = self.mask[None,:,:] * y
  • 应用掩码masky上,以选择性地激活或抑制某些输出。

postacts = y.clone().permute(0,2,1)
  • 创建postacts张量,它是y的转置版本,形状变为(batch, output dimension, input dimension)

y = torch.sum(y, dim=1)
  • y在第二个维度(即输入维度)上进行求和,以得到最终的输出。

return y, preacts, postacts, postspline
  • 返回最终的输出y,以及中间激活值preactspostactspostspline

这个forward方法实现了KANLayer层的前向传播,它计算了输入x的激活值,并返回了最终的输出以及中间激活值,这些中间激活值可能用于调试或进一步的分析。

这里的输入x变成输出y的过程涉及多个步骤,每个步骤都对输入x进行了一些操作,最终得到了输出y。以下是详细的步骤:

  1. 批处理大小获取

    • 首先通过batch = x.shape[0]获取输入张量x的批处理大小(即样本数量)。
  2. 创建preacts张量

    • 接着,preacts = x[:,None,:].clone().expand(batch, self.out_dim, self.in_dim)将输入x扩展为一个三维张量,其形状为(batch, output dimension, input dimension)。这是通过在x的第二维增加一个维度,然后复制并扩展到指定的形状来实现的。
  3. 应用残差函数

    • base = self.base_fun(x)应用一个残差函数base_fun到输入x上,得到基础激活值base,其形状为(batch, input dimension)
  4. 将系数转换为曲线

    • y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)使用coef2curve函数将系数转换为曲线,生成y张量,这是基于输入x、网格grid、系数coef以及多项式阶数k
  5. 创建postspline张量

    • postspline = y.clone().permute(0,2,1)创建y的转置版本,其形状变为(batch, input dimension, output dimension)
  6. 组合基础激活值和样条函数输出

    • y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y将基础激活值base与样条函数的输出y相加,并分别乘以它们的缩放参数scale_basescale_sp
  7. 应用掩码

    • y = self.mask[None,:,:] * y应用一个掩码masky上,以选择性地激活或抑制某些输出。
  8. 创建postacts张量

    • postacts = y.clone().permute(0,2,1)创建y的转置版本,其形状变为(batch, output dimension, input dimension)
  9. 求和得到最终输出

    • y = torch.sum(y, dim=1)y在第二个维度(即输入维度)上进行求和,以得到最终的输出y,其形状为(batch, output dimension)
  10. 返回输出和中间激活值

    • 最后,return y, preacts, postacts, postspline返回最终的输出y以及中间激活值preactspostactspostspline

综上所述,输入x通过一系列的扩展、函数应用、缩放、掩码应用和求和操作,被转换成了最终的输出y。这个过程涉及到多个中间步骤,每个步骤都对最终的输出有贡献。

下面重点关注spline文件讲了什么,待续

这个文件后面还有  初始化网格、更新网格等模块,待续

Top