Amp on cpu 50x slower and high memory allocation

I’ve been running the following snippet on cpu:

import torch
import torch.nn as nn

from torch.profiler import ProfilerActivity
from torch.profiler import profile
from torch.profiler import record_function

layer1 = nn.Conv3d( 1, 32, 3, 1, 1 )
layer2 = nn.Conv3d( 32, 32, 3, 1, 1 )

x = torch.randn( 1, 1, 256, 256, 1 )

with profile( activities = [ ProfilerActivity.CPU ], record_shapes = True, profile_memory = True ) as prof:
    with record_function( "model run" ):
        with torch.autocast( "cpu", torch.bfloat16, True ):
            res = layer1( x )
            res = layer2( res )

print( prof.key_averages().table( sort_by = "cpu_time_total" ) )

I get the following result for amp enabled:

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 aten::conv3d         0.00%      65.000us       199.98%       26.008s        6.502s      16.18 Mb           0 b             4  
                    model run         0.01%     772.000us       100.00%       13.005s       13.005s       8.18 Mb        -128 b             1  
            aten::convolution         0.00%      87.000us        99.99%       13.004s        6.502s       8.00 Mb           0 b             2  
           aten::_convolution         0.00%      46.000us        99.99%       13.003s        6.502s       8.00 Mb           0 b             2  
            aten::slow_conv3d         0.00%      23.000us        99.99%       13.003s        6.502s       8.00 Mb           0 b             2  
    aten::slow_conv3d_forward        99.98%       13.002s        99.99%       13.003s        6.502s       8.00 Mb    -111.38 Mb             2  
                  aten::copy_         0.01%       1.451ms         0.01%       1.451ms     207.286us           0 b           0 b             7  
                     aten::to         0.00%      16.000us         0.01%     822.000us     164.400us     183.81 Kb           0 b             5  
               aten::_to_copy         0.00%      79.000us         0.01%     806.000us     161.200us     183.81 Kb     128.00 Kb             5  
          aten::empty_strided         0.00%      55.000us         0.00%      55.000us      11.000us      55.81 Kb      55.81 Kb             5  
                aten::resize_         0.00%      45.000us         0.00%      45.000us      22.500us       8.00 Mb       8.00 Mb             2  
                aten::reshape         0.00%      33.000us         0.00%      43.000us      21.500us           0 b           0 b             2  
                  aten::empty         0.00%      31.000us         0.00%      31.000us       7.750us     111.38 Mb     111.38 Mb             4  
                   aten::view         0.00%      15.000us         0.00%      15.000us       7.500us           0 b           0 b             2  
         aten::_reshape_alias         0.00%      10.000us         0.00%      10.000us       5.000us           0 b           0 b             2  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 13.005s

And with amp disables:

----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   model run         1.67%     419.000us       100.00%      25.111ms      25.111ms      16.00 Mb           0 b             1  
                aten::conv3d         0.12%      31.000us        98.33%      24.692ms      12.346ms      16.00 Mb           0 b             2  
           aten::convolution         0.40%     100.000us        98.21%      24.661ms      12.331ms      16.00 Mb           0 b             2  
          aten::_convolution         0.20%      49.000us        97.81%      24.561ms      12.280ms      16.00 Mb           0 b             2  
    aten::mkldnn_convolution        97.36%      24.447ms        97.61%      24.512ms      12.256ms      16.00 Mb           0 b             2  
                 aten::empty         0.16%      41.000us         0.16%      41.000us      10.250us      16.00 Mb      16.00 Mb             4  
           aten::as_strided_         0.10%      24.000us         0.10%      24.000us      12.000us           0 b           0 b             2  
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 25.111ms

The amp run allocates way more memory and requires much longer than the standard run. It obviously calls different functions, of which aten::slow_conv3d_forward is responsible for the bulk of the additional time. Is this behavior standard for amp on cpu?
I have to mention, the script was run on a Ryzen 7 3800X, which doesn’t have native bfloat16 support, so I assume it is being emulated via different datatypes. Is torch ( 2.0.1 ) capable to detect this and choses this execution path or is this the default for amp on cpu?