Hi, Thank you for your reply. I found out that it is related to following part related to plotting my results.
I commented this part and solved the problem. since it was not so important, I ignored that part.
# def plot_constellation(
# ax,
# constellation,
# channel,
# decoder,
# grid_step=0.05,
# noise_samples=1000):
# """
# Plot a constellation with its decoder and channel noise.
# :param ax: Matplotlib axes to plot on.
# :param constellation: Constellation to plot.
# :param channel: Channel model to use for generating noise.
# :param decoder: Decoder function able to map the constellation points back
# to the original messages.
# :param grid_step: Grid step used for drawing the decision regions,
# expressed as percentage of the total plot width (or equivalently height).
# Lower steps makes more precise grids but takes more time to compute.
# :param noise_samples: Number of noisy points to sample and plot.
# """
# ax.grid()
# order = len(constellation)
# color_map = ListedColormap(seaborn.color_palette('husl', n_colors=order))
# color_norm = matplotlib.colors.BoundaryNorm(range(order + 1), order)
# # Extend axes symmetrically around zero so that they fit data
# axis_extent = max(
# abs(constellation.min()),
# abs(constellation.max())
# ) * 1.05
# ax.set_xlim(-axis_extent, axis_extent)
# ax.set_ylim(-axis_extent, axis_extent)
# # Hide borders but keep ticks
# for direction in ['left', 'bottom', 'right', 'top']:
# ax.axis[direction].line.set_color('#00000000')
# # Show zero-centered axes without ticks
# for direction in ['xzero', 'yzero']:
# axis = ax.axis[direction]
# axis.set_visible(True)
# axis.set_axisline_style('-|>')
# axis.major_ticklabels.set_visible(False)
# # Add axis names
# ax.annotate(
# 'I', (1, 0.5), xycoords='axes fraction',
# xytext=(25, 0), textcoords='offset points',
# va='center', ha='right'
# )
# ax.annotate(
# 'Q', (0.5, 1), xycoords='axes fraction',
# xytext=(0, 25), textcoords='offset points',
# va='center', ha='center'
# )
# # Plot decision regions
# regions_extent = 2 * axis_extent
# step = grid_step * regions_extent
# grid_range = torch.arange(-regions_extent, regions_extent, step)
# grid_y, grid_x = torch.meshgrid(grid_range, grid_range)
# grid_points = torch.stack((grid_x, grid_y), dim=-1).flatten(end_dim=1)
# grid_images = decoder(grid_points).argmax(dim=-1).reshape(grid_x.shape)
# ax.imshow(
# grid_images,
# extent=(
# -regions_extent, regions_extent,
# -regions_extent, regions_extent
# ),
# aspect='auto',
# origin='lower',
# cmap=color_map,
# norm=color_norm,
# alpha=0.2
# )
# # Plot constellation
# ax.scatter(
# *zip(*constellation.tolist()),
# zorder=10,
# s=60,
# c=range(order),
# edgecolor='black',
# cmap=color_map,
# norm=color_norm,
# )
# # Plot constellation center
# center = constellation.sum(dim=0) / order
# ax.scatter(
# center[0], center[1],
# marker='X',
# )
# # Plot channel noise
# if noise_samples > 0:
# noisy_vectors = channel(constellation.repeat(noise_samples, 1))
# ax.scatter(
# *zip(*noisy_vectors.tolist()),
# marker='.',
# s=5,
# c=list(range(order)) * noise_samples,
# cmap=color_map,
# norm=color_norm,
# alpha=0.3,
# zorder=8
# )