I’m working on a project where I apply Kolmogorov-Arnold Networks (KAN) on the Mixer architecture (KAN-Mixer). You can check the issue here
This is the code that I tried to run
X = torch.rand(64, 1, 28, 28) # MNIST (64, 1, 28, 28), CIFAR10 (64, 3, 32, 32)
pe = PatchEmbedding(1, 128, 4) # MNIST (1, 128, 4), CIFAR10 (3, 128, 4)
t1 = Transformation1()
t2 = Transformation2()
ml = MixerLayer(128, 49, 256, 256) # MNIST (128, 49, 256, 256), CIFAR10 (128, 64, 256, 256)
print(f"Input shape: {X.shape}")
y1 = pe(X)
print(f"Patch Embedding output shape: {y1.shape}")
y2 = t2(y1)
print(f"T2 transformation output shape: {y2.shape}")
y3 = t1(y2)
print(f"T1 transformation output shape: {y3.shape}")
y4 = t1(y3)
print(f"T1 transformation output shape: {y4.shape}")
y5 = ml(y4) # <--- Problem starts here
print(f"Mixer Layer output shape: {y5.shape}")
This is the Error message
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 6
4 t1 = Transformation1()
5 t2 = Transformation2()
----> 6 ml = MixerLayer(128, 49, 256, 256) # MNIST (128, 49, 256, 256), CIFAR10 (128, 64, 256, 256)
8 print(f"Input shape: {X.shape}")
9 y1 = pe(X)
File ~/model.py:265, in MixerLayer.__init__(self, embedding_dim, num_patch, token_intermediate_dim, channel_intermediate_dim, dropout)
259 def __init__(self, embedding_dim, num_patch, token_intermediate_dim, channel_intermediate_dim, dropout=0.):
260 super().__init__()
262 self.token_mixer = nn.Sequential(
263 nn.LayerNorm(embedding_dim),
264 Transformation1(),
--> 265 KAN(num_patch, token_intermediate_dim, dropout),
266 Transformation1()
267 )
269 self.channel_mixer = nn.Sequential(
270 nn.LayerNorm(embedding_dim),
271 KAN(embedding_dim, channel_intermediate_dim, dropout),
272 )
File ~/model.py:223, in KAN.__init__(self, dim, intermediate_dim, dropout)
220 def __init__(self, dim, intermediate_dim, dropout = 0.):
221 super().__init__()
222 self.kan = nn.Sequential(
--> 223 KANLinear(dim, intermediate_dim),
224 KANLinear(intermediate_dim, dim),
225 )
File ~/model.py:60, in KANLinear.__init__(self, in_features, out_features, grid_size, spline_order)
56 self.spline_scaler = nn.Parameter(torch.Tensor(out_features, in_features))
58 self.base_activation = nn.SiLU()
---> 60 self.reset_parameters()
File ~/model.py:73, in KANLinear.reset_parameters(self)
71 # Compute the spline weight coefficients from the random noise
72 grid_points = self.grid.T[self.spline_order : -self.spline_order]
---> 73 spline_coefficients = self.curve2coeff(grid_points, random_noise)
75 # Copy the computed coefficients to the spline weight tensor
76 self.spline_weight.data.copy_(spline_coefficients)
File ~/model.py:177, in KANLinear.curve2coeff(self, input_tensor, output_tensor)
166 """
167 Compute the coefficients of the curve that interpolates the given points.
168
(...)
174 torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
175 """
176 # Compute the B-spline bases for the input tensor
--> 177 b_splines_bases = self.b_splines(
178 input_tensor
179 ) # (batch_size, input_dim, grid_size + spline_order)
181 # Transpose the B-spline bases and output tensor for matrix multiplication
182 transposed_bases = b_splines_bases.transpose(
183 0, 1
184 ) # (input_dim, batch_size, grid_size + spline_order)
File ~/model.py:149, in KANLinear.b_splines(self, x)
147 test = (input_tensor_expanded - expanded_grid[:, :, : -order - 1]) / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
148 print(f"left_term: {test.shape}")
--> 149 test2 = test * bases[:, :, :-1]
151 left_term = (
152 (input_tensor_expanded - expanded_grid[:, :, : -order - 1])
153 / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
154 ) * bases[:, :, :-1]
156 right_term = (
157 (expanded_grid[:, :, order + 1 :] - input_tensor_expanded)
158 / (expanded_grid[:, :, order + 1 :] - expanded_grid[:, :, 1:-order])
159 ) * bases[:, :, 1:]
RuntimeError: The size of tensor a (10) must match the size of tensor b (11) at non-singleton dimension 2
Problem starts here
class KANLinear(nn.Module):
# ...
def b_splines(self, x: torch.Tensor):
# ...
# Compute the B-spline bases recursively
for order in range(1, self.spline_order + 1):
input_tensor_expanded = input_tensor_expanded.repeat(1, 1, expanded_grid[:, :, : -order - 1].shape[-1])
test = (input_tensor_expanded - expanded_grid[:, :, : -order - 1]) / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
print(f"left_term: {test.shape}")
test2 = test * bases[:, :, :-1] # <--- Problem starts here
left_term = (
(input_tensor_expanded - expanded_grid[:, :, : -order - 1])
/ (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
) * bases[:, :, :-1]
# ...
return bases.contiguous()
# ...
Data’s dimension
expanded_grid
andbases
:torch.Size([6, 49, 12])
input_tensor_expanded
:torch.Size([6, 49])
After computing the B-Spline bases
input_tensor_expanded
, expanded_grid
, and test
: torch.Size([6, 49, 10])
bases
: torch.Size([6, 49, 12])
Problem
I need to ensure that the dimensions of bases
match those of input_tensor_expanded
and expanded_grid
to avoid the error I encountered. Currently, the dimension of test
is torch.Size([6, 49, 10])
and the dimension of bases is torch.Size([6, 49, 12])
. I need to figure out how to adjust the dimension of bases
to match the others.