Start parallel versions of general TSV serialization/parsing

They mostly work, but are not actually parallelized yet and likely have some edge cases.
Also, the soon-to-be parallel version of parsing is very slow compared to the original.
This commit is contained in:
Nathan McRae 2024-02-25 22:35:56 -08:00
parent 4ddb8dc44d
commit 78eaa5dbab
2 changed files with 796 additions and 7 deletions

View File

@ -332,7 +332,7 @@ public class SaneTsv
fields.Add(fieldBytes.ToArray());
if (numFields == 0)
if (fields.Count == 0)
{
throw new Exception("Found 0 fields on last line. Possibly because of extra \\n after last record");
}
@ -355,6 +355,407 @@ public class SaneTsv
return parsed;
}
protected static Tsv<T> ParseParallel<T>(byte[] inputBuffer, FormatType format) where T : TsvRecord, new()
{
Tsv<T> parsed;
if (format == FormatType.COMMENTED_TSV)
{
parsed = new CommentedTsv<T>();
}
else
{
parsed = new Tsv<T>();
}
parsed.Records = new List<T>();
var headerTypes = new List<Type>();
var headerNames = new List<string>();
var headerPropertyInfos = new List<PropertyInfo>();
int columnCount = 0;
foreach (PropertyInfo property in typeof(T).GetProperties())
{
TypedTsvColumnAttribute attribute = (TypedTsvColumnAttribute)Attribute.GetCustomAttribute(property, typeof(TypedTsvColumnAttribute));
if (attribute == null)
{
continue;
}
headerNames.Add(attribute.ColumnName ?? property.Name);
headerTypes.Add(attribute.ColumnType ?? GetColumnFromType(property.PropertyType));
headerPropertyInfos.Add(property);
// TODO: Check that the property type and given column type are compatible
columnCount++;
}
var fieldBytes = new List<byte>();
var fields = new List<byte[]>();
var currentComment = new StringBuilder();
int numFields = -1;
int line = 1;
int currentLineStart = 0;
for (int i = 0; i < inputBuffer.Count(); i++)
{
if (inputBuffer[i] == '\\')
{
if (i + 1 == inputBuffer.Count())
{
throw new Exception($"Found '\\' at end of input");
}
if (inputBuffer[i + 1] == 'n')
{
fieldBytes.Add((byte)'\n');
i++;
}
else if (inputBuffer[i + 1] == '\\')
{
fieldBytes.Add((byte)'\\');
i++;
}
else if (inputBuffer[i + 1] == 't')
{
fieldBytes.Add((byte)'\t');
i++;
}
else if (inputBuffer[i + 1] == '#')
{
fieldBytes.Add((byte)'#');
i++;
}
else
{
throw new Exception($"Expected 'n', 't', '#', or '\\' after '\\' at line {line} column {i - currentLineStart}");
}
}
else if (inputBuffer[i] == '\t')
{
// end of field
fields.Add(fieldBytes.ToArray());
fieldBytes.Clear();
}
else if (inputBuffer[i] == '\n')
{
fields.Add(fieldBytes.ToArray());
fieldBytes.Clear();
int numTypesBlank = 0;
for (int j = 0; j < fields.Count; j++)
{
string columnString;
try
{
columnString = Encoding.UTF8.GetString(fields[j]);
}
catch (Exception e)
{
throw new Exception($"Header {fields.Count} is not valid UTF-8", e);
}
string columnTypeString;
string columnName;
if (columnString.Contains(':'))
{
if (format == FormatType.SIMPLE_TSV)
{
throw new Exception($"Header {fields.Count} contain ':', which is not allowed for column names");
}
columnTypeString = columnString.Split(":").Last();
columnName = columnString.Substring(0, columnString.Length - columnTypeString.Length - 1);
}
else
{
if (format > FormatType.SIMPLE_TSV)
{
throw new Exception($"Header {fields.Count} has no type");
}
columnTypeString = "";
columnName = columnString;
}
Type type;
switch (columnTypeString)
{
case "":
numTypesBlank++;
type = typeof(StringType);
break;
case "string":
type = typeof(StringType);
break;
case "boolean":
type = typeof(BooleanType);
break;
case "float32":
type = typeof(Float32Type);
break;
case "float32-le":
type = typeof(Float32LEType);
break;
case "float64":
type = typeof(Float64Type);
break;
case "float64-le":
type = typeof(Float64LEType);
break;
case "uint32":
type = typeof(UInt32Type);
break;
case "uint64":
type = typeof(UInt64Type);
break;
case "int32":
type = typeof(Int32Type);
break;
case "int64":
type = typeof(Int64Type);
break;
case "binary":
type = typeof(BinaryType);
break;
default:
throw new Exception($"Invalid type '{columnTypeString}' for column {j}");
}
// TODO: Allow lax parsing (only worry about parsing columns that are given in the specifying type
if (headerNames[j] != columnName)
{
throw new Exception($"Column {j} has name {columnName}, but expected {headerNames[j]}");
}
if (headerTypes[j] != type)
{
throw new Exception($"Column {j} has type {type}, but expected {headerTypes[j]}");
}
}
if (currentComment.Length > 0)
{
if (parsed is CommentedTsv<T> commentedParsed)
{
commentedParsed.FileComment = currentComment.ToString();
currentComment.Clear();
}
else
{
throw new Exception("Found a file comment, but parser wasn't expecting a comment");
}
}
fields.Clear();
line++;
currentLineStart = i + 1;
// Done parsing header
break;
}
else if (inputBuffer[i] == '#')
{
if (i == currentLineStart && format >= FormatType.COMMENTED_TSV)
{
int j = i;
for (; j < inputBuffer.Length && inputBuffer[j] != '\n'; j++) { }
if (j < inputBuffer.Length)
{
var commentBytes = new byte[j - i - 1];
Array.Copy(inputBuffer, i + 1, commentBytes, 0, j - i - 1);
if (currentComment.Length > 0)
{
currentComment.Append('\n');
}
currentComment.Append(Encoding.UTF8.GetString(commentBytes));
i = j;
currentLineStart = i + 1;
line++;
}
else
{
throw new Exception("Comments at end of file are not allowed");
}
}
else
{
throw new Exception($"Found unescaped '#' at line {line}, column {i - currentLineStart}");
}
}
else
{
fieldBytes.Add(inputBuffer[i]);
}
}
T[] parsedRecords = ParseParallel<T>(inputBuffer, format, headerPropertyInfos.ToArray(), headerTypes.ToArray(), currentLineStart - 1, inputBuffer.Length);
parsed.Records.AddRange(parsedRecords);
return parsed;
}
// This approach is slightly different than others. We skip the record that startIndex is in and
// include the record that endIndex is in. We do this because in order to include the record
// startIndex is in we'd have to go back to the start of the record's comment, and to know
// exactly where that comment started we'd have to go back to the start of the record before that
// (not including that other record's comment).
protected static T[] ParseParallel<T>(byte[] inputBuffer, FormatType format, PropertyInfo[] headerPropertyInfos, Type[] headerTypes, int startIndex, int endIndex) where T : TsvRecord, new()
{
var fieldBytes = new List<byte>();
var fields = new List<byte[]>();
var currentComment = new StringBuilder();
List<T> parsed = new List<T>();
bool parsingLastRecord = false;
int relativeLine = 0;
int i = startIndex;
while (i < inputBuffer.Length - 1 && inputBuffer[i] != '\n' && inputBuffer[i + 1] != '#')
{
i++;
}
if (i >= inputBuffer.Length - 1)
{
return Array.Empty<T>();
}
// Start parsing after \n
i++;
int currentLineStart = i;
for (; i < inputBuffer.Length && (i < endIndex || parsingLastRecord); i++)
{
if (inputBuffer[i] == '\\')
{
if (i + 1 == inputBuffer.Count())
{
throw new Exception($"Found '\\' at end of input");
}
if (inputBuffer[i + 1] == 'n')
{
fieldBytes.Add((byte)'\n');
i++;
}
else if (inputBuffer[i + 1] == '\\')
{
fieldBytes.Add((byte)'\\');
i++;
}
else if (inputBuffer[i + 1] == 't')
{
fieldBytes.Add((byte)'\t');
i++;
}
else if (inputBuffer[i + 1] == '#')
{
fieldBytes.Add((byte)'#');
i++;
}
else
{
throw new Exception($"Expected 'n', 't', '#', or '\\' after '\\' at line {relativeLine} column {i - currentLineStart}");
}
}
else if (inputBuffer[i] == '\t')
{
// end of field
fields.Add(fieldBytes.ToArray());
fieldBytes.Clear();
}
else if (inputBuffer[i] == '\n')
{
fields.Add(fieldBytes.ToArray());
fieldBytes.Clear();
if (headerTypes.Length != fields.Count)
{
throw new Exception($"Expected {headerTypes.Length} fields on line {relativeLine}, but found {fields.Count}");
}
else
{
string comment = null;
if (currentComment.Length > 0)
{
comment = currentComment.ToString();
currentComment.Clear();
}
parsed.Add(ParseCurrentRecord<T>(headerTypes.ToArray(), headerPropertyInfos.ToArray(), fields, comment, relativeLine));
fields.Clear();
}
parsingLastRecord = false;
relativeLine++;
currentLineStart = i + 1;
}
else if (inputBuffer[i] == '#')
{
if (i == currentLineStart && format >= FormatType.COMMENTED_TSV)
{
int j = i;
for (; j < inputBuffer.Length && inputBuffer[j] != '\n'; j++) { }
if (j < inputBuffer.Length)
{
var commentBytes = new byte[j - i - 1];
Array.Copy(inputBuffer, i + 1, commentBytes, 0, j - i - 1);
if (currentComment.Length > 0)
{
currentComment.Append('\n');
}
currentComment.Append(Encoding.UTF8.GetString(commentBytes));
i = j;
currentLineStart = i + 1;
relativeLine++;
}
else
{
throw new Exception("Comments at end of file are not allowed");
}
}
else
{
throw new Exception($"Found unescaped '#' at line {relativeLine}, column {i - currentLineStart}");
}
}
else
{
fieldBytes.Add(inputBuffer[i]);
}
if (i == endIndex - 1)
{
parsingLastRecord = true;
}
}
fields.Add(fieldBytes.ToArray());
if (fields.Count == 0)
{
// TODO
throw new Exception("Not sure when this will happen. THis might actuall be fine");
}
if (fields.Count != headerTypes.Length)
{
throw new Exception($"Expected {headerTypes} fields on line {relativeLine}, but found {fields.Count}");
}
else
{
string comment = null;
if (currentComment.Length > 0)
{
comment = currentComment.ToString();
currentComment.Clear();
}
parsed.Add(ParseCurrentRecord<T>(headerTypes.ToArray(), headerPropertyInfos.ToArray(), fields, comment, relativeLine));
fields.Clear();
}
return parsed.ToArray();
}
protected static T ParseCurrentCommentedRecord<T>(Type[] columnTypes, PropertyInfo[] properties, List<byte[]> fields, string comment, int line) where T : CommentedTsvRecord, new()
{
return (T)ParseCurrentRecord<T>(columnTypes, properties, fields, comment, line);
@ -1540,6 +1941,288 @@ public class SaneTsv
return bytes.ToArray();
}
protected static byte[] SerializeTsvParallel<T>(IList<T> data, FormatType tsvFormat)
{
var bytes = new List<byte>();
var headerTypes = new List<Type>();
var headerNames = new List<string>();
var headerPropertyInfos = new List<PropertyInfo>();
int columnCount = 0;
// Serialize header
foreach (PropertyInfo property in typeof(T).GetProperties())
{
TsvColumnAttribute attribute = (TsvColumnAttribute)Attribute.GetCustomAttribute(property, typeof(TsvColumnAttribute));
if (attribute == null)
{
continue;
}
string headerName = attribute.ColumnName ?? property.Name;
headerNames.Add(headerName);
Type headerType = attribute.ColumnType ?? GetColumnFromType(property.PropertyType);
if (tsvFormat == FormatType.SIMPLE_TSV && headerType != typeof(StringType))
{
throw new Exception($"Serializing Simple TSV requires all columns be of type string, but column '{headerName}' has type '{headerType}'");
}
headerTypes.Add(headerType);
headerPropertyInfos.Add(property);
// TODO: Check that the property type and given column type are compatible
columnCount++;
}
// Serialize header
for (int i = 0; i < headerNames.Count; i++)
{
for (int j = i + 1; j < headerNames.Count; j++)
{
if (headerNames[i] == headerNames[j])
{
throw new Exception("Column names in header must be unique");
}
}
byte[] nameEncoded = Encoding.UTF8.GetBytes(headerNames[i]);
for (int j = 0; j < nameEncoded.Length; j++)
{
if (nameEncoded[j] == '\n')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'n');
}
else if (nameEncoded[j] == '\t')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'t');
}
else if (nameEncoded[j] == '\\')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'\\');
}
else if (nameEncoded[j] == '#')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'#');
}
else
{
bytes.Add(nameEncoded[j]);
}
}
if (tsvFormat != FormatType.SIMPLE_TSV)
{
bytes.Add((byte)':');
try
{
bytes.AddRange(Encoding.UTF8.GetBytes(GetNameFromColumn(headerTypes[i])));
}
catch (Exception e)
{
throw new Exception($"Invalid header type for column {i}", e);
}
}
if (i == headerNames.Count - 1)
{
bytes.Add((byte)'\n');
}
else
{
bytes.Add((byte)'\t');
}
}
// Serialize data
SerializeTsvParallel<T>(data, bytes, headerPropertyInfos.ToArray(), headerTypes.ToArray(), tsvFormat, 0, data.Count);
return bytes.ToArray();
}
protected static void SerializeTsvParallel<T>(IList<T> data, List<byte> bytes, PropertyInfo[] headerPropertyInfos, Type[] headerTypes, FormatType tsvFormat, int startIndex, int endIndex)
{
// Serialize data
for (int i = 0; i < data.Count; i++)
{
for (int j = 0; j < headerTypes.Length; j++)
{
object datum = headerPropertyInfos[j].GetValue(data[i]);
try
{
byte[] fieldEncoded = null;
// Some fields definitely don't need escaping, so we add them directly to bytes
bool skipEscaping = false;
if (headerTypes[j] == typeof(StringType))
{
fieldEncoded = Encoding.UTF8.GetBytes((string)datum);
}
else if (headerTypes[j] == typeof(BooleanType))
{
bytes.AddRange((bool)datum ? TrueEncoded : FalseEncoded);
skipEscaping = true;
}
else if (headerTypes[j] == typeof(Float32Type))
{
if (datum is float f)
{
if (float.IsNegativeInfinity(f))
{
bytes.AddRange(Encoding.UTF8.GetBytes("-inf"));
}
else if (float.IsPositiveInfinity(f))
{
bytes.AddRange(Encoding.UTF8.GetBytes("+inf"));
}
else
{
// See https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings#round-trip-format-specifier-r
bytes.AddRange(Encoding.UTF8.GetBytes(((float)datum).ToString("G9")));
}
}
else
{
throw new InvalidCastException();
}
skipEscaping = true;
}
else if (headerTypes[j] == typeof(Float32LEType))
{
if (LittleEndian)
{
fieldEncoded = BitConverter.GetBytes((float)datum);
}
else
{
byte[] floatBytes = BitConverter.GetBytes((float)datum);
fieldEncoded = new byte[sizeof(float)];
for (int k = 0; k < sizeof(float); k++)
{
fieldEncoded[k] = floatBytes[sizeof(float) - 1 - k];
}
}
}
else if (headerTypes[j] == typeof(Float64Type))
{
if (datum is double d)
{
if (double.IsNegativeInfinity(d))
{
bytes.AddRange(Encoding.UTF8.GetBytes("-inf"));
}
else if (double.IsPositiveInfinity(d))
{
bytes.AddRange(Encoding.UTF8.GetBytes("+inf"));
}
else
{
// See https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings#round-trip-format-specifier-r
bytes.AddRange(Encoding.UTF8.GetBytes((d).ToString("G17")));
}
}
else
{
throw new InvalidCastException();
}
skipEscaping = true;
}
else if (headerTypes[j] == typeof(Float64LEType))
{
if (LittleEndian)
{
fieldEncoded = BitConverter.GetBytes((double)datum);
}
else
{
byte[] doubleBytes = BitConverter.GetBytes((double)datum);
fieldEncoded = new byte[sizeof(double)];
for (int k = 0; k < sizeof(double); k++)
{
fieldEncoded[k] = doubleBytes[sizeof(double) - 1 - k];
}
}
}
else if (headerTypes[j] == typeof(UInt32Type))
{
bytes.AddRange(Encoding.UTF8.GetBytes(((UInt32)datum).ToString()));
skipEscaping = true;
}
else if (headerTypes[j] == typeof(UInt64Type))
{
bytes.AddRange(Encoding.UTF8.GetBytes(((UInt64)datum).ToString()));
skipEscaping = true;
}
else if (headerTypes[j] == typeof(Int32Type))
{
bytes.AddRange(Encoding.UTF8.GetBytes(((Int32)datum).ToString()));
skipEscaping = true;
}
else if (headerTypes[j] == typeof(Int64Type))
{
bytes.AddRange(Encoding.UTF8.GetBytes(((Int64)datum).ToString()));
skipEscaping = true;
}
else if (headerTypes[j] == typeof(BinaryType))
{
fieldEncoded = (byte[])datum;
}
else
{
throw new Exception($"Unexpected column type {headerTypes[j]} for column {j}");
}
if (!skipEscaping)
{
for (int k = 0; k < fieldEncoded.Length; k++)
{
if (fieldEncoded[k] == '\n')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'n');
}
else if (fieldEncoded[k] == '\t')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'t');
}
else if (fieldEncoded[k] == '\\')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'\\');
}
else if (fieldEncoded[k] == '#')
{
bytes.Add((byte)'\\');
bytes.Add((byte)'#');
}
else
{
bytes.Add(fieldEncoded[k]);
}
}
}
if (j < headerTypes.Length - 1)
{
bytes.Add((byte)'\t');
}
else if (i < data.Count - 1)
{
bytes.Add((byte)'\n');
}
}
catch (InvalidCastException e)
{
throw new Exception($"Record {i}, field {j} expected type compatible with {GetNameFromColumn(headerTypes[j])}", e);
}
}
}
}
public class SimpleTsvRecord
{
public string[] ColumnNames { get; }

View File

@ -1,7 +1,7 @@
using NathanMcRae;
using System.Text;
internal class Program
internal class Program : SaneTsv
{
public class TestRecord : SaneTsv.TsvRecord
{
@ -349,7 +349,7 @@ internal class Program
}
{
string testName = "Check parallel serialization";
string testName = "Check parallel Simple TSV serialization";
int N = 100000;
var records = new StringTestRecord[N];
@ -398,7 +398,7 @@ internal class Program
}
{
string testName = "Check parallel parsing";
string testName = "Check Simple TSV parallel parsing";
int N = 100000;
var records = new StringTestRecord[N];
@ -423,8 +423,8 @@ internal class Program
(string[] headers, string[][] data) = SaneTsv.ParseSimpleTsvParallel(serialized);
TimeSpan parallelTime = DateTime.Now - lastTime;
Console.WriteLine($"Unparallel serialization time: {unparallelTime}");
Console.WriteLine($"Parallel serialization time: {parallelTime}");
Console.WriteLine($"Unparallel parse time: {unparallelTime}");
Console.WriteLine($"Parallel parse time: {parallelTime}");
bool matching = true;
for (int j = 0; j < Math.Min(headers2.Length, headers.Length); j++)
@ -458,6 +458,112 @@ internal class Program
}
}
{
string testName = "Check parallel serialization";
int N = 1000;
var records = new BoolTestRecord[N];
var rand = new Random(1);
for (int i = 0; i < N; i++)
{
byte[] bytes = new byte[rand.Next(50)];
rand.NextBytes(bytes);
records[i] = new BoolTestRecord()
{
Column1 = rand.NextDouble() > 0.5,
column2 = bytes,
Column3 = rand.Next().ToString(),
};
}
DateTime lastTime = DateTime.Now;
byte[] serialized1 = SaneTsv.SerializeTsv<BoolTestRecord>(records, FormatType.COMMENTED_TSV);
TimeSpan unparallelTime = DateTime.Now - lastTime;
lastTime = DateTime.Now;
byte[] serialized2 = SaneTsv.SerializeTsvParallel<BoolTestRecord>(records, FormatType.COMMENTED_TSV);
TimeSpan parallelTime = DateTime.Now - lastTime;
Console.WriteLine($"Unparallel serialization time: {unparallelTime}");
Console.WriteLine($"Parallel serialization time: {parallelTime}");
bool matching = true;
for (int i = 0; i < Math.Min(serialized1.Length, serialized2.Length); i++)
{
if (serialized1[i] != serialized2[i])
{
matching = false;
break;
}
}
if (matching)
{
Console.WriteLine($"Passed {testName}");
}
else
{
Console.WriteLine($"Failed {testName}");
}
}
{
string testName = "Check parallel parsing";
int N = 1000;
var records = new BoolTestRecord[N];
var rand = new Random(1);
for (int i = 0; i < N; i++)
{
byte[] bytes = new byte[rand.Next(50)];
rand.NextBytes(bytes);
records[i] = new BoolTestRecord()
{
Column1 = rand.NextDouble() > 0.5,
column2 = bytes,
Column3 = rand.Next().ToString(),
};
}
byte[] serialized2 = SaneTsv.SerializeTsvParallel<BoolTestRecord>(records, FormatType.COMMENTED_TSV);
DateTime lastTime = DateTime.Now;
CommentedTsv<BoolTestRecord> parsed = (CommentedTsv<BoolTestRecord>)SaneTsv.Parse<BoolTestRecord>(serialized2, FormatType.COMMENTED_TSV);
TimeSpan unparallelTime = DateTime.Now - lastTime;
lastTime = DateTime.Now;
CommentedTsv<BoolTestRecord> parsed2 = (CommentedTsv<BoolTestRecord>)SaneTsv.ParseParallel<BoolTestRecord>(serialized2, FormatType.COMMENTED_TSV);
TimeSpan parallelTime = DateTime.Now - lastTime;
Console.WriteLine($"Unparallel parsing time: {unparallelTime}");
Console.WriteLine($"Parallel parsing time: {parallelTime}");
bool matching = parsed.FileComment == parsed2.FileComment;
matching &= parsed.Records.Count == parsed2.Records.Count;
for (int i = 0; matching && i < parsed.Records.Count; i++)
{
matching &= parsed.Records[i].Comment == parsed2.Records[i].Comment;
matching &= parsed.Records[i].Column1 == parsed2.Records[i].Column1;
matching &= parsed.Records[i].column2.Length == parsed2.Records[i].column2.Length;
for (int j = 0; matching && j < parsed.Records[i].column2.Length; j++)
{
matching &= parsed.Records[i].column2[j] == parsed2.Records[i].column2[j];
}
}
if (matching)
{
Console.WriteLine($"Passed {testName}");
}
else
{
Console.WriteLine($"Failed {testName}");
}
}
Console.WriteLine("Done with tests");
}
}