SPIR-V: Fix tessellation control shader output types (#3807)

* SPIR-V: Fix tessellation control shader output types

* Shader cache version bump
This commit is contained in:
gdkchan 2022-10-29 13:45:30 -03:00 committed by GitHub
parent 4e34170a84
commit 59cdf310bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 6 deletions

View file

@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
private const ushort FileFormatVersionMajor = 1; private const ushort FileFormatVersionMajor = 1;
private const ushort FileFormatVersionMinor = 2; private const ushort FileFormatVersionMinor = 2;
private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor; private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
private const uint CodeGenVersion = 3781; private const uint CodeGenVersion = 3807;
private const string SharedTocFileName = "shared.toc"; private const string SharedTocFileName = "shared.toc";
private const string SharedDataFileName = "shared.data"; private const string SharedDataFileName = "shared.data";

View file

@ -262,6 +262,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
Instruction ioVariable, elemIndex; Instruction ioVariable, elemIndex;
Instruction invocationId = null;
if (Config.Stage == ShaderStage.TessellationControl && isOutAttr)
{
invocationId = Load(TypeS32(), Inputs[AttributeConsts.InvocationId]);
}
bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd; bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
if (isUserAttr && if (isUserAttr &&
@ -273,7 +280,17 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex()); elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
var vecIndex = Constant(TypeU32(), (attr - AttributeConsts.UserAttributeBase) >> 4); var vecIndex = Constant(TypeU32(), (attr - AttributeConsts.UserAttributeBase) >> 4);
if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr)) bool isArray = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr);
if (invocationId != null && isArray)
{
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index, vecIndex, elemIndex);
}
else if (invocationId != null)
{
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, vecIndex, elemIndex);
}
else if (isArray)
{ {
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex); return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex);
} }
@ -308,12 +325,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
if ((type & (AggregateType.Array | AggregateType.Vector)) == 0) if ((type & (AggregateType.Array | AggregateType.Vector)) == 0)
{ {
return isIndexed ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index) : ioVariable; if (invocationId != null)
{
return isIndexed
? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index)
: AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId);
}
else
{
return isIndexed ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index) : ioVariable;
}
} }
elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex()); elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
if (isIndexed) if (invocationId != null && isIndexed)
{
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index, elemIndex);
}
else if (invocationId != null)
{
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, elemIndex);
}
else if (isIndexed)
{ {
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, elemIndex); return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, elemIndex);
} }
@ -327,12 +361,29 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
{ {
var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input; var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
Instruction invocationId = null;
if (Config.Stage == ShaderStage.TessellationControl && isOutAttr)
{
invocationId = Load(TypeS32(), Inputs[AttributeConsts.InvocationId]);
}
elemType = AggregateType.FP32; elemType = AggregateType.FP32;
var ioVariable = isOutAttr ? OutputsArray : InputsArray; var ioVariable = isOutAttr ? OutputsArray : InputsArray;
var vecIndex = ShiftRightLogical(TypeS32(), attrIndex, Constant(TypeS32(), 2)); var vecIndex = ShiftRightLogical(TypeS32(), attrIndex, Constant(TypeS32(), 2));
var elemIndex = BitwiseAnd(TypeS32(), attrIndex, Constant(TypeS32(), 3)); var elemIndex = BitwiseAnd(TypeS32(), attrIndex, Constant(TypeS32(), 3));
if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr)) bool isArray = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr);
if (invocationId != null && isArray)
{
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, index, vecIndex, elemIndex);
}
else if (invocationId != null)
{
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, invocationId, vecIndex, elemIndex);
}
else if (isArray)
{ {
return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex); return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex);
} }

View file

@ -473,6 +473,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var attrType = context.TypeVector(context.TypeFP32(), (LiteralInteger)4); var attrType = context.TypeVector(context.TypeFP32(), (LiteralInteger)4);
attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)MaxAttributes)); attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)MaxAttributes));
if (context.Config.Stage == ShaderStage.TessellationControl)
{
attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), context.Config.ThreadsPerInputPrimitive));
}
var spvType = context.TypePointer(StorageClass.Output, attrType); var spvType = context.TypePointer(StorageClass.Output, attrType);
var spvVar = context.Variable(spvType, StorageClass.Output); var spvVar = context.Variable(spvType, StorageClass.Output);
@ -543,6 +548,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
} }
} }
if (context.Config.Stage == ShaderStage.TessellationControl && isOutAttr && !perPatch)
{
attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), context.Config.ThreadsPerInputPrimitive));
}
var spvType = context.TypePointer(storageClass, attrType); var spvType = context.TypePointer(storageClass, attrType);
var spvVar = context.Variable(spvType, storageClass); var spvVar = context.Variable(spvType, storageClass);
@ -634,6 +644,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)arraySize)); attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), (LiteralInteger)arraySize));
} }
if (context.Config.Stage == ShaderStage.TessellationControl && isOutAttr)
{
attrType = context.TypeArray(attrType, context.Constant(context.TypeU32(), context.Config.ThreadsPerInputPrimitive));
}
var spvType = context.TypePointer(storageClass, attrType); var spvType = context.TypePointer(storageClass, attrType);
var spvVar = context.Variable(spvType, storageClass); var spvVar = context.Variable(spvType, storageClass);

View file

@ -37,7 +37,12 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
Config = config; Config = config;
if (config.GpPassthrough) if (config.Stage == ShaderStage.TessellationControl)
{
// Required to index outputs.
Info.Inputs.Add(AttributeConsts.InvocationId);
}
else if (config.GpPassthrough)
{ {
int passthroughAttributes = config.PassthroughAttributes; int passthroughAttributes = config.PassthroughAttributes;
while (passthroughAttributes != 0) while (passthroughAttributes != 0)