import torch
import time
import os
# os.environ['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1'
# x = torch.randn((32, 2, 4), device="cuda")
#
# start = time.time()
# for i in range(10000):
# x *= x
# print(time.time()-start)
# 32, 0.046601057052612305
# 320, 0.044339895248413086
import torch.nn as nn
from latent_models import BART
model = BART(seq_len=512).cuda()
# start = time.time()
# model.module.encoder(torch.randint(0, 5, (32, 512), device="cuda"))
model.encoder(torch.zeros(32, 512, device="cuda", dtype=torch.int64))
# 32, 0.3491981029510498
# 64, 0.4617133140563965
# 128, 0.510127067565918
# 320, 0.7788600921630859
# 640, 1.2591266632080078
# start = time.time()
# print(model.module.forward(torch.zeros(32, 512, device="cuda", dtype=torch.int64),
# torch.zeros(32, 512, device="cuda", dtype=torch.int64)))
# print(time.time()-start)
# 16, 0.3350405693054199
# 32, 0.381974458694458
# 64, 0.47709083557128906
start = time.time()
model.train_vae(torch.zeros(16, 512, 768, device="cuda")).backward()
model.train_vae(torch.zeros(16, 512, 768, device="cuda")).backward()
model.train_vae(torch.zeros(16, 512, 768, device="cuda")).backward()
model.train_vae(torch.zeros(16, 512, 768, device="cuda")).backward()
model.train_vae(torch.zeros(16, 512, 768, device="cuda")).backward()
# print(model.module.train_vae(torch.randn((16, 64, 768), device="cuda")).backward())
# print('time ', time.time()-start)
# 16, 0.38356733322143555
# 32, 0.3720133304595947
# 64, 0.39897871017456055
# 128, 0.43648314476013184
# 256, 0.5453002452850342
import torch.autograd.profiler as profiler
torch.cuda.synchronize()
start = time.time()
torch.cuda.empty_cache()
# @torch.jit.script
def testing(model):
for i in range(1):
with torch.no_grad():
latents = model.module.get_encoder_latents(torch.randint(0, 5, (128, 512), device="cuda"))
latents = model.get_encoder_latents(torch.zeros(128, 512, device="cuda", dtype=torch.int64))
# print('time ', time.time()-start)
# time.sleep(1) 10280
# loss = model.train_vae(latents)
loss = model.train_vae(torch.randn((64, 512, 768), device="cuda"))
loss.backward()
# print(latents.grad)
# for param in model.module.encoder.parameters():
# print(param.grad)
torch.cuda.synchronize()
print('time ', time.time() - start)
# seq_len 512
# 16, 0.13292932510375977 (both), 0.06869959831237793 (encoder), 0.11004400253295898 (vae)
# 32, 0.22198772430419922 (both), 0.12623834609985352 (encoder), 0.1743018627166748 (vae)
# 64, 0.40708303451538086 (both), 0.1575465202331543 (encoder), 0.30747365951538086 (vae)
# 128, 0.8132572174072266 (both), 0.2463982105255127 (encoder), 0.6213085651397705 (vae)