import tensorflow as tf raw_dataset = tf.data.TFRecordDataset("input_file.tfrecord") shards = 10 for i in range(shards): writer = tf.data.experimental.TFRecordWriter(f"output_file-part-{i}.tfrecord") writer.write(raw_dataset.shard(shards, i))