How to be able to automatically tune a CNN without ending up in the following error?

Hello. I am trying to use a genetic algorithm to tune the hyperparameters of my neural network including the number of layers, the filter sized, the kernel sizes, the paddings, etc. At some iteration of the genetic algorithm the following error is being thrown into the terminal.

Given input size: (61x1x1). Calculated output size: (61x0x0). Output size is too small

Every time the reported sizes are different but the essence of the error is the same.
How would I be able to automatically tune a CNN without ending up in this error?

Your architecture search would have to check, if some activations could become empty tensors e.g. if the pooling or conv layers are too aggressively downsampling these tensors.
I don’t know how your workflow is used and what determines, if a layer should be added or which config should be used.

Like ptrblock says, you have to generate conv/pool configuration before committing to it and mimic the calculations of the pool / conv layers in your NAS so that you disregard upfront the configuration candidates that pool too much or stride to much to end up with a positive input size at the end.

In this way you can prune bad configs early, as well as generate them much quicker.

You’ll want to know that to calculate the output size of your tensor through the conv layer (how it looks for conv1d at least),
L_out = floor((L_in + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)

Also, pooling of course divides the output size by its stride.

Also I’d skip the randomizing of the padding, and just use:
padding = ceil(kernel_size - stride) // 2

That will narrow your hyper-parameter space a bit without sacrificing much.

1 Like

Thanks a lot for your replies @ptrblck and @superunification. It is clear to me now i was doing it in an overly simplistic way without mimicing the calculations of the pool/conv layers in my NAS, etc.

The way I am doing it now is the following. I use this genetic algorithm as my NAS routine, and I feed a very long vector hyperparameters__ as an input. Then I retrieve from the vector the number pf layers, kernel sizes, when to do pooling layers etc. It looks roughly like this:

    N_conv_max__ = int(hyperparameters__[-2]) # the number of conv-layers
    N_fc_max__ = int(hyperparameters__[-1])  # the number of fc-layers
    conv_layer_out_channels = hyperparameters__[0*N_conv_max:0*N_conv_max+N_conv_max__]
    conv_layer_kernel_sizes = hyperparameters__[1*N_conv_max:1*N_conv_max+N_conv_max__]
    conv_layer_strides = hyperparameters__[2*N_conv_max:2*N_conv_max+N_conv_max__]
    conv_layer_paddings = hyperparameters__[3*N_conv_max:3*N_conv_max+N_conv_max__]
    max_pool_layer_numbers = hyperparameters__[4*N_conv_max:4*N_conv_max+N_conv_max__]
    max_pool_kernel_sizes = hyperparameters__[5*N_conv_max:5*N_conv_max+N_conv_max__]
    max_pool_strides = hyperparameters__[6*N_conv_max:6*N_conv_max+N_conv_max__]
    max_pool_paddings = hyperparameters__[7*N_conv_max:7*N_conv_max+N_conv_max__]
    batch_norm_2d_layer_numbers = hyperparameters__[8*N_conv_max:8*N_conv_max+N_conv_max__]
    linear_layer_out_features = hyperparameters__[9*N_conv_max:9*N_conv_max+N_fc_max__]
    batch_norm_1d_layer_numbers = hyperparameters__[9*N_conv_max+N_fc_max:9*N_conv_max+N_fc_max+N_fc_max__]

where N_conv_max is the maximum possible number of conv-layers that the genetic algorithm can try, and N_fc_max is the maximum possible number of fc-layers that the genetic algorithm can try. I set them to be 15 and 2 correspondingly.

Then, I use the restricting intervals for each of the parameter (i.e. element in hyperparameters__) to tell the genetic algorithm to sample values only from those intervals. These are the intervals I use (including the rest of the NAS routine):

    N_dims = N_conv_max*9+N_fc_max*2+2 # last 2 for N_conv_max and N_fc_max values
    
    # 1 conv_layer_out_channels
    a1 = [[3,7]]+[[10,20]]+[[10,32]]+[[22,41]]+[[21,50]]+[[30,60]]+[[30,70]]+[[30,80]]+\
        [[40,90]]+[[40,100]]+[[50,110]]+[[50,120]]+[[50,130]]+[[60,140]]+[[60,150]]
    # 2 conv_layer_kernel_sizes
    a2 = [[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+\
        [[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]+[[3,9]]
    # 3 conv_layer_strides
    a3 = [[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+\
        [[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]+[[1,5]]
    # 4 conv_layer_paddings
    a4 = [[0,1]]+[[3,5]]+[[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]+\
        [[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]+[[5,10]]
    # 5 max_pool_layer_numbers
    a5 = [[0,0]]+[[0,1]]+[[0,0]]+[[0,1]]+[[0,0]]+[[0,1]]+[[0,0]]+[[0,1]]+[[0,0]]+\
        [[0,1]]+[[0,0]]+[[0,1]]+[[0,0]]+[[0,1]]+[[0,0]]
    # 6 max_pool_kernel_sizes
    a6 = [[4,4]]*N_conv_max
    # 7 max_pool_strides
    a7 = [[1,2]]*N_conv_max
    # 8 max_pool_paddings
    a8 = [[1,2]]*N_conv_max
    # 9 batch_norm_2d_layer_numbers
    a9 = [[0,1]]*N_conv_max
    # 10 linear_layer_out_features
    a10 = [[1,1500]]*N_fc_max
    # 11 batch_norm_1d_layer_numbers
    a11 = [[0,1]]*N_fc_max
    # 12
    a12 = [[1,N_conv_max]]
    # 13
    a13 = [[1,N_fc_max]]
    
    varbound=np.array(a1+a2+a3+a4+a5+a6+a7+a8+a9+a10+a11+a12+a13) 
    vartype=np.array([['int']]*N_dims)

    algorithm_param = {'max_num_iteration': 1000,\
                       'population_size':100,\
                       'mutation_probability':0.1,\
                       'elit_ratio': 0.01,\
                       'crossover_probability': 0.5,\
                       'parents_portion': 0.3,\
                       'crossover_type':'uniform',\
                       'max_iteration_without_improv':None}
    
    model=ga(function=function_to_optimize,\
                dimension=N_dims,\
                variable_type_mixed=vartype,\
                function_timeout=3600,\
                variable_boundaries=varbound,\
                algorithm_parameters=algorithm_param)

    model.run()
    convergence=model.report
    solution=model.output_dict    
    plt.plot(convergence)
    plt.legend([solution['function']])

Finally, I use the cost-value at the last epoch as the optimisation criterion for the genetic algorithm.

Could you guys @superunification @ptrblck comment on this approach and what specific NAS-framework would you recommend?

I’m afraid it’s not easy for me to follow what’s going on in those snippets, but it looks like all the parameters are decoupled from each other (as opposed to tuples of [kernel_size, padding, stride, …]) and fed into ga, so I don’t know how you’d accomplish the pruning we talked about earlier.

I found that it’s generally easier to just write my own genetic algorithm, as it’s conceptually dead easy to get the basics working (init pop, mutate, select, breed, repeat steps 2-4), and the particular mutators are almost always problem-specific.

Secondly, it’s just easier to figure out problems when the algorithm isn’t a black box in some library.

Like, for example, does this ga algorithm select from a uniform or Gaussian distribution along provided ranges? Quite often, it’s best to sample from a Gaussian so long as it’s mean is around where the optimal parameter value lies (you can shift the mean to the best-performing parameter value and reduce the variance as training progresses, kind of like simulated annealing), while in contrast sampling from uniform distributions could make good configurations really sparse.

Thanks for your reply @superunification. It seems that writing my own GA would me more reasonable.
In addition, to computing L_out and padding, and other considerations you have mentioned, is there anything else (other technical considerations) that I should keep in mind / forsee when implementing this adaptive approach and that you would tell off the top of your head?

You can have a look at my code snippet for generating CNN archs. It’s for 1-D stuff so you’ll need to just adapt accordingly. I’m mostly a java/c++ programmer so full disclaimer this stuff isn’t the most pythonic.

def get_random_arch_config(window_size: int,
                           param_budget: int,
                           trials_before_bailing: int, 
                           num_layers: int) -> list:
    """    
    :param window_size: 
        the L dimension of the input tensor, used to be sure the 
        candidate architectures don't pool too much and end up with 
        < 1 length outputs 
    :param param_budget: 
        the allowable number of estimated parameters to use in the configuration. 
        Candidate configurations with more than this budget are ignored and retried. 
    :param trials_before_bailing:
        how many random configurations to try, from which to
        choose the best configuration.
        If no random configs are acceptable, bail out and we'll
        use the default
    :return:
        a list of conv_config configurations for conv blocks, each its own list with the format:
             [out_channel_log2, kernel_size, pool_stride]
    """

    rand = Random()
    best_conv_configs = None
    best_budget = 1e7
    num_trials = -1
    remaining_budget = param_budget

    # find the best (in terms of using most parameters up to the budget limit)
    # of 'trials_before_bailing' random architectures
    for s in range(trials_before_bailing):
        layer_configs = []

        remaining_budget = param_budget
        in_chans = 1
        comp_input_size = window_size

        for layer_num in range(num_layers):
            kernel_size = round(1.5 ** rand.randint(0, 10))
            out_channels_log2 = rand.randint(4, 11)
            out_chans = 2 ** out_channels_log2
            stride = util.kernel_size_to_stride(kernel_size, comp_input_size)

            # final layer needs to do 1x1 convolutions and no more reduction
            if layer_num == num_layers - 1:
                kernel_size = 1
                stride = 1
                out_chans = window_size
                out_channels_log2 = round(math.log(window_size, 2))

            # enforces the constraint of ending up with input_size=1 by last layer
            if layer_num == num_layers - 2:
                stride = 1
                pool_stride = comp_input_size
            else:
                pool_stride = 1 if comp_input_size == 1 else 2 ** rand.randint(0, 2)

            padding = math.ceil(kernel_size - stride) // 2

            total_ops = util.calc_conv_ops(comp_input_size, in_chans, out_chans, kernel_size, stride)
            remaining_budget -= total_ops

            temp_input_size = comp_input_size
            comp_input_size = util.calc_output_size(comp_input_size, kernel_size, stride, pool_stride, padding)

            if comp_input_size <= 0 or remaining_budget < 0:
                break

            layer = {"in_ch": in_chans, "out_ch_log2": out_channels_log2, "in_size": temp_input_size,
                     "out_size": comp_input_size, "kernel_size": kernel_size, "pool_stride": pool_stride}

            layer_configs.append(layer)
            in_chans = out_chans

        if comp_input_size <= 0:
            continue

        if 0 < remaining_budget < best_budget:
            best_conv_configs = layer_configs
            best_budget = remaining_budget
            num_trials = s

        # return early if it's this optimal
        if 0 < remaining_budget < param_budget * 0.1:
            break

    conv_configs = []
    if best_conv_configs:
        for config in best_conv_configs:
            conv_configs.append([config["out_ch_log2"], config["kernel_size"], config["pool_stride"]])
    return conv_configs

My crossover function is simple but it might save you time:


def crossover_config(rand: Random, 
                     window_size: int, 
                     parent_1: dict, 
                     parent_2: dict, 
                     param_budget: int) -> list:
    viable = False
    new_conv_config = []
    cost = param_budget + 1
    tries = 0
    p1_conv_blocks: List[List] = parent_1["conv_config"]
    p2_conv_blocks: List[List] = parent_2["conv_config"]

    while cost > param_budget or not viable:
        new_conv_config = []
        sht = p1_conv_blocks if len(p1_conv_blocks) < len(p2_conv_blocks) else p2_conv_blocks
        lon = p1_conv_blocks if len(p1_conv_blocks) >= len(p2_conv_blocks) else p2_conv_blocks

        for i in range(len(sht)):
            source = p1_conv_blocks[i] if rand.random() < 0.5 else p2_conv_blocks[i]
            new_conv_config.append(copy.deepcopy(source))


        # give the child the average depth of the parents        
        for i in range((len(lon) - len(sht)) // 2):
            new_conv_config.append(copy.deepcopy(lon[i]))

        cost, viable = util.calc_config_param_cost_and_viability(window_size, new_conv_config)

        # bail and choose the first parent's config
        if not viable and tries > 100:
            new_conv_config = copy.deepcopy(p1_conv_blocks)
            break

        tries += 1
    return new_conv_config

Some ideas:

  • make a utility function for checking the validity of your network, and use this whenever you mutate / crossover to check that the newly created architecture is still viable.
  • You probably want to write a hash function for your architecture config to check whether it’s the same as another one in your population (if it is, you’ll just duplicate effort in running it).
  • Inherit conv layer’s input channels param from the output channels param of the previous conv layer (you definitely don’t want to randomize that or only a small fraction of architectures will be self consistent)
  • Make liberal use of copy.deepcopy(xyz) when mutating / doing crossover – you don’t want to be changing other configurations by accident.

Make your own mutate operators like

  • delete conv-block
  • duplicate last block
  • random adjustment to kernel size, pool stride, out channels

Also, if you come up with good improvements – I’d like to know! I had trouble generating configurations longer than, say 14 blocks deep, due to the sparsity of viable random configurations. I considered adding some notion of “ambition” as the layers are generated, reducing this quantity if large amounts of the param budget were used early.

Thanks @superunification! I will try. By the way, would you have a recommendation of how the convergence criterion should be used. So far I used the value of the cost function obtained after the last epoch. However, this might not necessarily be the best as the cost function can be lower at an intermediate epoch. What is the best way to deal with this?

@superunification @ptrblck Probably, another question that I had was how is the process of scaling up the network goes. So, I try various configurations at the first layer, and then I choose the one with the minimum cost-function, then I freeze the configuration of that layer and add the second layer. Then I search for the configuration of the second layer that gives me the minimum of the cost function. And so on. Is this correct approach? Are there other approaches for finding the best network configuration?

For the criterion I settled on log(eval loss + 1) * log(training wall clock time + 1) to compare between networks, even trained for different amounts of time. I select for the lowest 2/3rds of the population according to that metric each generation. Essentially that measures how good was the trade-off between accuracy and model complexity: if the model takes longer to run each epoch, it had better converge faster per epoch.

Architectures with superior inductive bias for your problem will converge faster at the beginning of training, at least that’s my intuition.

It’s important to do just a couple epochs when first running a new architecture, test that criterion, and see if it’s hopeless (say, twice as bad as your nth best) ASAP so you don’t waste time.

On your scaling approach, that sounds alright, but beware has a greedy aspect to it; i.e. with that policy you’re only allowing architectures that perform well on their first few layers to be candidates for enlarging, where it could be that there are larger architectures which perform well as a unit but do not perform well just with their first layers, which therefore you’d never test. You can really only tell whether this is a viable assumption by experiment, but I think it’s a good start.

You could, similarly, set the calculation ops budget to 1m, get the best N networks and seed them as the initial population of the next phase where you increase the ops budget to 2m, and so on. In my code, a op budget allows a network to only add so many layers, kernel sizes, etc, so by increasing it, mutations could build upon well-performing networks adding new layers or wider kernel windows and such.

I also had arbitrary residual connections randomly sprinkled in, which was a huge complication. But at the very least you should have a residual connection spanning the input to the output of the last conv block before the fully connected layers.