Expected torch.Size([4, 1, 40, 99]) but getting torch.Size([4, 1, 257, 197])

Here 4 is the batch size and I am trying to extract audio , hence channel is 1.

audio extractor

audio, audioR, audio2, audio2R = torch.split(spec, [1, 1, 1, 1], dim=1)

required value is audio

decoder

vid_path, 
    container, 
    sampling_rate, 
    num_frames, 
    clip_idx, 
    num_clips=10, 
    target_fps=30, 
    aug_audio=[1, 1, 3, 6], 
    decode_audio=True,
    num_sec=1,
    aud_sample_rate=48000,
    aud_spec_type=1,
    use_volume_jittering=False,
    use_temporal_jittering=False,
    z_normalize=False,


fps = float(container.streams.video[0].average_rate)
    frames_length = container.streams.video[0].frames
    duration = container.streams.video[0].duration

    if duration is None:
        # If failed to fetch the decoding information, decode the entire video.
        decode_all_video = True
        video_start_pts, video_end_pts = 0, math.inf
    else:
        # Perform selective decoding.
        decode_all_video = False
        start_idx, end_idx = get_start_end_idx(
            frames_length,
            sampling_rate * num_frames / target_fps * fps,
            clip_idx,
            num_clips,
        )
        timebase = duration / frames_length
        video_start_pts = int(start_idx * timebase)
        video_end_pts = int(end_idx * timebase)

    frames = None
    # If video stream was found, fetch video frames from the video.
    if container.streams.video:
        video_frames, max_pts = pyav_decode_stream(
            container,
            video_start_pts,
            video_end_pts,
            container.streams.video[0],
            {"video": 0},
        )
        container.close()

        frames = [frame.to_rgb().to_ndarray() for frame in video_frames]
        frames = torch.as_tensor(np.stack(frames))

    # Get wav
    if decode_audio:
        try: 
            # Get spectogram
            fr_sec = start_idx / fps
            spec = load_audio(
                vid_path, 
                fr_sec, 
                num_sec=num_sec, 
                sample_rate=aud_sample_rate, 
                aug_audio=aug_audio, 
                aud_spec_type=aud_spec_type, 
                use_volume_jittering=use_volume_jittering,
                use_temporal_jittering=use_temporal_jittering,
                z_normalize=z_normalize,
            )
        except:
            print(f"Bad audio of video: {vid_path}", flush=True)
            if spec is not None:
                print(f"Bad spec shape of {vid_path}: {spec.shape}", flush=True)
            if wav is not None:
                print(f"Bad wav shape of {vid_path}: {wav.shape}", flush=True)
            return None, None, None, None
        return frames, spec, fps, decode_all_video
    else:
        return frames, None, fps, decode_all_video