hle: Improve safety (#2778)

* timezone: Make timezone implementation safe

* hle: Do not use TrimEnd to parse ASCII strings

This adds an util that handle reading an ASCII string in a safe way.
Previously it was possible to read malformed data that could cause
various undefined behaviours in multiple services.

* hid: Remove an useless unsafe modifier on keyboard update

* Address gdkchan's comment

* Address gdkchan's comment
This commit is contained in:
Mary 2021-10-25 00:13:20 +02:00 committed by GitHub
parent b4dc33efc2
commit 51fa1b2cb0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 141 additions and 172 deletions

View file

@ -1,3 +1,4 @@
using Ryujinx.HLE.Utilities;
using System.IO;
using System.Text;
@ -30,10 +31,10 @@ namespace Ryujinx.HLE.FileSystem.Content
reader.ReadBytes(2); // Padding
PlatformString = Encoding.ASCII.GetString(reader.ReadBytes(0x20)).TrimEnd('\0');
Hex = Encoding.ASCII.GetString(reader.ReadBytes(0x40)).TrimEnd('\0');
VersionString = Encoding.ASCII.GetString(reader.ReadBytes(0x18)).TrimEnd('\0');
VersionTitle = Encoding.ASCII.GetString(reader.ReadBytes(0x80)).TrimEnd('\0');
PlatformString = StringUtils.ReadInlinedAsciiString(reader, 0x20);
Hex = StringUtils.ReadInlinedAsciiString(reader, 0x40);
VersionString = StringUtils.ReadInlinedAsciiString(reader, 0x18);
VersionTitle = StringUtils.ReadInlinedAsciiString(reader, 0x80);
}
}
}

View file

@ -8,7 +8,7 @@ namespace Ryujinx.HLE.HOS.Services.Hid
{
public KeyboardDevice(Switch device, bool active) : base(device, active) { }
public unsafe void Update(KeyboardInput keyState)
public void Update(KeyboardInput keyState)
{
ref RingLifo<KeyboardState> lifo = ref _device.Hid.SharedMemory.Keyboard;

View file

@ -218,11 +218,7 @@ namespace Ryujinx.HLE.HOS.Services.Sockets.Sfdnsres
private ResultCode GetHostByNameRequestImpl(ServiceCtx context, ulong inputBufferPosition, ulong inputBufferSize, ulong outputBufferPosition, ulong outputBufferSize, ulong optionsBufferPosition, ulong optionsBufferSize)
{
byte[] rawName = new byte[inputBufferSize];
context.Memory.Read(inputBufferPosition, rawName);
string name = Encoding.ASCII.GetString(rawName).TrimEnd('\0');
string name = MemoryHelper.ReadAsciiString(context.Memory, inputBufferPosition, (int)inputBufferSize);
// TODO: Use params.
bool enableNsdResolve = (context.RequestData.ReadInt32() & 1) != 0;

View file

@ -116,7 +116,7 @@ namespace Ryujinx.HLE.HOS.Services.Time
// SetupTimeZoneManager(nn::time::LocationName location_name, nn::time::SteadyClockTimePoint timezone_update_timepoint, u32 total_location_name_count, nn::time::TimeZoneRuleVersion timezone_rule_version, buffer<nn::time::TimeZoneBinary, 0x21> timezone_binary)
public ResultCode SetupTimeZoneManager(ServiceCtx context)
{
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
SteadyClockTimePoint timeZoneUpdateTimePoint = context.RequestData.ReadStruct<SteadyClockTimePoint>();
uint totalLocationNameCount = context.RequestData.ReadUInt32();
UInt128 timeZoneRuleVersion = context.RequestData.ReadStruct<UInt128>();

View file

@ -1,6 +1,7 @@
using Ryujinx.Common.Logging;
using Ryujinx.Cpu;
using Ryujinx.HLE.HOS.Services.Time.TimeZone;
using Ryujinx.HLE.Utilities;
using System;
using System.Text;
@ -35,7 +36,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.StaticService
return ResultCode.PermissionDenied;
}
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
return _timeZoneContentManager.SetDeviceLocationName(locationName);
}
@ -97,7 +98,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.StaticService
throw new InvalidOperationException();
}
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
ResultCode resultCode = _timeZoneContentManager.LoadTimeZoneRule(out TimeZoneRule rules, locationName);

View file

@ -125,7 +125,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.StaticService
(ulong bufferPosition, ulong bufferSize) = context.Request.GetBufferType0x21();
string locationName = Encoding.ASCII.GetString(context.RequestData.ReadBytes(0x24)).TrimEnd('\0');
string locationName = StringUtils.ReadInlinedAsciiString(context.RequestData, 0x24);
ResultCode result;

View file

@ -2,6 +2,7 @@
using Ryujinx.Common.Utilities;
using Ryujinx.HLE.Utilities;
using System;
using System.Buffers.Binary;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
@ -107,40 +108,24 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
public int TransitionTime;
}
private static int Detzcode32(byte[] bytes)
private static int Detzcode32(ReadOnlySpan<byte> bytes)
{
if (BitConverter.IsLittleEndian)
{
Array.Reverse(bytes, 0, bytes.Length);
}
return BitConverter.ToInt32(bytes, 0);
return BinaryPrimitives.ReadInt32BigEndian(bytes);
}
private static unsafe int Detzcode32(int* data)
private static int Detzcode32(int value)
{
int result = *data;
if (BitConverter.IsLittleEndian)
{
byte[] bytes = BitConverter.GetBytes(result);
Array.Reverse(bytes, 0, bytes.Length);
result = BitConverter.ToInt32(bytes, 0);
return BinaryPrimitives.ReverseEndianness(value);
}
return result;
return value;
}
private static unsafe long Detzcode64(long* data)
private static long Detzcode64(ReadOnlySpan<byte> bytes)
{
long result = *data;
if (BitConverter.IsLittleEndian)
{
byte[] bytes = BitConverter.GetBytes(result);
Array.Reverse(bytes, 0, bytes.Length);
result = BitConverter.ToInt64(bytes, 0);
}
return result;
return BinaryPrimitives.ReadInt64BigEndian(bytes);
}
private static bool DifferByRepeat(long t1, long t0)
@ -148,7 +133,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
return (t1 - t0) == SecondsPerRepeat;
}
private static unsafe bool TimeTypeEquals(TimeZoneRule outRules, byte aIndex, byte bIndex)
private static bool TimeTypeEquals(TimeZoneRule outRules, byte aIndex, byte bIndex)
{
if (aIndex < 0 || aIndex >= outRules.TypeCount || bIndex < 0 || bIndex >= outRules.TypeCount)
{
@ -158,17 +143,14 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
TimeTypeInfo a = outRules.Ttis[aIndex];
TimeTypeInfo b = outRules.Ttis[bIndex];
fixed (char* chars = outRules.Chars)
{
return a.GmtOffset == b.GmtOffset &&
a.IsDaySavingTime == b.IsDaySavingTime &&
a.IsStandardTimeDaylight == b.IsStandardTimeDaylight &&
a.IsGMT == b.IsGMT &&
StringUtils.CompareCStr(chars + a.AbbreviationListIndex, chars + b.AbbreviationListIndex) == 0;
}
return a.GmtOffset == b.GmtOffset &&
a.IsDaySavingTime == b.IsDaySavingTime &&
a.IsStandardTimeDaylight == b.IsStandardTimeDaylight &&
a.IsGMT == b.IsGMT &&
StringUtils.CompareCStr(outRules.Chars[a.AbbreviationListIndex..], outRules.Chars[b.AbbreviationListIndex..]) == 0;
}
private static int GetQZName(char[] name, int namePosition, char delimiter)
private static int GetQZName(ReadOnlySpan<char> name, int namePosition, char delimiter)
{
int i = namePosition;
@ -403,7 +385,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
return 0;
}
private static bool ParsePosixName(Span<char> name, out TimeZoneRule outRules, bool lastDitch)
private static bool ParsePosixName(ReadOnlySpan<char> name, out TimeZoneRule outRules, bool lastDitch)
{
outRules = new TimeZoneRule
{
@ -414,9 +396,10 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
};
int stdLen;
Span<char> stdName = name;
int namePosition = 0;
int stdOffset = 0;
ReadOnlySpan<char> stdName = name;
int namePosition = 0;
int stdOffset = 0;
if (lastDitch)
{
@ -433,7 +416,8 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
int stdNamePosition = namePosition;
namePosition = GetQZName(name.ToArray(), namePosition, '>');
namePosition = GetQZName(name, namePosition, '>');
if (name[namePosition] != '>')
{
return false;
@ -465,7 +449,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
int destLen = 0;
int dstOffset = 0;
Span<char> destName = name.Slice(namePosition);
ReadOnlySpan<char> destName = name.Slice(namePosition);
if (TzCharsArraySize < charCount)
{
@ -903,7 +887,7 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
return ParsePosixName(name.ToCharArray(), out outRules, false);
}
internal static unsafe bool ParseTimeZoneBinary(out TimeZoneRule outRules, Stream inputData)
internal static bool ParseTimeZoneBinary(out TimeZoneRule outRules, Stream inputData)
{
outRules = new TimeZoneRule
{
@ -967,12 +951,11 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
timeCount = 0;
fixed (byte* workBufferPtrStart = workBuffer)
{
byte* p = workBufferPtrStart;
Span<byte> p = workBuffer;
for (int i = 0; i < outRules.TimeCount; i++)
{
long at = Detzcode64((long*)p);
long at = Detzcode64(p);
outRules.Types[i] = 1;
if (timeCount != 0 && at <= outRules.Ats[timeCount - 1])
@ -988,13 +971,15 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
outRules.Ats[timeCount++] = at;
p += TimeTypeSize;
p = p[TimeTypeSize..];
}
timeCount = 0;
for (int i = 0; i < outRules.TimeCount; i++)
{
byte type = *p++;
byte type = p[0];
p = p[1..];
if (outRules.TypeCount <= type)
{
return false;
@ -1011,18 +996,20 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
for (int i = 0; i < outRules.TypeCount; i++)
{
TimeTypeInfo ttis = outRules.Ttis[i];
ttis.GmtOffset = Detzcode32((int*)p);
p += 4;
ttis.GmtOffset = Detzcode32(p);
p = p[sizeof(int)..];
if (*p >= 2)
if (p[0] >= 2)
{
return false;
}
ttis.IsDaySavingTime = *p != 0;
p++;
ttis.IsDaySavingTime = p[0] != 0;
p = p[1..];
int abbreviationListIndex = p[0];
p = p[1..];
int abbreviationListIndex = *p++;
if (abbreviationListIndex >= outRules.CharCount)
{
return false;
@ -1033,12 +1020,9 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
outRules.Ttis[i] = ttis;
}
fixed (char* chars = outRules.Chars)
{
Encoding.ASCII.GetChars(p, outRules.CharCount, chars, outRules.CharCount);
}
Encoding.ASCII.GetChars(p[..outRules.CharCount].ToArray()).CopyTo(outRules.Chars.AsSpan());
p += outRules.CharCount;
p = p[outRules.CharCount..];
outRules.Chars[outRules.CharCount] = '\0';
for (int i = 0; i < outRules.TypeCount; i++)
@ -1049,14 +1033,14 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
}
else
{
if (*p >= 2)
if (p[0] >= 2)
{
return false;
}
outRules.Ttis[i].IsStandardTimeDaylight = *p++ != 0;
outRules.Ttis[i].IsStandardTimeDaylight = p[0] != 0;
p = p[1..];
}
}
for (int i = 0; i < outRules.TypeCount; i++)
@ -1067,17 +1051,18 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
}
else
{
if (*p >= 2)
if (p[0] >= 2)
{
return false;
}
outRules.Ttis[i].IsGMT = *p++ != 0;
outRules.Ttis[i].IsGMT = p[0] != 0;
p = p[1..];
}
}
long position = (p - workBufferPtrStart);
long position = (workBuffer.Length - p.Length);
long nRead = streamLength - position;
if (nRead < 0)
@ -1107,77 +1092,75 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
int abbreviationCount = 0;
charCount = outRules.CharCount;
fixed (char* chars = outRules.Chars)
Span<char> chars = outRules.Chars;
for (int i = 0; i < tempRules.TypeCount; i++)
{
for (int i = 0; i < tempRules.TypeCount; i++)
ReadOnlySpan<char> tempChars = tempRules.Chars;
ReadOnlySpan<char> tempAbbreviation = tempChars[tempRules.Ttis[i].AbbreviationListIndex..];
int j;
for (j = 0; j < charCount; j++)
{
fixed (char* tempChars = tempRules.Chars)
if (StringUtils.CompareCStr(chars[j..], tempAbbreviation) == 0)
{
char* tempAbbreviation = tempChars + tempRules.Ttis[i].AbbreviationListIndex;
int j;
for (j = 0; j < charCount; j++)
{
if (StringUtils.CompareCStr(chars + j, tempAbbreviation) == 0)
{
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
break;
}
}
if (j >= charCount)
{
int abbreviationLength = StringUtils.LengthCstr(tempAbbreviation);
if (j + abbreviationLength < TzMaxChars)
{
for (int x = 0; x < abbreviationLength; x++)
{
chars[j + x] = tempAbbreviation[x];
}
charCount = j + abbreviationLength + 1;
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
}
}
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
break;
}
}
if (abbreviationCount == tempRules.TypeCount)
if (j >= charCount)
{
outRules.CharCount = charCount;
// Remove trailing
while (1 < outRules.TimeCount && (outRules.Types[outRules.TimeCount - 1] == outRules.Types[outRules.TimeCount - 2]))
int abbreviationLength = StringUtils.LengthCstr(tempAbbreviation);
if (j + abbreviationLength < TzMaxChars)
{
outRules.TimeCount--;
}
int i;
for (i = 0; i < tempRules.TimeCount; i++)
{
if (outRules.TimeCount == 0 || outRules.Ats[outRules.TimeCount - 1] < tempRules.Ats[i])
for (int x = 0; x < abbreviationLength; x++)
{
break;
chars[j + x] = tempAbbreviation[x];
}
}
while (i < tempRules.TimeCount && outRules.TimeCount < TzMaxTimes)
charCount = j + abbreviationLength + 1;
tempRules.Ttis[i].AbbreviationListIndex = j;
abbreviationCount++;
}
}
}
if (abbreviationCount == tempRules.TypeCount)
{
outRules.CharCount = charCount;
// Remove trailing
while (1 < outRules.TimeCount && (outRules.Types[outRules.TimeCount - 1] == outRules.Types[outRules.TimeCount - 2]))
{
outRules.TimeCount--;
}
int i;
for (i = 0; i < tempRules.TimeCount; i++)
{
if (outRules.TimeCount == 0 || outRules.Ats[outRules.TimeCount - 1] < tempRules.Ats[i])
{
outRules.Ats[outRules.TimeCount] = tempRules.Ats[i];
outRules.Types[outRules.TimeCount] = (byte)(outRules.TypeCount + (byte)tempRules.Types[i]);
outRules.TimeCount++;
i++;
break;
}
}
for (i = 0; i < tempRules.TypeCount; i++)
{
outRules.Ttis[outRules.TypeCount++] = tempRules.Ttis[i];
}
while (i < tempRules.TimeCount && outRules.TimeCount < TzMaxTimes)
{
outRules.Ats[outRules.TimeCount] = tempRules.Ats[i];
outRules.Types[outRules.TimeCount] = (byte)(outRules.TypeCount + (byte)tempRules.Types[i]);
outRules.TimeCount++;
i++;
}
for (i = 0; i < tempRules.TypeCount; i++)
{
outRules.Ttis[outRules.TypeCount++] = tempRules.Ttis[i];
}
}
}
@ -1467,17 +1450,11 @@ namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
{
calendarAdditionalInfo.IsDaySavingTime = rules.Ttis[ttiIndex].IsDaySavingTime;
unsafe
{
fixed (char* timeZoneAbbreviation = &rules.Chars[rules.Ttis[ttiIndex].AbbreviationListIndex])
{
int timeZoneSize = Math.Min(StringUtils.LengthCstr(timeZoneAbbreviation), 8);
for (int i = 0; i < timeZoneSize; i++)
{
calendarAdditionalInfo.TimezoneName[i] = timeZoneAbbreviation[i];
}
}
}
ReadOnlySpan<char> timeZoneAbbreviation = rules.Chars.AsSpan()[rules.Ttis[ttiIndex].AbbreviationListIndex..];
int timeZoneSize = Math.Min(StringUtils.LengthCstr(timeZoneAbbreviation), 8);
timeZoneAbbreviation[..timeZoneSize].CopyTo(calendarAdditionalInfo.TimezoneName.AsSpan());
}
return result;

View file

@ -1,34 +1,19 @@
using System.Runtime.InteropServices;
using Ryujinx.Common.Memory;
using System.Runtime.InteropServices;
namespace Ryujinx.HLE.HOS.Services.Time.TimeZone
{
[StructLayout(LayoutKind.Sequential, Pack = 0x4, Size = 0x2C)]
struct TzifHeader
{
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public char[] Magic;
public char Version;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 15)]
public byte[] Reserved;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public byte[] TtisGMTCount;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public byte[] TtisSTDCount;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public byte[] LeapCount;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public byte[] TimeCount;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public byte[] TypeCount;
[MarshalAs(UnmanagedType.ByValArray, SizeConst = 4)]
public byte[] CharCount;
public Array4<byte> Magic;
public byte Version;
private Array15<byte> _reserved;
public int TtisGMTCount;
public int TtisSTDCount;
public int LeapCount;
public int TimeCount;
public int TypeCount;
public int CharCount;
}
}

View file

@ -36,6 +36,15 @@ namespace Ryujinx.HLE.Utilities
return output;
}
public static string ReadInlinedAsciiString(BinaryReader reader, int maxSize)
{
byte[] data = reader.ReadBytes(maxSize);
int stringSize = Array.IndexOf<byte>(data, 0);
return Encoding.ASCII.GetString(data, 0, stringSize < 0 ? maxSize : stringSize);
}
public static byte[] HexToBytes(string hexString)
{
// Ignore last character if HexLength % 2 != 0.
@ -107,7 +116,7 @@ namespace Ryujinx.HLE.Utilities
}
}
public static unsafe int CompareCStr(char* s1, char* s2)
public static int CompareCStr(ReadOnlySpan<char> s1, ReadOnlySpan<char> s2)
{
int s1Index = 0;
int s2Index = 0;
@ -121,7 +130,7 @@ namespace Ryujinx.HLE.Utilities
return s2[s2Index] - s1[s1Index];
}
public static unsafe int LengthCstr(char* s)
public static int LengthCstr(ReadOnlySpan<char> s)
{
int i = 0;