Hi,
Between, I did not understand the use of hidden channel here.
Well I just added for simplicity, but you can think of it as an extra step in the reduction of the temporal data. If HIDDEN_C=3
the amount of features before and after conv_spatial
is identical and we only reduce the amount in the first convolution. If HIDDEN_C>3
there is a more soft reduction and the output of the conv_temp
is in an intermediate state of our input and target output, in terms of shape. But the question is does this really make a difference?
If HIDDEN_C
is close to 3, there is also the possibility ‘that information is lost, where certain points of the manifold collapse into each other’, to quote MobileNetV2. You can run the code below and you will see, that if the ‘hidden’ dimension is only a small multiple of the original dimension, that certain information is lost, but less in higher dims. So this could be another reason to choose a HIDDEN_C >3
, but in the end just test what works best ^^
import math
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
N_SEEDS = 5
ns = [2, 3, 6, 10, 15, 30]
MAX_N = ns[-1]
D = 2
x = torch.empty(300, D)
w = torch.randn(D, D*MAX_N)
for i in range(x.shape[0]):
x[i] = torch.tensor([i * math.cos(i/10), i * math.sin(i/10)])
fig, axs = plt.subplots(N_SEEDS, len(ns) + 1, figsize=(14, 6))
for seed in range(N_SEEDS):
w = torch.randn(D, D*MAX_N)
axs[seed, 0].set_title("original")
axs[seed, 0].set_aspect('equal')
axs[seed, 0].set_axis_off()
axs[seed, 0].plot(x[:, 0], x[:, 1])
for i, n in enumerate(ns, start=1):
w_ = w[:, :D*n]
x_up = F.relu(x @ w_)
x_down = x_up @ w_.pinverse()
axs[seed, i].set_title(f"Output/dim={n}")
axs[seed, i].set_aspect('equal')
axs[seed, i].set_axis_off()
axs[seed, i].plot(x_down[:, 0], x_down[:, 1])
plt.show()