I don’t know what your exact use case is, but you could use view
to reinterpret the float32
parameters as int32
, manipulate the bits you want, and copy the weight back to the module as seen here:
conv = nn.Conv2d(1, 1, 3, bias=False)
print(conv.weight)
# Parameter containing:
# tensor([[[[-0.2380, 0.0261, 0.2133],
# [-0.3071, -0.3317, 0.3193],
# [-0.0558, 0.3280, 0.1207]]]], requires_grad=True)
orig_weight = conv.weight.clone()
# reinterpret as int32
weight_int32 = orig_weight.view(torch.int32)
print(weight_int32)
# tensor([[[[-1099712250, 1020618774, 1046115288],
# [-1096991409, -1096167765, 1050902816],
# [-1117487605, 1051193350, 1039613558]]]], dtype=torch.int32)
print(format(weight_int32[0, 0, 0, 1], "b"))
# 111100110101010110100000010110
# corresponds to the float32 number as seen here: https://float.exposed/0x3cd56816
# manipulate bit
bit = 1 << 24
weight_int32[0, 0, 0, 1] ^= bit
print(format(weight_int32[0, 0, 0, 1], "b"))
# 111101110101010110100000010110
print(weight_int32[0, 0, 0, 1])
# tensor(1037395990, dtype=torch.int32)
weight_float32 = weight_int32.view(torch.float32)
print(weight_float32[0, 0, 0, 1])
# tensor(0.1042) corresponds to https://float.exposed/0x3dd56816
# set manipulated weight to conv layer
with torch.no_grad():
conv.weight.copy_(weight_float32)
print(conv.weight)
# Parameter containing:
# tensor([[[[-0.2380, 0.1042, 0.2133],
# [-0.3071, -0.3317, 0.3193],
# [-0.0558, 0.3280, 0.1207]]]], requires_grad=True)
Note that bit manipulations in the floating point representation are not trivial to interpret and could easily create huge or tiny numbers, but as I said I don’t know what your use case is.