[TensorFlow] Generating TFRecord from image files

import tensorflow as tf

def _float_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def _int64_feature(value):
  return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def write_examples(image_data, output_path):
  """
  Create a tfrecord file.
  
  Args:
    image_data (List[(image_file_path (str), label (int), instance_id (str)]): the data to store in the tfrecord file. 
      The `image_file_path` should be the full path to the image, accessible by the machine that will be running the 
      TensorFlow network. The `label` should be an integer in the range [0, number_of_classes). `instance_id` should be 
      some unique identifier for this example (such as a database identifier). 
    output_path (str): the full path name for the tfrecord file. 
  """
  writer = tf.python_io.TFRecordWriter(output_path)

  for image_path, label, instance_id in image_data:

    example = tf.train.Example(features=tf.train.Features(
      feature={
        'label': _int64_feature([label]),
        'path': _bytes_feature([image_path]),
        'instance' : _bytes_feature([instance_id])
      }
    ))

    writer.write(example.SerializeToString())

  writer.close()

Reference

https://gist.github.com/gvanhorn38/ac19b85a4f7b5fb9e82e04f4ac6d5566

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.