From 78eaa5dbabc4cfe1ccb52f44615458743acb182f Mon Sep 17 00:00:00 2001 From: Nathan McRae Date: Sun, 25 Feb 2024 22:35:56 -0800 Subject: [PATCH] Start parallel versions of general TSV serialization/parsing They mostly work, but are not actually parallelized yet and likely have some edge cases. Also, the soon-to-be parallel version of parsing is very slow compared to the original. --- SaneTsv/SaneTsv.cs | 687 ++++++++++++++++++++++++++++++++- SaneTsv/SaneTsvTest/Program.cs | 116 +++++- 2 files changed, 796 insertions(+), 7 deletions(-) diff --git a/SaneTsv/SaneTsv.cs b/SaneTsv/SaneTsv.cs index 71bef74..8c4493b 100644 --- a/SaneTsv/SaneTsv.cs +++ b/SaneTsv/SaneTsv.cs @@ -332,7 +332,7 @@ public class SaneTsv fields.Add(fieldBytes.ToArray()); - if (numFields == 0) + if (fields.Count == 0) { throw new Exception("Found 0 fields on last line. Possibly because of extra \\n after last record"); } @@ -355,6 +355,407 @@ public class SaneTsv return parsed; } + protected static Tsv ParseParallel(byte[] inputBuffer, FormatType format) where T : TsvRecord, new() + { + Tsv parsed; + if (format == FormatType.COMMENTED_TSV) + { + parsed = new CommentedTsv(); + } + else + { + parsed = new Tsv(); + } + parsed.Records = new List(); + + var headerTypes = new List(); + var headerNames = new List(); + var headerPropertyInfos = new List(); + int columnCount = 0; + + foreach (PropertyInfo property in typeof(T).GetProperties()) + { + TypedTsvColumnAttribute attribute = (TypedTsvColumnAttribute)Attribute.GetCustomAttribute(property, typeof(TypedTsvColumnAttribute)); + if (attribute == null) + { + continue; + } + + headerNames.Add(attribute.ColumnName ?? property.Name); + headerTypes.Add(attribute.ColumnType ?? GetColumnFromType(property.PropertyType)); + headerPropertyInfos.Add(property); + // TODO: Check that the property type and given column type are compatible + columnCount++; + } + + var fieldBytes = new List(); + var fields = new List(); + var currentComment = new StringBuilder(); + + int numFields = -1; + int line = 1; + int currentLineStart = 0; + for (int i = 0; i < inputBuffer.Count(); i++) + { + if (inputBuffer[i] == '\\') + { + if (i + 1 == inputBuffer.Count()) + { + throw new Exception($"Found '\\' at end of input"); + } + if (inputBuffer[i + 1] == 'n') + { + fieldBytes.Add((byte)'\n'); + i++; + } + else if (inputBuffer[i + 1] == '\\') + { + fieldBytes.Add((byte)'\\'); + i++; + } + else if (inputBuffer[i + 1] == 't') + { + fieldBytes.Add((byte)'\t'); + i++; + } + else if (inputBuffer[i + 1] == '#') + { + fieldBytes.Add((byte)'#'); + i++; + } + else + { + throw new Exception($"Expected 'n', 't', '#', or '\\' after '\\' at line {line} column {i - currentLineStart}"); + } + } + else if (inputBuffer[i] == '\t') + { + // end of field + fields.Add(fieldBytes.ToArray()); + fieldBytes.Clear(); + } + else if (inputBuffer[i] == '\n') + { + fields.Add(fieldBytes.ToArray()); + fieldBytes.Clear(); + + int numTypesBlank = 0; + + for (int j = 0; j < fields.Count; j++) + { + string columnString; + try + { + columnString = Encoding.UTF8.GetString(fields[j]); + } + catch (Exception e) + { + throw new Exception($"Header {fields.Count} is not valid UTF-8", e); + } + + string columnTypeString; + string columnName; + if (columnString.Contains(':')) + { + if (format == FormatType.SIMPLE_TSV) + { + throw new Exception($"Header {fields.Count} contain ':', which is not allowed for column names"); + } + columnTypeString = columnString.Split(":").Last(); + columnName = columnString.Substring(0, columnString.Length - columnTypeString.Length - 1); + } + else + { + if (format > FormatType.SIMPLE_TSV) + { + throw new Exception($"Header {fields.Count} has no type"); + } + columnTypeString = ""; + columnName = columnString; + } + + Type type; + + switch (columnTypeString) + { + case "": + numTypesBlank++; + type = typeof(StringType); + break; + case "string": + type = typeof(StringType); + break; + case "boolean": + type = typeof(BooleanType); + break; + case "float32": + type = typeof(Float32Type); + break; + case "float32-le": + type = typeof(Float32LEType); + break; + case "float64": + type = typeof(Float64Type); + break; + case "float64-le": + type = typeof(Float64LEType); + break; + case "uint32": + type = typeof(UInt32Type); + break; + case "uint64": + type = typeof(UInt64Type); + break; + case "int32": + type = typeof(Int32Type); + break; + case "int64": + type = typeof(Int64Type); + break; + case "binary": + type = typeof(BinaryType); + break; + default: + throw new Exception($"Invalid type '{columnTypeString}' for column {j}"); + } + + // TODO: Allow lax parsing (only worry about parsing columns that are given in the specifying type + + if (headerNames[j] != columnName) + { + throw new Exception($"Column {j} has name {columnName}, but expected {headerNames[j]}"); + } + + if (headerTypes[j] != type) + { + throw new Exception($"Column {j} has type {type}, but expected {headerTypes[j]}"); + } + } + + if (currentComment.Length > 0) + { + if (parsed is CommentedTsv commentedParsed) + { + commentedParsed.FileComment = currentComment.ToString(); + currentComment.Clear(); + } + else + { + throw new Exception("Found a file comment, but parser wasn't expecting a comment"); + } + } + + fields.Clear(); + + line++; + currentLineStart = i + 1; + + // Done parsing header + break; + } + else if (inputBuffer[i] == '#') + { + if (i == currentLineStart && format >= FormatType.COMMENTED_TSV) + { + int j = i; + for (; j < inputBuffer.Length && inputBuffer[j] != '\n'; j++) { } + if (j < inputBuffer.Length) + { + var commentBytes = new byte[j - i - 1]; + Array.Copy(inputBuffer, i + 1, commentBytes, 0, j - i - 1); + if (currentComment.Length > 0) + { + currentComment.Append('\n'); + } + currentComment.Append(Encoding.UTF8.GetString(commentBytes)); + i = j; + currentLineStart = i + 1; + line++; + } + else + { + throw new Exception("Comments at end of file are not allowed"); + } + } + else + { + throw new Exception($"Found unescaped '#' at line {line}, column {i - currentLineStart}"); + } + } + else + { + fieldBytes.Add(inputBuffer[i]); + } + } + + T[] parsedRecords = ParseParallel(inputBuffer, format, headerPropertyInfos.ToArray(), headerTypes.ToArray(), currentLineStart - 1, inputBuffer.Length); + + parsed.Records.AddRange(parsedRecords); + + return parsed; + } + + // This approach is slightly different than others. We skip the record that startIndex is in and + // include the record that endIndex is in. We do this because in order to include the record + // startIndex is in we'd have to go back to the start of the record's comment, and to know + // exactly where that comment started we'd have to go back to the start of the record before that + // (not including that other record's comment). + protected static T[] ParseParallel(byte[] inputBuffer, FormatType format, PropertyInfo[] headerPropertyInfos, Type[] headerTypes, int startIndex, int endIndex) where T : TsvRecord, new() + { + var fieldBytes = new List(); + var fields = new List(); + var currentComment = new StringBuilder(); + List parsed = new List(); + bool parsingLastRecord = false; + + int relativeLine = 0; + + int i = startIndex; + while (i < inputBuffer.Length - 1 && inputBuffer[i] != '\n' && inputBuffer[i + 1] != '#') + { + i++; + } + + if (i >= inputBuffer.Length - 1) + { + return Array.Empty(); + } + + // Start parsing after \n + i++; + + int currentLineStart = i; + + for (; i < inputBuffer.Length && (i < endIndex || parsingLastRecord); i++) + { + if (inputBuffer[i] == '\\') + { + if (i + 1 == inputBuffer.Count()) + { + throw new Exception($"Found '\\' at end of input"); + } + if (inputBuffer[i + 1] == 'n') + { + fieldBytes.Add((byte)'\n'); + i++; + } + else if (inputBuffer[i + 1] == '\\') + { + fieldBytes.Add((byte)'\\'); + i++; + } + else if (inputBuffer[i + 1] == 't') + { + fieldBytes.Add((byte)'\t'); + i++; + } + else if (inputBuffer[i + 1] == '#') + { + fieldBytes.Add((byte)'#'); + i++; + } + else + { + throw new Exception($"Expected 'n', 't', '#', or '\\' after '\\' at line {relativeLine} column {i - currentLineStart}"); + } + } + else if (inputBuffer[i] == '\t') + { + // end of field + fields.Add(fieldBytes.ToArray()); + fieldBytes.Clear(); + } + else if (inputBuffer[i] == '\n') + { + fields.Add(fieldBytes.ToArray()); + fieldBytes.Clear(); + + if (headerTypes.Length != fields.Count) + { + throw new Exception($"Expected {headerTypes.Length} fields on line {relativeLine}, but found {fields.Count}"); + } + else + { + string comment = null; + if (currentComment.Length > 0) + { + comment = currentComment.ToString(); + currentComment.Clear(); + } + parsed.Add(ParseCurrentRecord(headerTypes.ToArray(), headerPropertyInfos.ToArray(), fields, comment, relativeLine)); + fields.Clear(); + } + + parsingLastRecord = false; + relativeLine++; + currentLineStart = i + 1; + } + else if (inputBuffer[i] == '#') + { + if (i == currentLineStart && format >= FormatType.COMMENTED_TSV) + { + int j = i; + for (; j < inputBuffer.Length && inputBuffer[j] != '\n'; j++) { } + if (j < inputBuffer.Length) + { + var commentBytes = new byte[j - i - 1]; + Array.Copy(inputBuffer, i + 1, commentBytes, 0, j - i - 1); + if (currentComment.Length > 0) + { + currentComment.Append('\n'); + } + currentComment.Append(Encoding.UTF8.GetString(commentBytes)); + i = j; + currentLineStart = i + 1; + relativeLine++; + } + else + { + throw new Exception("Comments at end of file are not allowed"); + } + } + else + { + throw new Exception($"Found unescaped '#' at line {relativeLine}, column {i - currentLineStart}"); + } + } + else + { + fieldBytes.Add(inputBuffer[i]); + } + + if (i == endIndex - 1) + { + parsingLastRecord = true; + } + } + + fields.Add(fieldBytes.ToArray()); + + if (fields.Count == 0) + { + // TODO + throw new Exception("Not sure when this will happen. THis might actuall be fine"); + } + if (fields.Count != headerTypes.Length) + { + throw new Exception($"Expected {headerTypes} fields on line {relativeLine}, but found {fields.Count}"); + } + else + { + string comment = null; + if (currentComment.Length > 0) + { + comment = currentComment.ToString(); + currentComment.Clear(); + } + parsed.Add(ParseCurrentRecord(headerTypes.ToArray(), headerPropertyInfos.ToArray(), fields, comment, relativeLine)); + fields.Clear(); + } + + return parsed.ToArray(); + } + protected static T ParseCurrentCommentedRecord(Type[] columnTypes, PropertyInfo[] properties, List fields, string comment, int line) where T : CommentedTsvRecord, new() { return (T)ParseCurrentRecord(columnTypes, properties, fields, comment, line); @@ -1150,7 +1551,7 @@ public class SaneTsv records.Add(fieldStrings); fields.Clear(); } - + return records.ToArray(); } @@ -1540,6 +1941,288 @@ public class SaneTsv return bytes.ToArray(); } + protected static byte[] SerializeTsvParallel(IList data, FormatType tsvFormat) + { + var bytes = new List(); + + var headerTypes = new List(); + var headerNames = new List(); + var headerPropertyInfos = new List(); + int columnCount = 0; + + // Serialize header + foreach (PropertyInfo property in typeof(T).GetProperties()) + { + TsvColumnAttribute attribute = (TsvColumnAttribute)Attribute.GetCustomAttribute(property, typeof(TsvColumnAttribute)); + if (attribute == null) + { + continue; + } + + string headerName = attribute.ColumnName ?? property.Name; + headerNames.Add(headerName); + Type headerType = attribute.ColumnType ?? GetColumnFromType(property.PropertyType); + if (tsvFormat == FormatType.SIMPLE_TSV && headerType != typeof(StringType)) + { + throw new Exception($"Serializing Simple TSV requires all columns be of type string, but column '{headerName}' has type '{headerType}'"); + } + headerTypes.Add(headerType); + headerPropertyInfos.Add(property); + // TODO: Check that the property type and given column type are compatible + columnCount++; + } + + // Serialize header + for (int i = 0; i < headerNames.Count; i++) + { + for (int j = i + 1; j < headerNames.Count; j++) + { + if (headerNames[i] == headerNames[j]) + { + throw new Exception("Column names in header must be unique"); + } + } + + byte[] nameEncoded = Encoding.UTF8.GetBytes(headerNames[i]); + + for (int j = 0; j < nameEncoded.Length; j++) + { + if (nameEncoded[j] == '\n') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'n'); + } + else if (nameEncoded[j] == '\t') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'t'); + } + else if (nameEncoded[j] == '\\') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'\\'); + } + else if (nameEncoded[j] == '#') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'#'); + } + else + { + bytes.Add(nameEncoded[j]); + } + } + + if (tsvFormat != FormatType.SIMPLE_TSV) + { + bytes.Add((byte)':'); + try + { + bytes.AddRange(Encoding.UTF8.GetBytes(GetNameFromColumn(headerTypes[i]))); + } + catch (Exception e) + { + throw new Exception($"Invalid header type for column {i}", e); + } + } + + if (i == headerNames.Count - 1) + { + bytes.Add((byte)'\n'); + } + else + { + bytes.Add((byte)'\t'); + } + } + + // Serialize data + SerializeTsvParallel(data, bytes, headerPropertyInfos.ToArray(), headerTypes.ToArray(), tsvFormat, 0, data.Count); + + return bytes.ToArray(); + } + + protected static void SerializeTsvParallel(IList data, List bytes, PropertyInfo[] headerPropertyInfos, Type[] headerTypes, FormatType tsvFormat, int startIndex, int endIndex) + { + // Serialize data + for (int i = 0; i < data.Count; i++) + { + for (int j = 0; j < headerTypes.Length; j++) + { + object datum = headerPropertyInfos[j].GetValue(data[i]); + + try + { + byte[] fieldEncoded = null; + // Some fields definitely don't need escaping, so we add them directly to bytes + bool skipEscaping = false; + + if (headerTypes[j] == typeof(StringType)) + { + fieldEncoded = Encoding.UTF8.GetBytes((string)datum); + } + else if (headerTypes[j] == typeof(BooleanType)) + { + bytes.AddRange((bool)datum ? TrueEncoded : FalseEncoded); + skipEscaping = true; + } + else if (headerTypes[j] == typeof(Float32Type)) + { + if (datum is float f) + { + if (float.IsNegativeInfinity(f)) + { + bytes.AddRange(Encoding.UTF8.GetBytes("-inf")); + } + else if (float.IsPositiveInfinity(f)) + { + bytes.AddRange(Encoding.UTF8.GetBytes("+inf")); + } + else + { + // See https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings#round-trip-format-specifier-r + bytes.AddRange(Encoding.UTF8.GetBytes(((float)datum).ToString("G9"))); + } + } + else + { + throw new InvalidCastException(); + } + skipEscaping = true; + } + else if (headerTypes[j] == typeof(Float32LEType)) + { + if (LittleEndian) + { + fieldEncoded = BitConverter.GetBytes((float)datum); + } + else + { + byte[] floatBytes = BitConverter.GetBytes((float)datum); + fieldEncoded = new byte[sizeof(float)]; + for (int k = 0; k < sizeof(float); k++) + { + fieldEncoded[k] = floatBytes[sizeof(float) - 1 - k]; + } + } + } + else if (headerTypes[j] == typeof(Float64Type)) + { + if (datum is double d) + { + if (double.IsNegativeInfinity(d)) + { + bytes.AddRange(Encoding.UTF8.GetBytes("-inf")); + } + else if (double.IsPositiveInfinity(d)) + { + bytes.AddRange(Encoding.UTF8.GetBytes("+inf")); + } + else + { + // See https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings#round-trip-format-specifier-r + bytes.AddRange(Encoding.UTF8.GetBytes((d).ToString("G17"))); + } + } + else + { + throw new InvalidCastException(); + } + skipEscaping = true; + } + else if (headerTypes[j] == typeof(Float64LEType)) + { + if (LittleEndian) + { + fieldEncoded = BitConverter.GetBytes((double)datum); + } + else + { + byte[] doubleBytes = BitConverter.GetBytes((double)datum); + fieldEncoded = new byte[sizeof(double)]; + for (int k = 0; k < sizeof(double); k++) + { + fieldEncoded[k] = doubleBytes[sizeof(double) - 1 - k]; + } + } + } + else if (headerTypes[j] == typeof(UInt32Type)) + { + bytes.AddRange(Encoding.UTF8.GetBytes(((UInt32)datum).ToString())); + skipEscaping = true; + } + else if (headerTypes[j] == typeof(UInt64Type)) + { + bytes.AddRange(Encoding.UTF8.GetBytes(((UInt64)datum).ToString())); + skipEscaping = true; + } + else if (headerTypes[j] == typeof(Int32Type)) + { + bytes.AddRange(Encoding.UTF8.GetBytes(((Int32)datum).ToString())); + skipEscaping = true; + } + else if (headerTypes[j] == typeof(Int64Type)) + { + bytes.AddRange(Encoding.UTF8.GetBytes(((Int64)datum).ToString())); + skipEscaping = true; + } + else if (headerTypes[j] == typeof(BinaryType)) + { + fieldEncoded = (byte[])datum; + } + else + { + throw new Exception($"Unexpected column type {headerTypes[j]} for column {j}"); + } + + if (!skipEscaping) + { + for (int k = 0; k < fieldEncoded.Length; k++) + { + if (fieldEncoded[k] == '\n') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'n'); + } + else if (fieldEncoded[k] == '\t') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'t'); + } + else if (fieldEncoded[k] == '\\') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'\\'); + } + else if (fieldEncoded[k] == '#') + { + bytes.Add((byte)'\\'); + bytes.Add((byte)'#'); + } + else + { + bytes.Add(fieldEncoded[k]); + } + } + } + + if (j < headerTypes.Length - 1) + { + bytes.Add((byte)'\t'); + } + else if (i < data.Count - 1) + { + bytes.Add((byte)'\n'); + } + } + catch (InvalidCastException e) + { + throw new Exception($"Record {i}, field {j} expected type compatible with {GetNameFromColumn(headerTypes[j])}", e); + } + } + } + } + public class SimpleTsvRecord { public string[] ColumnNames { get; } diff --git a/SaneTsv/SaneTsvTest/Program.cs b/SaneTsv/SaneTsvTest/Program.cs index b3bfd40..66929c1 100644 --- a/SaneTsv/SaneTsvTest/Program.cs +++ b/SaneTsv/SaneTsvTest/Program.cs @@ -1,7 +1,7 @@ using NathanMcRae; using System.Text; -internal class Program +internal class Program : SaneTsv { public class TestRecord : SaneTsv.TsvRecord { @@ -349,7 +349,7 @@ internal class Program } { - string testName = "Check parallel serialization"; + string testName = "Check parallel Simple TSV serialization"; int N = 100000; var records = new StringTestRecord[N]; @@ -398,7 +398,7 @@ internal class Program } { - string testName = "Check parallel parsing"; + string testName = "Check Simple TSV parallel parsing"; int N = 100000; var records = new StringTestRecord[N]; @@ -423,8 +423,8 @@ internal class Program (string[] headers, string[][] data) = SaneTsv.ParseSimpleTsvParallel(serialized); TimeSpan parallelTime = DateTime.Now - lastTime; - Console.WriteLine($"Unparallel serialization time: {unparallelTime}"); - Console.WriteLine($"Parallel serialization time: {parallelTime}"); + Console.WriteLine($"Unparallel parse time: {unparallelTime}"); + Console.WriteLine($"Parallel parse time: {parallelTime}"); bool matching = true; for (int j = 0; j < Math.Min(headers2.Length, headers.Length); j++) @@ -458,6 +458,112 @@ internal class Program } } + { + string testName = "Check parallel serialization"; + + int N = 1000; + var records = new BoolTestRecord[N]; + var rand = new Random(1); + + for (int i = 0; i < N; i++) + { + byte[] bytes = new byte[rand.Next(50)]; + rand.NextBytes(bytes); + records[i] = new BoolTestRecord() + { + Column1 = rand.NextDouble() > 0.5, + column2 = bytes, + Column3 = rand.Next().ToString(), + }; + } + + DateTime lastTime = DateTime.Now; + byte[] serialized1 = SaneTsv.SerializeTsv(records, FormatType.COMMENTED_TSV); + TimeSpan unparallelTime = DateTime.Now - lastTime; + lastTime = DateTime.Now; + byte[] serialized2 = SaneTsv.SerializeTsvParallel(records, FormatType.COMMENTED_TSV); + TimeSpan parallelTime = DateTime.Now - lastTime; + + Console.WriteLine($"Unparallel serialization time: {unparallelTime}"); + Console.WriteLine($"Parallel serialization time: {parallelTime}"); + + bool matching = true; + for (int i = 0; i < Math.Min(serialized1.Length, serialized2.Length); i++) + { + if (serialized1[i] != serialized2[i]) + { + matching = false; + break; + } + } + + if (matching) + { + Console.WriteLine($"Passed {testName}"); + } + else + { + Console.WriteLine($"Failed {testName}"); + } + } + + { + string testName = "Check parallel parsing"; + + int N = 1000; + var records = new BoolTestRecord[N]; + var rand = new Random(1); + + for (int i = 0; i < N; i++) + { + byte[] bytes = new byte[rand.Next(50)]; + rand.NextBytes(bytes); + records[i] = new BoolTestRecord() + { + Column1 = rand.NextDouble() > 0.5, + column2 = bytes, + Column3 = rand.Next().ToString(), + }; + } + + byte[] serialized2 = SaneTsv.SerializeTsvParallel(records, FormatType.COMMENTED_TSV); + + DateTime lastTime = DateTime.Now; + CommentedTsv parsed = (CommentedTsv)SaneTsv.Parse(serialized2, FormatType.COMMENTED_TSV); + TimeSpan unparallelTime = DateTime.Now - lastTime; + lastTime = DateTime.Now; + CommentedTsv parsed2 = (CommentedTsv)SaneTsv.ParseParallel(serialized2, FormatType.COMMENTED_TSV); + TimeSpan parallelTime = DateTime.Now - lastTime; + + Console.WriteLine($"Unparallel parsing time: {unparallelTime}"); + Console.WriteLine($"Parallel parsing time: {parallelTime}"); + + bool matching = parsed.FileComment == parsed2.FileComment; + + matching &= parsed.Records.Count == parsed2.Records.Count; + + for (int i = 0; matching && i < parsed.Records.Count; i++) + { + matching &= parsed.Records[i].Comment == parsed2.Records[i].Comment; + matching &= parsed.Records[i].Column1 == parsed2.Records[i].Column1; + matching &= parsed.Records[i].column2.Length == parsed2.Records[i].column2.Length; + for (int j = 0; matching && j < parsed.Records[i].column2.Length; j++) + { + matching &= parsed.Records[i].column2[j] == parsed2.Records[i].column2[j]; + } + } + + if (matching) + { + Console.WriteLine($"Passed {testName}"); + } + else + { + Console.WriteLine($"Failed {testName}"); + } + } + + Console.WriteLine("Done with tests"); } }