Loss backward took forever until memory leaking on TPU v3

duplicated post in Github. Please go to here for more details.

I am fine tuning XLNet for multi-label text classification on a TPU v3. When the labels are under 100, the training went well and took about 10s per iterator(batch size 16). But if the labels are about 9000, the training stuck after the first batch forward at loss.backward. I tried to keep it running for about 4 hours until memory leaking. The model has about 131032033 parameters. The following is the metric reports.

Metric: CompileTime
TotalSamples: 21
Accumulator: 594ms960.157us
ValueRate: 092ms946.360us / second
Rate: 3.25085 / second
Percentiles: 1%=001ms048.724us; 5%=001ms143.656us; 10%=001ms297.037us; 20%=002ms503.725us; 50%=002ms686.961us; 80%=005ms211.937us; 90%=006ms457.997us; 95%=018ms385.346us; 99%=529ms270.667us
Metric: DeviceLockWait
TotalSamples: 21
Accumulator: 096.075us
ValueRate: 013.724us / second
Rate: 2.99983 / second
Percentiles: 1%=003.520us; 5%=003.626us; 10%=003.792us; 20%=003.934us; 50%=004.301us; 80%=005.171us; 90%=005.555us; 95%=005.833us; 99%=007.407us
Metric: ExecuteTime
TotalSamples: 21
Accumulator: 185ms133.715us
ValueRate: 029ms301.158us / second
Rate: 3.32368 / second
Percentiles: 1%=002ms989.937us; 5%=002ms014.393us; 10%=002ms107.005us; 20%=002ms137.575us; 50%=002ms313.633us; 80%=003ms652.595us; 90%=003ms785.078us; 95%=005ms872.769us; 99%=136ms111.333us
Metric: InboundData
TotalSamples: 20
Accumulator: 160.00B
ValueRate: 31.90B / second
Rate: 3.98692 / second
Percentiles: 1%=8.00B; 5%=8.00B; 10%=8.00B; 20%=8.00B; 50%=8.00B; 80%=8.00B; 90%=8.00B; 95%=8.00B; 99%=8.00B
Metric: InputOutputAliasCount
TotalSamples: 21
Accumulator: 232.00
ValueRate: 33.19 / second
Rate: 3.00399 / second
Percentiles: 1%=1.00; 5%=1.00; 10%=1.00; 20%=1.00; 50%=1.00; 80%=1.00; 90%=1.00; 95%=1.00; 99%=212.00
Metric: IrValueTensorToXlaData
TotalSamples: 232
Accumulator: 05s153ms259.368us
ValueRate: 566ms014.728us / second
Rate: 25.482 / second
Percentiles: 1%=830.401us; 5%=889.505us; 10%=001ms033.071us; 20%=001ms173.958us; 50%=004ms322.203us; 80%=015ms407.568us; 90%=036ms242.661us; 95%=085ms422.525us; 99%=355ms438.993us
Metric: OutboundData
TotalSamples: 239
Accumulator: 984.20MB
ValueRate: 61.39MB / second
Rate: 14.9075 / second
Percentiles: 1%=4.00B; 5%=3.00KB; 10%=3.00KB; 20%=3.00KB; 50%=12.00KB; 80%=2.25MB; 90%=9.00MB; 95%=14.74MB; 99%=48.00MB
Metric: ReleaseDataHandlesTime
TotalSamples: 21
Accumulator: 027ms167.485us
ValueRate: 004ms305.010us / second
Rate: 3.3277 / second
Percentiles: 1%=507.093us; 5%=512.067us; 10%=564.039us; 20%=604.354us; 50%=684.380us; 80%=925.953us; 90%=004ms156.686us; 95%=005ms773.718us; 99%=006ms807.402us
Metric: TensorsGraphSize
TotalSamples: 21
Accumulator: 1420.00
ValueRate: 219.98 / second
Rate: 3.25324 / second
Percentiles: 1%=18.00; 5%=18.00; 10%=18.00; 20%=18.00; 50%=18.00; 80%=18.00; 90%=18.00; 95%=18.00; 99%=1060.00
Metric: TransferFromServerTime
TotalSamples: 20
Accumulator: 031ms137.437us
ValueRate: 006ms207.132us / second
Rate: 3.98693 / second
Percentiles: 1%=864.648us; 5%=937.089us; 10%=948.000us; 20%=986.957us; 50%=001ms342.661us; 80%=002ms528.235us; 90%=003ms378.802us; 95%=004ms301.261us; 99%=004ms301.261us
Metric: TransferToServerTime
TotalSamples: 239
Accumulator: 05s345ms883.669us
ValueRate: 340ms073.866us / second
Rate: 15.2066 / second
Percentiles: 1%=820.348us; 5%=879.701us; 10%=001ms020.336us; 20%=001ms163.830us; 50%=004ms304.142us; 80%=017ms207.125us; 90%=041ms485.225us; 95%=090ms678.567us; 99%=355ms415.786us
Metric: TransferToServerTransformTime
TotalSamples: 239
Accumulator: 769ms568.847us
ValueRate: 048ms935.319us / second
Rate: 14.9063 / second
Percentiles: 1%=048.697us; 5%=054.600us; 10%=059.095us; 20%=079.733us; 50%=280.509us; 80%=002ms895.032us; 90%=003ms721.550us; 95%=015ms214.894us; 99%=052ms474.712us
Counter: CreateCompileHandles
Value: 21
Counter: CreateDataHandles
Value: 501
Counter: CreateXlaTensor
Value: 139157
Counter: DestroyDataHandles
Value: 232
Counter: DestroyXlaTensor
Value: 78849
Counter: DeviceDataCacheMiss
Value: 4
Counter: MarkStep
Value: 1
Counter: ReleaseDataHandles
Value: 232
Counter: UncachedCompile
Value: 21
Counter: XRTAllocateFromTensor_Empty
Value: 35
Counter: XrtCompile_Empty
Value: 32
Counter: XrtExecuteChained_Empty
Value: 32
Counter: XrtExecute_Empty
Value: 32
Counter: XrtMemoryInfo_Empty
Value: 32
Counter: XrtRead_Empty
Value: 32
Counter: XrtReleaseAllocationHandle_Empty
Value: 32
Counter: XrtReleaseCompileHandle_Empty
Value: 32
Counter: XrtSessionCount
Value: 4
Counter: XrtSubTuple_Empty
Value: 32
Counter: aten::local_scalar_dense
Value: 20
Counter: xla::copy_from
Value: 272
Counter: xla::softmax
Value: 5161
Counter: xla::unsafe_view
Value: 15363
Counter: xla::add
Value: 851
Counter: xla::add

Value: 243
Counter: xla::addcmul
Value: 240
Counter: xla::arange_out
Value: 120
Counter: xla::as_strided
Value: 222
Counter: xla::bernoulli

Value: 513
Counter: xla::binary_cross_entropy_with_logits
Value: 1
Counter: xla::bmm
Value: 6241
Counter: xla::clamp_min

Value: 1
Counter: xla::div

Value: 513
Counter: xla::embedding
Value: 10
Counter: xla::empty
Value: 916
Counter: xla::empty_strided
Value: 222
Counter: xla::exp_
Value: 2
Counter: xla::expand
Value: 10082
Counter: xla::fill_
Value: 1
Counter: xla::gelu
Value: 120
Counter: xla::gt
Value: 20
Counter: xla::index_select
Value: 130
Counter: xla::log_
Value: 1
Counter: xla::max
Value: 10
Counter: xla::mean
Value: 1
Counter: xla::min
Value: 10
Counter: xla::mm
Value: 10322
Counter: xla::mul
Value: 754
Counter: xla::mul_
Value: 1
Counter: xla::native_batch_norm
Value: 240
Counter: xla::ne
Value: 10
Counter: xla::neg
Value: 13
Counter: xla::permute
Value: 6130
Counter: xla::rsub
Value: 10
Counter: xla::scatter_
Value: 10
Counter: xla::select
Value: 5071
Counter: xla::sigmoid
Value: 1
Counter: xla::slice
Value: 10692
Counter: xla::squeeze
Value: 10062
Counter: xla::stack
Value: 5
Counter: xla::sub
Value: 121
Counter: xla::sub_
Value: 1
Counter: xla::sum
Value: 1
Counter: xla::t
Value: 10322
Counter: xla::tanh
Value: 5041
Counter: xla::transpose
Value: 5076
Counter: xla::unsqueeze
Value: 3670
Counter: xla::view
Value: 41789
Counter: xla::zero_
Value: 10
Metric: XrtAllocateFromTensor
TotalSamples: 2166
Accumulator: 05s371ms945.919us
Mean: 003ms247.374us
StdDev: 005ms051.251us
Rate: 65.5897 / second
Percentiles: 25%=342.407us; 50%=002ms514.806us; 80%=005ms992.474us; 90%=007ms063.137us; 95%=015ms325.965us; 99%=019ms539.374us
Metric: XrtCompile
TotalSamples: 168
Accumulator: 05s519ms861.698us
Mean: 027ms897.986us
StdDev: 106ms354.463us
Rate: 5.10487 / second
Percentiles: 25%=167.783us; 50%=197.449us; 80%=242.641us; 90%=159ms036.450us; 95%=162ms310.616us; 99%=169ms269.848us
Metric: XrtExecute
TotalSamples: 168
Accumulator: 237ms422.502us
Mean: 001ms413.229us
StdDev: 478.713us
Rate: 5.10719 / second
Percentiles: 25%=001ms202.048us; 50%=001ms314.841us; 80%=001ms478.719us; 90%=002ms594.452us; 95%=002ms117.205us; 99%=004ms608.056us
Metric: XrtExecutorEvict
TotalSamples: 0
Accumulator: nanB
Mean: nanB
StdDev: nanB
Percentiles:
Metric: XrtReadLiteral
TotalSamples: 160
Accumulator: 073ms435.633us
Mean: 458.973us
StdDev: 261.592us
Rate: 4.89464 / second
Percentiles: 25%=351.912us; 50%=432.780us; 80%=531.511us; 90%=579.884us; 95%=659.538us; 99%=728.484us
Metric: XrtReleaseAllocation
TotalSamples: 205
Accumulator: 007ms322.887us
Mean: 035.721us
StdDev: 038.990us
Rate: 6.23232 / second
Percentiles: 25%=023.236us; 50%=027.079us; 80%=034.452us; 90%=046.374us; 95%=082.995us; 99%=267.722us