Tuesday, 3 October 2023

TensorFlow: ModelCheckpoint a model with TextVectorization layer

 TL;DR: don't use ".h5" or ".hdf5" extensions. 

(Tested on TensorFlow 2.12.0)

The error

Several weeks ago I started implementing some research papers on Text Classification using TensorFlow. By the means of "research" I'm not supposed to use ModelCheckpoint to checkpoint models on their best performance, instead, just let them run for $n$ epochs and record their best.

However, while exporting the code to my local machine from Kaggle I noticed that TensorFlow's TextVectorization layers can't be exported to a .h5 or .hdf5 file extension. This was considered a bug in 2020, but surprisingly it hasn't been fixed! Furthermore, this has been affecting ModelCheckpoint, since ModelCheckpoint automatically calls model.save on the best metric models, which accidentally triggers the bug:


Note that the error implies that you should save the model to the Tensorflow SavedModel format by setting save_format="tf", which requires modifications inside the ModelCheckpoint class. Another solution is to save the TextVectorization layer separately from the model, and then load both of them when we need a model inference.

Of course, we'll not try any of these. There must be a better solution to not modify too much inside the TensorFlow API.

The Way

After a "self-research" (which mostly involves scrolling through useless topics where painful-and-have-no-life dudes screamed for help but got nothing back - this and this), I finally ran into something useful:


Source

So, removing the ".h5" or ".hdf5" extension from the model path and voilà, it's done? It turns out that Tensorflow, by default, saves models without extensions as tf-saved format, which means the whole 3-years-bug-can-be-solved-by-removing-some-strings-on-the-end-user-side.

Working solution can be found here, check it yourself. Thank you, strangers on the internet!

Bonus

The original issue on TensorFlow's GitHub was stupidly and awkwardly closed. It's either the author is too pissed off to continue asking for a modification in TensorFlow API, or no one really cares about (the bug).


Who the actual f would say "Yes"?

No comments:

Post a Comment

Popular posts