diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md b/Microsoft.Toolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md new file mode 100644 index 00000000000..5ccc9f037f6 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/AnalyzerReleases.Shipped.md @@ -0,0 +1,2 @@ +; Shipped analyzer releases +; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/AnalyzerReleases.Unshipped.md b/Microsoft.Toolkit.Mvvm.SourceGenerators/AnalyzerReleases.Unshipped.md new file mode 100644 index 00000000000..b27b97bcbc6 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/AnalyzerReleases.Unshipped.md @@ -0,0 +1,19 @@ +; Unshipped analyzer release +; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md + +### New Rules + +Rule ID | Category | Severity | Notes +--------|----------|----------|------- +MVVMTK0001 | Microsoft.Toolkit.Mvvm.SourceGenerators.INotifyPropertyChangedGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0002 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservableObjectGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0003 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservableRecipientGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0004 | Microsoft.Toolkit.Mvvm.SourceGenerators.INotifyPropertyChangedGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0005 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservableObjectGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0006 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservableObjectGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0007 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservableRecipientGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0008 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservableRecipientGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0009 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservablePropertyGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0010 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservablePropertyGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0011 | Microsoft.Toolkit.Mvvm.SourceGenerators.ObservablePropertyGenerator | Error | See https://aka.ms/mvvmtoolkit/error +MVVMTK0012 | Microsoft.Toolkit.Mvvm.SourceGenerators.ICommandGenerator | Error | See https://aka.ms/mvvmtoolkit/error diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Attributes/NotNullWhenAttribute.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Attributes/NotNullWhenAttribute.cs new file mode 100644 index 00000000000..7e63f97aae5 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Attributes/NotNullWhenAttribute.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace System.Diagnostics.CodeAnalysis +{ + /// Specifies that when a method returns , the parameter will not be null even if the corresponding type allows it. + [AttributeUsage(AttributeTargets.Parameter, Inherited = false)] + internal sealed class NotNullWhenAttribute : Attribute + { + /// + /// Initializes a new instance of the class. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be null. + public NotNullWhenAttribute(bool returnValue) + { + ReturnValue = returnValue; + } + + /// + /// Gets a value indicating whether the annotated parameter will be null depending on the return value. + /// + public bool ReturnValue { get; } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs new file mode 100644 index 00000000000..4f4534b5818 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/INotifyPropertyChangedGenerator.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for the INotifyPropertyChangedAttribute type. + /// + [Generator] + public sealed class INotifyPropertyChangedGenerator : TransitiveMembersGenerator + { + /// + /// Initializes a new instance of the class. + /// + public INotifyPropertyChangedGenerator() + : base("Microsoft.Toolkit.Mvvm.ComponentModel.INotifyPropertyChangedAttribute") + { + } + + /// + protected override DiagnosticDescriptor TargetTypeErrorDescriptor => INotifyPropertyChangedGeneratorError; + + /// + protected override bool ValidateTargetType( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + [NotNullWhen(false)] out DiagnosticDescriptor? descriptor) + { + INamedTypeSymbol iNotifyPropertyChangedSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.INotifyPropertyChanged")!; + + // Check if the type already implements INotifyPropertyChanged + if (classDeclarationSymbol.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, iNotifyPropertyChangedSymbol))) + { + descriptor = DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError; + + return false; + } + + descriptor = null; + + return true; + } + + /// + protected override IEnumerable FilterDeclaredMembers( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + ClassDeclarationSyntax sourceDeclaration) + { + // If requested, only include the event and the basic methods to raise it, but not the additional helpers + if (attributeData.HasNamedArgument("IncludeAdditionalHelperMethods", false)) + { + return sourceDeclaration.Members.Where(static member => + { + return member + is EventFieldDeclarationSyntax + or MethodDeclarationSyntax { Identifier: { ValueText: "OnPropertyChanged" } }; + }); + } + + return sourceDeclaration.Members; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs new file mode 100644 index 00000000000..c1b1a37ffbc --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableObjectGenerator.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for the ObservableObjectAttribute type. + /// + [Generator] + public sealed class ObservableObjectGenerator : TransitiveMembersGenerator + { + /// + /// Initializes a new instance of the class. + /// + public ObservableObjectGenerator() + : base("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableObjectAttribute") + { + } + + /// + protected override DiagnosticDescriptor TargetTypeErrorDescriptor => ObservableObjectGeneratorError; + + /// + protected override bool ValidateTargetType( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + [NotNullWhen(false)] out DiagnosticDescriptor? descriptor) + { + INamedTypeSymbol + iNotifyPropertyChangedSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.INotifyPropertyChanged")!, + iNotifyPropertyChangingSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.INotifyPropertyChanging")!; + + // Check if the type already implements INotifyPropertyChanged... + if (classDeclarationSymbol.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, iNotifyPropertyChangedSymbol))) + { + descriptor = DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError; + + return false; + } + + // ...or INotifyPropertyChanging + if (classDeclarationSymbol.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i, iNotifyPropertyChangingSymbol))) + { + descriptor = DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError; + + return false; + } + + descriptor = null; + + return true; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.SyntaxReceiver.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.SyntaxReceiver.cs new file mode 100644 index 00000000000..f71dfa367ac --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.SyntaxReceiver.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + public sealed partial class ObservablePropertyGenerator + { + /// + /// An that selects candidate nodes to process. + /// + private sealed class SyntaxReceiver : ISyntaxContextReceiver + { + /// + /// The list of info gathered during exploration. + /// + private readonly List gatheredInfo = new(); + + /// + /// Gets the collection of gathered info to process. + /// + public IReadOnlyCollection GatheredInfo => this.gatheredInfo; + + /// + public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + { + if (context.Node is FieldDeclarationSyntax { AttributeLists: { Count: > 0 } } fieldDeclaration && + context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservablePropertyAttribute") is INamedTypeSymbol attributeSymbol) + { + SyntaxTriviaList leadingTrivia = fieldDeclaration.GetLeadingTrivia(); + + foreach (VariableDeclaratorSyntax variableDeclarator in fieldDeclaration.Declaration.Variables) + { + if (context.SemanticModel.GetDeclaredSymbol(variableDeclarator) is IFieldSymbol fieldSymbol && + fieldSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeSymbol))) + { + this.gatheredInfo.Add(new Item(leadingTrivia, fieldSymbol)); + } + } + } + } + + /// + /// A model for a group of item representing a discovered type to process. + /// + /// The leading trivia for the field declaration. + /// The instance for the target field. + public sealed record Item(SyntaxTriviaList LeadingTrivia, IFieldSymbol FieldSymbol); + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs new file mode 100644 index 00000000000..e5b9402e042 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservablePropertyGenerator.cs @@ -0,0 +1,573 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics.Contracts; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.CodeAnalysis.SymbolDisplayTypeQualificationStyle; +using static Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for the ObservablePropertyAttribute type. + /// + [Generator] + public sealed partial class ObservablePropertyGenerator : ISourceGenerator + { + /// + public void Initialize(GeneratorInitializationContext context) + { + context.RegisterForSyntaxNotifications(static () => new SyntaxReceiver()); + } + + /// + public void Execute(GeneratorExecutionContext context) + { + // Get the syntax receiver with the candidate nodes + if (context.SyntaxContextReceiver is not SyntaxReceiver syntaxReceiver || + syntaxReceiver.GatheredInfo.Count == 0) + { + return; + } + + // Sets of discovered property names + HashSet + propertyChangedNames = new(), + propertyChangingNames = new(); + + // Process the annotated fields + foreach (var items in syntaxReceiver.GatheredInfo.GroupBy(static item => item.FieldSymbol.ContainingType, SymbolEqualityComparer.Default)) + { + if (items.Key.DeclaringSyntaxReferences.Length > 0 && + items.Key.DeclaringSyntaxReferences.First().GetSyntax() is ClassDeclarationSyntax classDeclaration) + { + try + { + OnExecuteForProperties(context, classDeclaration, items.Key, items, propertyChangedNames, propertyChangingNames); + } + catch + { + context.ReportDiagnostic(ObservablePropertyGeneratorError, items.Key, items.Key); + } + } + } + + // Process the fields for the cached args + OnExecuteForPropertyArgs(context, propertyChangedNames, propertyChangingNames); + } + + /// + /// Processes a given target type for declared observable properties. + /// + /// The input instance to use. + /// The node to process. + /// The for . + /// The sequence of fields to process. + /// The collection of discovered property changed names. + /// The collection of discovered property changing names. + private static void OnExecuteForProperties( + GeneratorExecutionContext context, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + IEnumerable items, + ICollection propertyChangedNames, + ICollection propertyChangingNames) + { + INamedTypeSymbol + iNotifyPropertyChangingSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.INotifyPropertyChanging")!, + observableObjectSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableObject")!, + observableObjectAttributeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableObjectAttribute")!, + observableValidatorSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableValidator")!; + + // Check whether the current type implements INotifyPropertyChanging and whether it inherits from ObservableValidator + bool + isObservableObject = classDeclarationSymbol.InheritsFrom(observableObjectSymbol), + isObservableValidator = classDeclarationSymbol.InheritsFrom(observableValidatorSymbol), + isNotifyPropertyChanging = + isObservableObject || + classDeclarationSymbol.AllInterfaces.Contains(iNotifyPropertyChangingSymbol, SymbolEqualityComparer.Default) || + classDeclarationSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, observableObjectAttributeSymbol)); + + // Create the class declaration for the user type. This will produce a tree as follows: + // + // + // { + // + // } + var classDeclarationSyntax = + ClassDeclaration(classDeclarationSymbol.Name) + .WithModifiers(classDeclaration.Modifiers) + .AddMembers(items.Select(item => + CreatePropertyDeclaration( + context, + item.LeadingTrivia, + item.FieldSymbol, + isNotifyPropertyChanging, + isObservableValidator, + propertyChangedNames, + propertyChangingNames)).ToArray()); + + TypeDeclarationSyntax typeDeclarationSyntax = classDeclarationSyntax; + + // Add all parent types in ascending order, if any + foreach (var parentType in classDeclaration.Ancestors().OfType()) + { + typeDeclarationSyntax = parentType + .WithMembers(SingletonList(typeDeclarationSyntax)) + .WithConstraintClauses(List()) + .WithBaseList(null) + .WithAttributeLists(List()) + .WithoutTrivia(); + } + + // Create the compilation unit with the namespace and target member. + // From this, we can finally generate the source code to output. + var namespaceName = classDeclarationSymbol.ContainingNamespace.ToDisplayString(new(typeQualificationStyle: NameAndContainingTypesAndNamespaces)); + + // Create the final compilation unit to generate (with leading trivia) + var source = + CompilationUnit().AddMembers( + NamespaceDeclaration(IdentifierName(namespaceName)).WithLeadingTrivia(TriviaList( + Comment("// Licensed to the .NET Foundation under one or more agreements."), + Comment("// The .NET Foundation licenses this file to you under the MIT license."), + Comment("// See the LICENSE file in the project root for more information."), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))) + .AddMembers(typeDeclarationSyntax)) + .NormalizeWhitespace() + .ToFullString(); + + // Add the partial type + context.AddSource($"{classDeclarationSymbol.GetFullMetadataNameForFileName()}.cs", SourceText.From(source, Encoding.UTF8)); + } + + /// + /// Creates a instance for a specified field. + /// + /// The input instance to use. + /// The leading trivia for the field to process. + /// The input instance to process. + /// Indicates whether or not is also implemented. + /// Indicates whether or not the containing type inherits from ObservableValidator. + /// The collection of discovered property changed names. + /// The collection of discovered property changing names. + /// A generated instance for the input field. + [Pure] + private static PropertyDeclarationSyntax CreatePropertyDeclaration( + GeneratorExecutionContext context, + SyntaxTriviaList leadingTrivia, + IFieldSymbol fieldSymbol, + bool isNotifyPropertyChanging, + bool isObservableValidator, + ICollection propertyChangedNames, + ICollection propertyChangingNames) + { + // Get the field type and the target property name + string + typeName = fieldSymbol.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), + propertyName = GetGeneratedPropertyName(fieldSymbol); + + INamedTypeSymbol alsoNotifyChangeForAttributeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.AlsoNotifyChangeForAttribute")!; + INamedTypeSymbol? validationAttributeSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.ValidationAttribute"); + + List dependentPropertyNotificationStatements = new(); + List validationAttributes = new(); + + foreach (AttributeData attributeData in fieldSymbol.GetAttributes()) + { + // Add dependent property notifications, if needed + if (SymbolEqualityComparer.Default.Equals(attributeData.AttributeClass, alsoNotifyChangeForAttributeSymbol)) + { + foreach (TypedConstant attributeArgument in attributeData.ConstructorArguments) + { + if (attributeArgument.IsNull) + { + continue; + } + + if (attributeArgument.Kind == TypedConstantKind.Primitive && + attributeArgument.Value is string dependentPropertyName) + { + propertyChangedNames.Add(dependentPropertyName); + + // OnPropertyChanged("OtherPropertyName"); + dependentPropertyNotificationStatements.Add(ExpressionStatement( + InvocationExpression(IdentifierName("OnPropertyChanged")) + .AddArgumentListArguments(Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs"), + IdentifierName($"{dependentPropertyName}{nameof(PropertyChangedEventArgs)}")))))); + } + else if (attributeArgument.Kind == TypedConstantKind.Array) + { + foreach (TypedConstant nestedAttributeArgument in attributeArgument.Values) + { + if (nestedAttributeArgument.IsNull) + { + continue; + } + + string currentPropertyName = (string)nestedAttributeArgument.Value!; + + propertyChangedNames.Add(currentPropertyName); + + // Additional property names + dependentPropertyNotificationStatements.Add(ExpressionStatement( + InvocationExpression(IdentifierName("OnPropertyChanged")) + .AddArgumentListArguments(Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs"), + IdentifierName($"{currentPropertyName}{nameof(PropertyChangedEventArgs)}")))))); + } + } + } + } + else if (validationAttributeSymbol is not null && + attributeData.AttributeClass?.InheritsFrom(validationAttributeSymbol) == true) + { + // Track the current validation attribute + validationAttributes.Add(attributeData.AsAttributeSyntax()); + } + } + + BlockSyntax setterBlock; + + if (validationAttributes.Count > 0) + { + // Emit a diagnostic if the current type doesn't inherit from ObservableValidator + if (!isObservableValidator) + { + context.ReportDiagnostic( + MissingObservableValidatorInheritanceError, + fieldSymbol, + fieldSymbol.ContainingType, + fieldSymbol.Name, + validationAttributes.Count); + + setterBlock = Block(); + } + else + { + propertyChangedNames.Add(propertyName); + propertyChangingNames.Add(propertyName); + + // Generate the inner setter block as follows: + // + // if (!global::System.Collections.Generic.EqualityComparer<>.Default.Equals(, value)) + // { + // OnPropertyChanging(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNamePropertyChangingEventArgs); // Optional + // = value; + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNamePropertyChangedEventArgs); + // ValidateProperty(value, ); + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.Property1PropertyChangedEventArgs); // Optional + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.Property2PropertyChangedEventArgs); + // ... + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNPropertyChangedEventArgs); + // } + // + // The reason why the code is explicitly generated instead of just calling ObservableValidator.SetProperty() is so that we can + // take advantage of the cached property changed arguments for the current property as well, not just for the dependent ones. + setterBlock = Block( + IfStatement( + PrefixUnaryExpression( + SyntaxKind.LogicalNotExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + GenericName(Identifier("global::System.Collections.Generic.EqualityComparer")) + .AddTypeArgumentListArguments(IdentifierName(typeName)), + IdentifierName("Default")), + IdentifierName("Equals"))) + .AddArgumentListArguments( + Argument(IdentifierName(fieldSymbol.Name)), + Argument(IdentifierName("value")))), + Block( + ExpressionStatement( + InvocationExpression(IdentifierName("OnPropertyChanging")) + .AddArgumentListArguments(Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs"), + IdentifierName($"{propertyName}{nameof(PropertyChangingEventArgs)}"))))), + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(fieldSymbol.Name), + IdentifierName("value"))), + ExpressionStatement( + InvocationExpression(IdentifierName("OnPropertyChanged")) + .AddArgumentListArguments(Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs"), + IdentifierName($"{propertyName}{nameof(PropertyChangedEventArgs)}"))))), + ExpressionStatement( + InvocationExpression(IdentifierName("ValidateProperty")) + .AddArgumentListArguments( + Argument(IdentifierName("value")), + Argument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(propertyName)))))) + .AddStatements(dependentPropertyNotificationStatements.ToArray()))); + } + } + else + { + BlockSyntax updateAndNotificationBlock = Block(); + + // Add OnPropertyChanging(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNamePropertyChangingEventArgs) if necessary + if (isNotifyPropertyChanging) + { + propertyChangingNames.Add(propertyName); + + updateAndNotificationBlock = updateAndNotificationBlock.AddStatements(ExpressionStatement( + InvocationExpression(IdentifierName("OnPropertyChanging")) + .AddArgumentListArguments(Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs"), + IdentifierName($"{propertyName}{nameof(PropertyChangingEventArgs)}")))))); + } + + propertyChangedNames.Add(propertyName); + + // Add the following statements: + // + // = value; + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNamePropertyChangedEventArgs); + updateAndNotificationBlock = updateAndNotificationBlock.AddStatements( + ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + IdentifierName(fieldSymbol.Name), + IdentifierName("value"))), + ExpressionStatement( + InvocationExpression(IdentifierName("OnPropertyChanged")) + .AddArgumentListArguments(Argument(MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs"), + IdentifierName($"{propertyName}{nameof(PropertyChangedEventArgs)}")))))); + + // Add the dependent property notifications at the end + updateAndNotificationBlock = updateAndNotificationBlock.AddStatements(dependentPropertyNotificationStatements.ToArray()); + + // Generate the inner setter block as follows: + // + // if (!global::System.Collections.Generic.EqualityComparer<>.Default.Equals(, value)) + // { + // OnPropertyChanging(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNamePropertyChangingEventArgs); // Optional + // = value; + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNamePropertyChangedEventArgs); + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.Property1PropertyChangedEventArgs); // Optional + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.Property2PropertyChangedEventArgs); + // ... + // OnPropertyChanged(global::Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__KnownINotifyPropertyChangedOrChangingArgs.PropertyNPropertyChangedEventArgs); + // } + setterBlock = Block( + IfStatement( + PrefixUnaryExpression( + SyntaxKind.LogicalNotExpression, + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + GenericName(Identifier("global::System.Collections.Generic.EqualityComparer")) + .AddTypeArgumentListArguments(IdentifierName(typeName)), + IdentifierName("Default")), + IdentifierName("Equals"))) + .AddArgumentListArguments( + Argument(IdentifierName(fieldSymbol.Name)), + Argument(IdentifierName("value")))), + updateAndNotificationBlock)); + } + + // Get the right type for the declared property (including nullability annotations) + TypeSyntax propertyType = IdentifierName(typeName); + + if (fieldSymbol.Type is { IsReferenceType: true, NullableAnnotation: NullableAnnotation.Annotated }) + { + propertyType = NullableType(propertyType); + } + + // Construct the generated property as follows: + // + // + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // [global::System.Diagnostics.DebuggerNonUserCode] + // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + // // Optional + // + // ... + // + // public + // { + // get => ; + // set + // { + // + // } + // } + return + PropertyDeclaration(propertyType, Identifier(propertyName)) + .AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ObservablePropertyGenerator).FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ObservablePropertyGenerator).Assembly.GetName().Version.ToString())))))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage"))))) + .AddAttributeLists(validationAttributes.Select(static a => AttributeList(SingletonSeparatedList(a))).ToArray()) + .WithLeadingTrivia(leadingTrivia) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddAccessorListAccessors( + AccessorDeclaration(SyntaxKind.GetAccessorDeclaration) + .WithExpressionBody(ArrowExpressionClause(IdentifierName(fieldSymbol.Name))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)), + AccessorDeclaration(SyntaxKind.SetAccessorDeclaration) + .WithBody(setterBlock)); + } + + /// + /// Get the generated property name for an input field. + /// + /// The input instance to process. + /// The generated property name for . + [Pure] + private static string GetGeneratedPropertyName(IFieldSymbol fieldSymbol) + { + string propertyName = fieldSymbol.Name; + + if (propertyName.StartsWith("m_")) + { + propertyName = propertyName.Substring(2); + } + else if (propertyName.StartsWith("_")) + { + propertyName = propertyName.TrimStart('_'); + } + + return $"{char.ToUpper(propertyName[0])}{propertyName.Substring(1)}"; + } + + /// + /// Processes the cached property changed/changing args. + /// + /// The input instance to use. + /// The collection of discovered property changed names. + /// The collection of discovered property changing names. + public void OnExecuteForPropertyArgs(GeneratorExecutionContext context, IReadOnlyCollection propertyChangedNames, IReadOnlyCollection propertyChangingNames) + { + if (propertyChangedNames.Count == 0 && + propertyChangingNames.Count == 0) + { + return; + } + + INamedTypeSymbol + propertyChangedEventArgsSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.PropertyChangedEventArgs")!, + propertyChangingEventArgsSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.PropertyChangingEventArgs")!; + + // Create a static method to validate all properties in a given class. + // This code takes a class symbol and produces a compilation unit as follows: + // + // // Licensed to the .NET Foundation under one or more agreements. + // // The .NET Foundation licenses this file to you under the MIT license. + // // See the LICENSE file in the project root for more information. + // + // #pragma warning disable + // + // namespace Microsoft.Toolkit.Mvvm.ComponentModel.__Internals + // { + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // [global::System.Diagnostics.DebuggerNonUserCode] + // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This type is not intended to be used directly by user code")] + // internal static class __KnownINotifyPropertyChangedOrChangingArgs + // { + // + // } + // } + var source = + CompilationUnit().AddMembers( + NamespaceDeclaration(IdentifierName("Microsoft.Toolkit.Mvvm.ComponentModel.__Internals")).WithLeadingTrivia(TriviaList( + Comment("// Licensed to the .NET Foundation under one or more agreements."), + Comment("// The .NET Foundation licenses this file to you under the MIT license."), + Comment("// See the LICENSE file in the project root for more information."), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))).AddMembers( + ClassDeclaration("__KnownINotifyPropertyChangedOrChangingArgs").AddModifiers( + Token(SyntaxKind.InternalKeyword), + Token(SyntaxKind.StaticKeyword)).AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName($"global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().Assembly.GetName().Version.ToString())))))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This type is not intended to be used directly by user code"))))))) + .AddMembers(propertyChangedNames.Select(name => CreateFieldDeclaration(propertyChangedEventArgsSymbol, name)).ToArray()) + .AddMembers(propertyChangingNames.Select(name => CreateFieldDeclaration(propertyChangingEventArgsSymbol, name)).ToArray()))) + .NormalizeWhitespace() + .ToFullString(); + + // Add the partial type + context.AddSource("__KnownINotifyPropertyChangedOrChangingArgs.cs", SourceText.From(source, Encoding.UTF8)); + } + + /// + /// Creates a field declaration for a cached property change name. + /// + /// The type of cached property change argument (either or ). + /// The name of the cached property name. + /// A instance for the input cached property name. + [Pure] + private static FieldDeclarationSyntax CreateFieldDeclaration(INamedTypeSymbol type, string propertyName) + { + // Create a static field with a cached property changed/changing argument for a specified property. + // This code produces a field declaration as follows: + // + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This field is not intended to be referenced directly by user code")] + // public static readonly = new(""); + return + FieldDeclaration( + VariableDeclaration(IdentifierName(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) + .AddVariables( + VariableDeclarator(Identifier($"{propertyName}{type.Name}")) + .WithInitializer(EqualsValueClause( + ImplicitObjectCreationExpression() + .AddArgumentListArguments(Argument( + LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(propertyName)))))))) + .AddModifiers( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.ReadOnlyKeyword)) + .AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This field is not intended to be referenced directly by user code"))))))); + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs new file mode 100644 index 00000000000..7a2228bfde0 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableRecipientGenerator.cs @@ -0,0 +1,131 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for the ObservableRecipientAttribute type. + /// + [Generator] + public sealed class ObservableRecipientGenerator : TransitiveMembersGenerator + { + /// + /// Initializes a new instance of the class. + /// + public ObservableRecipientGenerator() + : base("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableRecipientAttribute") + { + } + + /// + protected override DiagnosticDescriptor TargetTypeErrorDescriptor => ObservableRecipientGeneratorError; + + /// + protected override bool ValidateTargetType( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + [NotNullWhen(false)] out DiagnosticDescriptor? descriptor) + { + INamedTypeSymbol + observableRecipientSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableRecipient")!, + observableObjectSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableObject")!, + observableObjectAttributeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableObjectAttribute")!, + iNotifyPropertyChangedSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.INotifyPropertyChanged")!; + + // Check if the type already inherits from ObservableRecipient + if (classDeclarationSymbol.InheritsFrom(observableRecipientSymbol)) + { + descriptor = DuplicateObservableRecipientError; + + return false; + } + + // In order to use [ObservableRecipient], the target type needs to inherit from ObservableObject, + // or be annotated with [ObservableObject] or [INotifyPropertyChanged] (with additional helpers). + if (!classDeclarationSymbol.InheritsFrom(observableObjectSymbol) && + !classDeclarationSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, observableObjectAttributeSymbol)) && + !classDeclarationSymbol.GetAttributes().Any(a => + SymbolEqualityComparer.Default.Equals(a.AttributeClass, iNotifyPropertyChangedSymbol) && + !a.HasNamedArgument("IncludeAdditionalHelperMethods", false))) + { + descriptor = MissingBaseObservableObjectFunctionalityError; + + return false; + } + + descriptor = null; + + return true; + } + + /// + protected override IEnumerable FilterDeclaredMembers( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + ClassDeclarationSyntax sourceDeclaration) + { + // If the target type has no constructors, generate constructors as well + if (classDeclarationSymbol.InstanceConstructors.Length == 1 && + classDeclarationSymbol.InstanceConstructors[0] is + { + Parameters: { IsEmpty: true }, + DeclaringSyntaxReferences: { IsEmpty: true }, + IsImplicitlyDeclared: true + }) + { + foreach (ConstructorDeclarationSyntax ctor in sourceDeclaration.Members.OfType()) + { + string + text = ctor.NormalizeWhitespace().ToFullString(), + replaced = text.Replace("ObservableRecipient", classDeclarationSymbol.Name); + + // Adjust the visibility of the constructors based on whether the target type is abstract. + // If that is not the case, the constructors have to be declared as public and not protected. + if (!classDeclarationSymbol.IsAbstract) + { + replaced = replaced.Replace("protected", "public"); + } + + yield return (ConstructorDeclarationSyntax)ParseMemberDeclaration(replaced)!; + } + } + + INamedTypeSymbol observableValidatorSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableValidator")!; + + // Skip the SetProperty overloads if the target type inherits from ObservableValidator, to avoid conflicts + if (classDeclarationSymbol.InheritsFrom(observableValidatorSymbol)) + { + foreach (MemberDeclarationSyntax member in sourceDeclaration.Members.Where(static member => member is not ConstructorDeclarationSyntax)) + { + if (member is not MethodDeclarationSyntax { Identifier: { ValueText: "SetProperty" } }) + { + yield return member; + } + } + + yield break; + } + + // If the target type has at least one custom constructor, only generate methods + foreach (MemberDeclarationSyntax member in sourceDeclaration.Members.Where(static member => member is not ConstructorDeclarationSyntax)) + { + yield return member; + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.SyntaxReceiver.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.SyntaxReceiver.cs new file mode 100644 index 00000000000..ece1653b246 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.SyntaxReceiver.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + public sealed partial class ObservableValidatorValidateAllPropertiesGenerator + { + /// + /// An that selects candidate nodes to process. + /// + private sealed class SyntaxReceiver : ISyntaxContextReceiver + { + /// + /// The list of info gathered during exploration. + /// + private readonly List gatheredInfo = new(); + + /// + /// Gets the collection of gathered info to process. + /// + public IReadOnlyCollection GatheredInfo => this.gatheredInfo; + + /// + public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + { + if (context.Node is ClassDeclarationSyntax classDeclaration && + context.SemanticModel.GetDeclaredSymbol(classDeclaration) is INamedTypeSymbol { IsGenericType: false } classSymbol && + context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.ComponentModel.ObservableValidator") is INamedTypeSymbol validatorSymbol && + classSymbol.InheritsFrom(validatorSymbol)) + { + this.gatheredInfo.Add(classSymbol); + } + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs new file mode 100644 index 00000000000..45a6bb3d449 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/ObservableValidatorValidateAllPropertiesGenerator.cs @@ -0,0 +1,214 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Diagnostics.Contracts; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for properties validation without relying on compiled LINQ expressions. + /// + [Generator] + public sealed partial class ObservableValidatorValidateAllPropertiesGenerator : ISourceGenerator + { + /// + public void Initialize(GeneratorInitializationContext context) + { + context.RegisterForSyntaxNotifications(static () => new SyntaxReceiver()); + } + + /// + public void Execute(GeneratorExecutionContext context) + { + // Get the syntax receiver with the candidate nodes + if (context.SyntaxContextReceiver is not SyntaxReceiver syntaxReceiver || + syntaxReceiver.GatheredInfo.Count == 0) + { + return; + } + + // Get the symbol for the ValidationAttribute type + INamedTypeSymbol validationSymbol = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.ValidationAttribute")!; + + // Prepare the attributes to add to the first class declaration + AttributeListSyntax[] classAttributes = new[] + { + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName($"global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().Assembly.GetName().Version.ToString())))))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This type is not intended to be used directly by user code")))))) + }; + + foreach (INamedTypeSymbol classSymbol in syntaxReceiver.GatheredInfo) + { + // Create a static factory method creating a delegate that can be used to validate all properties in a given class. + // This pattern is used so that the library doesn't have to use MakeGenericType(...) at runtime, nor use unsafe casts + // over the created delegate to be able to cache it as an Action instance. This pattern enables the same + // functionality and with almost identical performance (not noticeable in this context anyway), but while preserving + // full runtime type safety (as a safe cast is used to validate the input argument), and with less reflection needed. + // Note that we're deliberately creating a new delegate instance here and not using code that could see the C# compiler + // create a static class to cache a reusable delegate, because each generated method will only be called at most once, + // as the returned delegate will be cached by the MVVM Toolkit itself. So this ensures the the produced code is minimal, + // and that there will be no unnecessary static fields and objects being created and possibly never collected. + // This code takes a class symbol and produces a compilation unit as follows: + // + // // Licensed to the .NET Foundation under one or more agreements. + // // The .NET Foundation licenses this file to you under the MIT license. + // // See the LICENSE file in the project root for more information. + // + // #pragma warning disable + // + // namespace Microsoft.Toolkit.Mvvm.ComponentModel.__Internals + // { + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // [global::System.Diagnostics.DebuggerNonUserCode] + // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This type is not intended to be used directly by user code")] + // internal static partial class __ObservableValidatorExtensions + // { + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This method is not intended to be called directly by user code")] + // public static global::System.Action CreateAllPropertiesValidator( _) + // { + // static void ValidateAllProperties(object obj) + // { + // var instance = ()obj; + // + // } + // + // return ValidateAllProperties; + // } + // } + // } + var source = + CompilationUnit().AddMembers( + NamespaceDeclaration(IdentifierName("Microsoft.Toolkit.Mvvm.ComponentModel.__Internals")).WithLeadingTrivia(TriviaList( + Comment("// Licensed to the .NET Foundation under one or more agreements."), + Comment("// The .NET Foundation licenses this file to you under the MIT license."), + Comment("// See the LICENSE file in the project root for more information."), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))).AddMembers( + ClassDeclaration("__ObservableValidatorExtensions").AddModifiers( + Token(SyntaxKind.InternalKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.PartialKeyword)).AddAttributeLists(classAttributes).AddMembers( + MethodDeclaration( + GenericName("global::System.Action").AddTypeArgumentListArguments(PredefinedType(Token(SyntaxKind.ObjectKeyword))), + Identifier("CreateAllPropertiesValidator")).AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This method is not intended to be called directly by user code"))))))).AddModifiers( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)).AddParameterListParameters( + Parameter(Identifier("_")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))) + .WithBody(Block( + LocalFunctionStatement( + PredefinedType(Token(SyntaxKind.VoidKeyword)), + Identifier("ValidateAllProperties")) + .AddModifiers(Token(SyntaxKind.StaticKeyword)) + .AddParameterListParameters( + Parameter(Identifier("obj")).WithType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) + .WithBody(Block( + LocalDeclarationStatement( + VariableDeclaration(IdentifierName("var")) // Cannot Token(SyntaxKind.VarKeyword) here (throws an ArgumentException) + .AddVariables( + VariableDeclarator(Identifier("instance")) + .WithInitializer(EqualsValueClause( + CastExpression( + IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + IdentifierName("obj"))))))) + .AddStatements(EnumerateValidationStatements(classSymbol, validationSymbol).ToArray())), + ReturnStatement(IdentifierName("ValidateAllProperties"))))))) + .NormalizeWhitespace() + .ToFullString(); + + // Reset the attributes list (so the same class doesn't get duplicate attributes) + classAttributes = Array.Empty(); + + // Add the partial type + context.AddSource($"{classSymbol.GetFullMetadataNameForFileName()}.cs", SourceText.From(source, Encoding.UTF8)); + } + } + + /// + /// Gets a sequence of statements to validate declared properties. + /// + /// The input instance to process. + /// The type symbol for the ValidationAttribute type. + /// The sequence of instances to validate declared properties. + [Pure] + private static IEnumerable EnumerateValidationStatements(INamedTypeSymbol classSymbol, INamedTypeSymbol validationSymbol) + { + foreach (var propertySymbol in classSymbol.GetMembers().OfType()) + { + if (propertySymbol.IsIndexer) + { + continue; + } + + ImmutableArray attributes = propertySymbol.GetAttributes(); + + if (!attributes.Any(a => a.AttributeClass?.InheritsFrom(validationSymbol) == true)) + { + continue; + } + + // This enumerator produces a sequence of statements as follows: + // + // __ObservableValidatorHelper.ValidateProperty(instance, instance., nameof(instance.)); + // __ObservableValidatorHelper.ValidateProperty(instance, instance., nameof(instance.)); + // ... + // __ObservableValidatorHelper.ValidateProperty(instance, instance., nameof(instance.)); + yield return + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("__ObservableValidatorHelper"), + IdentifierName("ValidateProperty"))) + .AddArgumentListArguments( + Argument(IdentifierName("instance")), + Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("instance"), + IdentifierName(propertySymbol.Name))), + Argument( + InvocationExpression(IdentifierName("nameof")) + .AddArgumentListArguments(Argument( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("instance"), + IdentifierName(propertySymbol.Name))))))); + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.SyntaxReceiver.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.SyntaxReceiver.cs new file mode 100644 index 00000000000..b370089d644 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.SyntaxReceiver.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + public abstract partial class TransitiveMembersGenerator + { + /// + /// An that selects candidate nodes to process. + /// + private sealed class SyntaxReceiver : ISyntaxContextReceiver + { + /// + /// The fully qualified name of the attribute type to look for. + /// + private readonly string attributeTypeFullName; + + /// + /// The list of info gathered during exploration. + /// + private readonly List gatheredInfo = new(); + + /// + /// Initializes a new instance of the class. + /// + /// The fully qualified name of the attribute type to look for. + public SyntaxReceiver(string attributeTypeFullName) + { + this.attributeTypeFullName = attributeTypeFullName; + } + + /// + /// Gets the collection of gathered info to process. + /// + public IReadOnlyCollection GatheredInfo => this.gatheredInfo; + + /// + public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + { + if (context.Node is ClassDeclarationSyntax { AttributeLists: { Count: > 0 } } classDeclaration && + context.SemanticModel.GetDeclaredSymbol(classDeclaration) is INamedTypeSymbol classSymbol && + context.SemanticModel.Compilation.GetTypeByMetadataName(this.attributeTypeFullName) is INamedTypeSymbol attributeSymbol && + classSymbol.GetAttributes().FirstOrDefault(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, attributeSymbol)) is AttributeData attributeData && + attributeData.ApplicationSyntaxReference is SyntaxReference syntaxReference && + syntaxReference.GetSyntax() is AttributeSyntax attributeSyntax) + { + this.gatheredInfo.Add(new Item(classDeclaration, classSymbol, attributeSyntax, attributeData)); + } + } + + /// + /// A model for a group of item representing a discovered type to process. + /// + /// The instance for the target class declaration. + /// The instance for . + /// The instance for the target attribute over . + /// The instance for . + public sealed record Item( + ClassDeclarationSyntax ClassDeclaration, + INamedTypeSymbol ClassSymbol, + AttributeSyntax AttributeSyntax, + AttributeData AttributeData); + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs new file mode 100644 index 00000000000..bd3f48e284c --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/ComponentModel/TransitiveMembersGenerator.cs @@ -0,0 +1,279 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.CodeAnalysis.SymbolDisplayTypeQualificationStyle; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for a given attribute type. + /// + public abstract partial class TransitiveMembersGenerator : ISourceGenerator + { + /// + /// The fully qualified name of the attribute type to look for. + /// + private readonly string attributeTypeFullName; + + /// + /// The name of the attribute type to look for. + /// + private readonly string attributeTypeName; + + /// + /// Initializes a new instance of the class. + /// + /// The fully qualified name of the attribute type to look for. + protected TransitiveMembersGenerator(string attributeTypeFullName) + { + this.attributeTypeFullName = attributeTypeFullName; + this.attributeTypeName = attributeTypeFullName.Split('.').Last(); + } + + /// + /// Gets a indicating when the generation failed for a given type. + /// + protected abstract DiagnosticDescriptor TargetTypeErrorDescriptor { get; } + + /// + public void Initialize(GeneratorInitializationContext context) + { + context.RegisterForSyntaxNotifications(() => new SyntaxReceiver(this.attributeTypeFullName)); + } + + /// + public void Execute(GeneratorExecutionContext context) + { + // Get the syntax receiver with the candidate nodes + if (context.SyntaxContextReceiver is not SyntaxReceiver syntaxReceiver || + syntaxReceiver.GatheredInfo.Count == 0) + { + return; + } + + // Load the syntax tree with the members to generate + SyntaxTree sourceSyntaxTree = LoadSourceSyntaxTree(); + + foreach (SyntaxReceiver.Item item in syntaxReceiver.GatheredInfo) + { + if (!ValidateTargetType(context, item.AttributeData, item.ClassDeclaration, item.ClassSymbol, out var descriptor)) + { + context.ReportDiagnostic(descriptor, item.AttributeSyntax, item.ClassSymbol); + + continue; + } + + try + { + OnExecute(context, item.AttributeData, item.ClassDeclaration, item.ClassSymbol, sourceSyntaxTree); + } + catch + { + context.ReportDiagnostic(TargetTypeErrorDescriptor, item.AttributeSyntax, item.ClassSymbol); + } + } + } + + /// + /// Loads the source syntax tree for the current generator. + /// + /// The syntax tree with the elements to emit in the generated code. + [Pure] + private SyntaxTree LoadSourceSyntaxTree() + { + string filename = $"Microsoft.Toolkit.Mvvm.SourceGenerators.EmbeddedResources.{this.attributeTypeName.Replace("Attribute", string.Empty)}.cs"; + + Stream stream = Assembly.GetExecutingAssembly().GetManifestResourceStream(filename); + StreamReader reader = new(stream); + + string observableObjectSource = reader.ReadToEnd(); + + return CSharpSyntaxTree.ParseText(observableObjectSource); + } + + /// + /// Processes a given target type. + /// + /// The input instance to use. + /// The for the current attribute being processed. + /// The node to process. + /// The for . + /// The for the target parsed source. + private void OnExecute( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + SyntaxTree sourceSyntaxTree) + { + ClassDeclarationSyntax sourceDeclaration = sourceSyntaxTree.GetRoot().DescendantNodes().OfType().First(); + UsingDirectiveSyntax[] usingDirectives = sourceSyntaxTree.GetRoot().DescendantNodes().OfType().ToArray(); + BaseListSyntax? baseListSyntax = BaseList(SeparatedList( + sourceDeclaration.BaseList?.Types + .OfType() + .Select(static t => t.Type) + .OfType() + .Where(static t => t.Identifier.ValueText.StartsWith("I")) + .Select(static t => SimpleBaseType(t)) + .ToArray() + ?? Array.Empty())); + + if (baseListSyntax.Types.Count == 0) + { + baseListSyntax = null; + } + + // Create the class declaration for the user type. This will produce a tree as follows: + // + // : + // { + // + // } + var classDeclarationSyntax = + ClassDeclaration(classDeclaration.Identifier.Text) + .WithModifiers(classDeclaration.Modifiers) + .WithBaseList(baseListSyntax) + .AddMembers(OnLoadDeclaredMembers(context, attributeData, classDeclaration, classDeclarationSymbol, sourceDeclaration).ToArray()); + + TypeDeclarationSyntax typeDeclarationSyntax = classDeclarationSyntax; + + // Add all parent types in ascending order, if any + foreach (var parentType in classDeclaration.Ancestors().OfType()) + { + typeDeclarationSyntax = parentType + .WithMembers(SingletonList(typeDeclarationSyntax)) + .WithConstraintClauses(List()) + .WithBaseList(null) + .WithAttributeLists(List()) + .WithoutTrivia(); + } + + // Create the compilation unit with the namespace and target member. + // From this, we can finally generate the source code to output. + var namespaceName = classDeclarationSymbol.ContainingNamespace.ToDisplayString(new(typeQualificationStyle: NameAndContainingTypesAndNamespaces)); + + // Create the final compilation unit to generate (with using directives and the full type declaration) + var source = + CompilationUnit() + .AddMembers(NamespaceDeclaration(IdentifierName(namespaceName)) + .AddMembers(typeDeclarationSyntax)) + .AddUsings(usingDirectives.First().WithLeadingTrivia(TriviaList( + Comment("// Licensed to the .NET Foundation under one or more agreements."), + Comment("// The .NET Foundation licenses this file to you under the MIT license."), + Comment("// See the LICENSE file in the project root for more information."), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))))) + .AddUsings(usingDirectives.Skip(1).ToArray()) + .NormalizeWhitespace() + .ToFullString(); + + // Add the partial type + context.AddSource($"{classDeclarationSymbol.GetFullMetadataNameForFileName()}.cs", SourceText.From(source, Encoding.UTF8)); + } + + /// + /// Loads the nodes to generate from the input parsed tree. + /// + /// The input instance to use. + /// The for the current attribute being processed. + /// The node to process. + /// The for . + /// The parsed instance with the source nodes. + /// A sequence of nodes to emit in the generated file. + private IEnumerable OnLoadDeclaredMembers( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + ClassDeclarationSyntax sourceDeclaration) + { + IEnumerable generatedMembers = FilterDeclaredMembers(context, attributeData, classDeclaration, classDeclarationSymbol, sourceDeclaration); + + // Add the attributes on each member + return generatedMembers.Select(member => + { + // [GeneratedCode] is always present + member = member + .WithoutLeadingTrivia() + .AddAttributeLists(AttributeList(SingletonSeparatedList( + Attribute(IdentifierName($"global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().Assembly.GetName().Version.ToString()))))))) + .WithLeadingTrivia(member.GetLeadingTrivia()); + + // [DebuggerNonUserCode] is not supported over interfaces, events or fields + if (member.Kind() is not SyntaxKind.InterfaceDeclaration and not SyntaxKind.EventFieldDeclaration and not SyntaxKind.FieldDeclaration) + { + member = member.AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode"))))); + } + + // [ExcludeFromCodeCoverage] is not supported on interfaces and fields + if (member.Kind() is not SyntaxKind.InterfaceDeclaration and not SyntaxKind.FieldDeclaration) + { + member = member.AddAttributeLists(AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage"))))); + } + + // If the target class is sealed, make protected members private and remove the virtual modifier + if (classDeclarationSymbol.IsSealed) + { + return member + .ReplaceModifier(SyntaxKind.ProtectedKeyword, SyntaxKind.PrivateKeyword) + .RemoveModifier(SyntaxKind.VirtualKeyword); + } + + return member; + }); + } + + /// + /// Validates a target type being processed. + /// + /// The input instance to use. + /// The for the current attribute being processed. + /// The node to process. + /// The for . + /// The resulting to emit in case the target type isn't valid. + /// Whether or not the target type is valid and can be processed normally. + protected abstract bool ValidateTargetType( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + [NotNullWhen(false)] out DiagnosticDescriptor? descriptor); + + /// + /// Filters the nodes to generate from the input parsed tree. + /// + /// The input instance to use. + /// The for the current attribute being processed. + /// The node to process. + /// The for . + /// The parsed instance with the source nodes. + /// A sequence of nodes to emit in the generated file. + protected virtual IEnumerable FilterDeclaredMembers( + GeneratorExecutionContext context, + AttributeData attributeData, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + ClassDeclarationSyntax sourceDeclaration) + { + return sourceDeclaration.Members; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs new file mode 100644 index 00000000000..81fa09ee7ee --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticDescriptors.cs @@ -0,0 +1,207 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.ComponentModel; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics +{ + /// + /// A container for all instances for errors reported by analyzers in this project. + /// + internal static class DiagnosticDescriptors + { + /// + /// Gets a indicating when failed to run on a given type. + /// + /// Format: "The generator INotifyPropertyChangedGenerator failed to execute on type {0}". + /// + /// + public static readonly DiagnosticDescriptor INotifyPropertyChangedGeneratorError = new( + id: "MVVMTK0001", + title: $"Internal error for {nameof(INotifyPropertyChangedGenerator)}", + messageFormat: $"The generator {nameof(INotifyPropertyChangedGenerator)} failed to execute on type {{0}}", + category: typeof(INotifyPropertyChangedGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"The {nameof(INotifyPropertyChangedGenerator)} generator encountered an error while processing a type. Please report this issue at https://aka.ms/mvvmtoolkit.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when failed to run on a given type. + /// + /// Format: "The generator ObservableObjectGenerator failed to execute on type {0}". + /// + /// + public static readonly DiagnosticDescriptor ObservableObjectGeneratorError = new( + id: "MVVMTK0002", + title: $"Internal error for {nameof(ObservableObjectGenerator)}", + messageFormat: $"The generator {nameof(ObservableObjectGenerator)} failed to execute on type {{0}}", + category: typeof(ObservableObjectGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"The {nameof(ObservableObjectGenerator)} generator encountered an error while processing a type. Please report this issue at https://aka.ms/mvvmtoolkit.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when failed to run on a given type. + /// + /// Format: "The generator ObservableRecipientGenerator failed to execute on type {0}". + /// + /// + public static readonly DiagnosticDescriptor ObservableRecipientGeneratorError = new( + id: "MVVMTK0003", + title: $"Internal error for {nameof(ObservableRecipientGenerator)}", + messageFormat: $"The generator {nameof(ObservableRecipientGenerator)} failed to execute on type {{0}}", + category: typeof(ObservableRecipientGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"The {nameof(ObservableRecipientGenerator)} generator encountered an error while processing a type. Please report this issue at https://aka.ms/mvvmtoolkit.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when a duplicate declaration of would happen. + /// + /// Format: "Cannot apply [INotifyPropertyChangedAttribute] to type {0}, as it already declares the INotifyPropertyChanged interface". + /// + /// + public static readonly DiagnosticDescriptor DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError = new( + id: "MVVMTK0004", + title: $"Duplicate {nameof(INotifyPropertyChanged)} definition", + messageFormat: $"Cannot apply [INotifyPropertyChanged] to type {{0}}, as it already declares the {nameof(INotifyPropertyChanged)} interface", + category: typeof(INotifyPropertyChangedGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [INotifyPropertyChanged] to a type that already declares the {nameof(INotifyPropertyChanged)} interface.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when a duplicate declaration of would happen. + /// + /// Format: "Cannot apply [ObservableObjectAttribute] to type {0}, as it already declares the INotifyPropertyChanged interface". + /// + /// + public static readonly DiagnosticDescriptor DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError = new( + id: "MVVMTK0005", + title: $"Duplicate {nameof(INotifyPropertyChanged)} definition", + messageFormat: $"Cannot apply [ObservableObject] to type {{0}}, as it already declares the {nameof(INotifyPropertyChanged)} interface", + category: typeof(ObservableObjectGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [ObservableObject] to a type that already declares the {nameof(INotifyPropertyChanged)} interface.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when a duplicate declaration of would happen. + /// + /// Format: "Cannot apply [ObservableObjectAttribute] to type {0}, as it already declares the INotifyPropertyChanging interface". + /// + /// + public static readonly DiagnosticDescriptor DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError = new( + id: "MVVMTK0006", + title: $"Duplicate {nameof(INotifyPropertyChanging)} definition", + messageFormat: $"Cannot apply [ObservableObject] to type {{0}}, as it already declares the {nameof(INotifyPropertyChanging)} interface", + category: typeof(ObservableObjectGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [ObservableObject] to a type that already declares the {nameof(INotifyPropertyChanging)} interface.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when a duplicate declaration of would happen. + /// + /// Format: "Cannot apply [ObservableRecipientAttribute] to type {0}, as it already inherits from the ObservableRecipient class". + /// + /// + public static readonly DiagnosticDescriptor DuplicateObservableRecipientError = new( + id: "MVVMTK0007", + title: "Duplicate ObservableRecipient definition", + messageFormat: $"Cannot apply [ObservableRecipient] to type {{0}}, as it already inherits from the ObservableRecipient class", + category: typeof(ObservableRecipientGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [ObservableRecipient] to a type that already inherits from the ObservableRecipient class.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when there is a missing base functionality to enable ObservableRecipientAttribute. + /// + /// Format: "Cannot apply [ObservableRecipientAttribute] to type {0}, as it lacks necessary base functionality (it should either inherit from ObservableObject, or be annotated with [ObservableObjectAttribute] or [INotifyPropertyChangedAttribute])". + /// + /// + public static readonly DiagnosticDescriptor MissingBaseObservableObjectFunctionalityError = new( + id: "MVVMTK0008", + title: "Missing base ObservableObject functionality", + messageFormat: $"Cannot apply [ObservableRecipient] to type {{0}}, as it lacks necessary base functionality (it should either inherit from ObservableObject, or be annotated with [ObservableObject] or [INotifyPropertyChanged])", + category: typeof(ObservableRecipientGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [ObservableRecipient] to a type that lacks necessary base functionality (it should either inherit from ObservableObject, or be annotated with [ObservableObject] or [INotifyPropertyChanged]).", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when the target type doesn't inherit from the ObservableValidator class. + /// + /// Format: "The field {0}.{1} cannot be used to generate an observable property, as it has {2} validation attribute(s) but is declared in a type that doesn't inherit from ObservableValidator". + /// + /// + public static readonly DiagnosticDescriptor MissingObservableValidatorInheritanceError = new( + id: "MVVMTK0009", + title: "Missing ObservableValidator inheritance", + messageFormat: "The field {0}.{1} cannot be used to generate an observable property, as it has {2} validation attribute(s) but is declared in a type that doesn't inherit from ObservableValidator", + category: typeof(ObservablePropertyGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [ObservableProperty] to fields with validation attributes if they are declared in a type that doesn't inherit from ObservableValidator.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when failed to run on a given type. + /// + /// Format: "The generator ObservablePropertyGenerator failed to execute on type {0}". + /// + /// + public static readonly DiagnosticDescriptor ObservablePropertyGeneratorError = new( + id: "MVVMTK0010", + title: $"Internal error for {nameof(ObservablePropertyGenerator)}", + messageFormat: $"The generator {nameof(ObservablePropertyGenerator)} failed to execute on type {{0}}", + category: typeof(ObservableObjectGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"The {nameof(ObservablePropertyGenerator)} generator encountered an error while processing a type. Please report this issue at https://aka.ms/mvvmtoolkit.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when failed to run on a given type. + /// + /// Format: "The generator ICommandGenerator failed to execute on type {0}". + /// + /// + public static readonly DiagnosticDescriptor ICommandGeneratorError = new( + id: "MVVMTK0011", + title: $"Internal error for {nameof(ICommandGenerator)}", + messageFormat: $"The generator {nameof(ICommandGenerator)} failed to execute on type {{0}}", + category: typeof(ICommandGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"The {nameof(ICommandGenerator)} generator encountered an error while processing a type. Please report this issue at https://aka.ms/mvvmtoolkit.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + + /// + /// Gets a indicating when an annotated method to generate a command for has an invalid signature. + /// + /// Format: "The method {0}.{1} cannot be used to generate a command property, as its signature isn't compatible with any of the existing relay command types". + /// + /// + public static readonly DiagnosticDescriptor InvalidICommandMethodSignatureError = new( + id: "MVVMTK0012", + title: "Invalid ICommand method signature", + messageFormat: "The method {0}.{1} cannot be used to generate a command property, as its signature isn't compatible with any of the existing relay command types", + category: typeof(ICommandGenerator).FullName, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true, + description: $"Cannot apply [ICommand] to methods with a signature that doesn't match any of the existing relay command types.", + helpLinkUri: "https://aka.ms/mvvmtoolkit"); + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticExtensions.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticExtensions.cs new file mode 100644 index 00000000000..5e7e6830f88 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Diagnostics/DiagnosticExtensions.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics +{ + /// + /// Extension methods for , specifically for reporting diagnostics. + /// + internal static class DiagnosticExtensions + { + /// + /// Adds a new diagnostics to the current compilation. + /// + /// The instance currently in use. + /// The input for the diagnostics to create. + /// The source to attach the diagnostics to. + /// The optional arguments for the formatted message to include. + public static void ReportDiagnostic( + this GeneratorExecutionContext context, + DiagnosticDescriptor descriptor, + ISymbol symbol, + params object[] args) + { + context.ReportDiagnostic(Diagnostic.Create(descriptor, symbol.Locations.FirstOrDefault(), args)); + } + + /// + /// Adds a new diagnostics to the current compilation. + /// + /// The instance currently in use. + /// The input for the diagnostics to create. + /// The source to attach the diagnostics to. + /// The optional arguments for the formatted message to include. + public static void ReportDiagnostic( + this GeneratorExecutionContext context, + DiagnosticDescriptor descriptor, + SyntaxNode node, + params object[] args) + { + context.ReportDiagnostic(Diagnostic.Create(descriptor, node.GetLocation(), args)); + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/EmbeddedResources/INotifyPropertyChanged.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/EmbeddedResources/INotifyPropertyChanged.cs new file mode 100644 index 00000000000..72d14de5394 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/EmbeddedResources/INotifyPropertyChanged.cs @@ -0,0 +1,496 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma warning disable + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Runtime.CompilerServices; +using System.Threading.Tasks; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel +{ + /// + /// A base class for objects implementing . + /// + public abstract class NotifyPropertyChanged : INotifyPropertyChanged + { + /// + public event PropertyChangedEventHandler? PropertyChanged; + + /// + /// Raises the event. + /// + /// The input instance. + protected virtual void OnPropertyChanged(PropertyChangedEventArgs e) + { + PropertyChanged?.Invoke(this, e); + } + + /// + /// Raises the event. + /// + /// (optional) The name of the property that changed. + protected void OnPropertyChanged([CallerMemberName] string? propertyName = null) + { + OnPropertyChanged(new PropertyChangedEventArgs(propertyName)); + } + + /// + /// Compares the current and new values for a given property. If the value has changed, updates + /// the property with the new value, then raises the event. + /// + /// The type of the property that changed. + /// The field storing the property's value. + /// The property's value after the change occurred. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are the same. + /// + protected bool SetProperty(ref T field, T newValue, [CallerMemberName] string? propertyName = null) + { + if (EqualityComparer.Default.Equals(field, newValue)) + { + return false; + } + + field = newValue; + + OnPropertyChanged(propertyName); + + return true; + } + + /// + /// Compares the current and new values for a given property. If the value has changed, updates + /// the property with the new value, then raises the event. + /// See additional notes about this overload in . + /// + /// The type of the property that changed. + /// The field storing the property's value. + /// The property's value after the change occurred. + /// The instance to use to compare the input values. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + protected bool SetProperty(ref T field, T newValue, IEqualityComparer comparer, [CallerMemberName] string? propertyName = null) + { + if (comparer.Equals(field, newValue)) + { + return false; + } + + field = newValue; + + OnPropertyChanged(propertyName); + + return true; + } + + /// + /// Compares the current and new values for a given property. If the value has changed, updates + /// the property with the new value, then raises the event. + /// This overload is much less efficient than and it + /// should only be used when the former is not viable (eg. when the target property being + /// updated does not directly expose a backing field that can be passed by reference). + /// For performance reasons, it is recommended to use a stateful callback if possible through + /// the whenever possible + /// instead of this overload, as that will allow the C# compiler to cache the input callback and + /// reduce the memory allocations. More info on that overload are available in the related XML + /// docs. This overload is here for completeness and in cases where that is not applicable. + /// + /// The type of the property that changed. + /// The current property value. + /// The property's value after the change occurred. + /// A callback to invoke to update the property value. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are the same. + /// + protected bool SetProperty(T oldValue, T newValue, Action callback, [CallerMemberName] string? propertyName = null) + { + if (EqualityComparer.Default.Equals(oldValue, newValue)) + { + return false; + } + + callback(newValue); + + OnPropertyChanged(propertyName); + + return true; + } + + /// + /// Compares the current and new values for a given property. If the value has changed, updates + /// the property with the new value, then raises the event. + /// See additional notes about this overload in . + /// + /// The type of the property that changed. + /// The current property value. + /// The property's value after the change occurred. + /// The instance to use to compare the input values. + /// A callback to invoke to update the property value. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + protected bool SetProperty(T oldValue, T newValue, IEqualityComparer comparer, Action callback, [CallerMemberName] string? propertyName = null) + { + if (comparer.Equals(oldValue, newValue)) + { + return false; + } + + callback(newValue); + + OnPropertyChanged(propertyName); + + return true; + } + + /// + /// Compares the current and new values for a given nested property. If the value has changed, + /// updates the property and then raises the event. + /// The behavior mirrors that of , + /// with the difference being that this method is used to relay properties from a wrapped model in the + /// current instance. This type is useful when creating wrapping, bindable objects that operate over + /// models that lack support for notification (eg. for CRUD operations). + /// Suppose we have this model (eg. for a database row in a table): + /// + /// public class Person + /// { + /// public string Name { get; set; } + /// } + /// + /// We can then use a property to wrap instances of this type into our observable model (which supports + /// notifications), injecting the notification to the properties of that model, like so: + /// + /// [INotifyPropertyChanged] + /// public partial class BindablePerson + /// { + /// public Model { get; } + /// + /// public BindablePerson(Person model) + /// { + /// Model = model; + /// } + /// + /// public string Name + /// { + /// get => Model.Name; + /// set => Set(Model.Name, value, Model, (model, name) => model.Name = name); + /// } + /// } + /// + /// This way we can then use the wrapping object in our application, and all those "proxy" properties will + /// also raise notifications when changed. Note that this method is not meant to be a replacement for + /// , and it should only be used when relaying properties to a model that + /// doesn't support notifications, and only if you can't implement notifications to that model directly (eg. by having + /// it inherit from ). The syntax relies on passing the target model and a stateless callback + /// to allow the C# compiler to cache the function, which results in much better performance and no memory usage. + /// + /// The type of model whose property (or field) to set. + /// The type of property (or field) to set. + /// The current property value. + /// The property's value after the change occurred. + /// The model containing the property being updated. + /// The callback to invoke to set the target property value, if a change has occurred. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are the same. + /// + protected bool SetProperty(T oldValue, T newValue, TModel model, Action callback, [CallerMemberName] string? propertyName = null) + where TModel : class + { + if (EqualityComparer.Default.Equals(oldValue, newValue)) + { + return false; + } + + callback(model, newValue); + + OnPropertyChanged(propertyName); + + return true; + } + + /// + /// Compares the current and new values for a given nested property. If the value has changed, + /// updates the property and then raises the event. + /// The behavior mirrors that of , + /// with the difference being that this method is used to relay properties from a wrapped model in the + /// current instance. See additional notes about this overload in . + /// + /// The type of model whose property (or field) to set. + /// The type of property (or field) to set. + /// The current property value. + /// The property's value after the change occurred. + /// The instance to use to compare the input values. + /// The model containing the property being updated. + /// The callback to invoke to set the target property value, if a change has occurred. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + protected bool SetProperty(T oldValue, T newValue, IEqualityComparer comparer, TModel model, Action callback, [CallerMemberName] string? propertyName = null) + where TModel : class + { + if (comparer.Equals(oldValue, newValue)) + { + return false; + } + + callback(model, newValue); + + OnPropertyChanged(propertyName); + + return true; + } + + /// + /// Compares the current and new values for a given field (which should be the backing field for a property). + /// If the value has changed, updates the field and then raises the event. + /// The behavior mirrors that of , with the difference being that + /// this method will also monitor the new value of the property (a generic ) and will also + /// raise the again for the target property when it completes. + /// This can be used to update bindings observing that or any of its properties. + /// This method and its overload specifically rely on the type, which needs + /// to be used in the backing field for the target property. The field doesn't need to be + /// initialized, as this method will take care of doing that automatically. The + /// type also includes an implicit operator, so it can be assigned to any instance directly. + /// Here is a sample property declaration using this method: + /// + /// private TaskNotifier myTask; + /// + /// public Task MyTask + /// { + /// get => myTask; + /// private set => SetAndNotifyOnCompletion(ref myTask, value); + /// } + /// + /// + /// The field notifier to modify. + /// The property's value after the change occurred. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are + /// the same. The return value being only indicates that the new value being assigned to + /// is different than the previous one, and it does not mean the new + /// instance passed as argument is in any particular state. + /// + protected bool SetPropertyAndNotifyOnCompletion(ref TaskNotifier? taskNotifier, Task? newValue, [CallerMemberName] string? propertyName = null) + { + return SetPropertyAndNotifyOnCompletion(taskNotifier ??= new(), newValue, static _ => { }, propertyName); + } + + /// + /// Compares the current and new values for a given field (which should be the backing field for a property). + /// If the value has changed, updates the field and then raises the event. + /// This method is just like , + /// with the difference being an extra parameter with a callback being invoked + /// either immediately, if the new task has already completed or is , or upon completion. + /// + /// The field notifier to modify. + /// The property's value after the change occurred. + /// A callback to invoke to update the property value. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are the same. + /// + protected bool SetPropertyAndNotifyOnCompletion(ref TaskNotifier? taskNotifier, Task? newValue, Action callback, [CallerMemberName] string? propertyName = null) + { + return SetPropertyAndNotifyOnCompletion(taskNotifier ??= new(), newValue, callback, propertyName); + } + + /// + /// Compares the current and new values for a given field (which should be the backing field for a property). + /// If the value has changed, updates the field and then raises the event. + /// The behavior mirrors that of , with the difference being that + /// this method will also monitor the new value of the property (a generic ) and will also + /// raise the again for the target property when it completes. + /// This can be used to update bindings observing that or any of its properties. + /// This method and its overload specifically rely on the type, which needs + /// to be used in the backing field for the target property. The field doesn't need to be + /// initialized, as this method will take care of doing that automatically. The + /// type also includes an implicit operator, so it can be assigned to any instance directly. + /// Here is a sample property declaration using this method: + /// + /// private TaskNotifier<int> myTask; + /// + /// public Task<int> MyTask + /// { + /// get => myTask; + /// private set => SetAndNotifyOnCompletion(ref myTask, value); + /// } + /// + /// + /// The type of result for the to set and monitor. + /// The field notifier to modify. + /// The property's value after the change occurred. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are + /// the same. The return value being only indicates that the new value being assigned to + /// is different than the previous one, and it does not mean the new + /// instance passed as argument is in any particular state. + /// + protected bool SetPropertyAndNotifyOnCompletion(ref TaskNotifier? taskNotifier, Task? newValue, [CallerMemberName] string? propertyName = null) + { + return SetPropertyAndNotifyOnCompletion(taskNotifier ??= new(), newValue, static _ => { }, propertyName); + } + + /// + /// Compares the current and new values for a given field (which should be the backing field for a property). + /// If the value has changed, updates the field and then raises the event. + /// This method is just like , + /// with the difference being an extra parameter with a callback being invoked + /// either immediately, if the new task has already completed or is , or upon completion. + /// + /// The type of result for the to set and monitor. + /// The field notifier to modify. + /// The property's value after the change occurred. + /// A callback to invoke to update the property value. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + /// + /// The event is not raised if the current and new value for the target property are the same. + /// + protected bool SetPropertyAndNotifyOnCompletion(ref TaskNotifier? taskNotifier, Task? newValue, Action?> callback, [CallerMemberName] string? propertyName = null) + { + return SetPropertyAndNotifyOnCompletion(taskNotifier ??= new(), newValue, callback, propertyName); + } + + /// + /// Implements the notification logic for the related methods. + /// + /// The type of to set and monitor. + /// The field notifier. + /// The property's value after the change occurred. + /// A callback to invoke to update the property value. + /// (optional) The name of the property that changed. + /// if the property was changed, otherwise. + private bool SetPropertyAndNotifyOnCompletion(ITaskNotifier taskNotifier, TTask? newValue, Action callback, [CallerMemberName] string? propertyName = null) + where TTask : Task + { + if (ReferenceEquals(taskNotifier.Task, newValue)) + { + return false; + } + + bool isAlreadyCompletedOrNull = newValue?.IsCompleted ?? true; + + taskNotifier.Task = newValue; + + OnPropertyChanged(propertyName); + + if (isAlreadyCompletedOrNull) + { + callback(newValue); + + return true; + } + + async void MonitorTask() + { + try + { + await newValue!; + } + catch + { + } + + if (ReferenceEquals(taskNotifier.Task, newValue)) + { + OnPropertyChanged(propertyName); + } + + callback(newValue); + } + + MonitorTask(); + + return true; + } + + /// + /// An interface for task notifiers of a specified type. + /// + /// The type of value to store. + private interface ITaskNotifier + where TTask : Task + { + /// + /// Gets or sets the wrapped value. + /// + TTask? Task { get; set; } + } + + /// + /// A wrapping class that can hold a value. + /// + protected sealed class TaskNotifier : ITaskNotifier + { + /// + /// Initializes a new instance of the class. + /// + internal TaskNotifier() + { + } + + private Task? task; + + /// + Task? ITaskNotifier.Task + { + get => this.task; + set => this.task = value; + } + + /// + /// Unwraps the value stored in the current instance. + /// + /// The input instance. + public static implicit operator Task?(TaskNotifier? notifier) + { + return notifier?.task; + } + } + + /// + /// A wrapping class that can hold a value. + /// + /// The type of value for the wrapped instance. + protected sealed class TaskNotifier : ITaskNotifier> + { + /// + /// Initializes a new instance of the class. + /// + internal TaskNotifier() + { + } + + private Task? task; + + /// + Task? ITaskNotifier>.Task + { + get => this.task; + set => this.task = value; + } + + /// + /// Unwraps the value stored in the current instance. + /// + /// The input instance. + public static implicit operator Task?(TaskNotifier? notifier) + { + return notifier?.task; + } + } + } +} \ No newline at end of file diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/AttributeDataExtensions.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/AttributeDataExtensions.cs new file mode 100644 index 00000000000..93f9ac9f1d4 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/AttributeDataExtensions.cs @@ -0,0 +1,118 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions +{ + /// + /// Extension methods for the type. + /// + internal static class AttributeDataExtensions + { + /// + /// Checks whether a given instance contains a specified named argument. + /// + /// The type of argument to check. + /// The target instance to check. + /// The name of the argument to check. + /// The expected value for the target named argument. + /// Whether or not contains an argument named with the expected value. + [Pure] + public static bool HasNamedArgument(this AttributeData attributeData, string name, T? value) + { + foreach (KeyValuePair properties in attributeData.NamedArguments) + { + if (properties.Key == name) + { + return + properties.Value.Value is T argumentValue && + EqualityComparer.Default.Equals(argumentValue, value); + } + } + + return false; + } + + /// + /// Creates an node that is equivalent to the input instance. + /// + /// The input instance to process. + /// An replicating the data in . + [Pure] + public static AttributeSyntax AsAttributeSyntax(this AttributeData attributeData) + { + IdentifierNameSyntax attributeType = IdentifierName(attributeData.AttributeClass!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + AttributeArgumentSyntax[] arguments = + attributeData.ConstructorArguments + .Select(static arg => AttributeArgument(ToExpression(arg))).Concat( + attributeData.NamedArguments + .Select(static arg => + AttributeArgument(ToExpression(arg.Value)) + .WithNameEquals(NameEquals(IdentifierName(arg.Key))))).ToArray(); + + return Attribute(attributeType, AttributeArgumentList(SeparatedList(SeparatedList(arguments)))); + + static ExpressionSyntax ToExpression(TypedConstant arg) + { + if (arg.IsNull) + { + return LiteralExpression(SyntaxKind.NullLiteralExpression); + } + + if (arg.Kind == TypedConstantKind.Array) + { + string elementType = ((IArrayTypeSymbol)arg.Type!).ElementType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + return + ArrayCreationExpression( + ArrayType(IdentifierName(elementType)) + .AddRankSpecifiers(ArrayRankSpecifier(SingletonSeparatedList(OmittedArraySizeExpression())))) + .WithInitializer(InitializerExpression(SyntaxKind.ArrayInitializerExpression) + .AddExpressions(arg.Values.Select(ToExpression).ToArray())); + } + + switch ((arg.Kind, arg.Value)) + { + case (TypedConstantKind.Primitive, string text): + return LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(text)); + case (TypedConstantKind.Primitive, bool flag) when flag: + return LiteralExpression(SyntaxKind.TrueLiteralExpression); + case (TypedConstantKind.Primitive, bool): + return LiteralExpression(SyntaxKind.FalseLiteralExpression); + case (TypedConstantKind.Primitive, object value): + return LiteralExpression(SyntaxKind.NumericLiteralExpression, value switch + { + byte b => Literal(b), + char c => Literal(c), + double d => Literal(d), + float f => Literal(f), + int i => Literal(i), + long l => Literal(l), + sbyte sb => Literal(sb), + short sh => Literal(sh), + uint ui => Literal(ui), + ulong ul => Literal(ul), + ushort ush => Literal(ush), + _ => throw new ArgumentException() + }); + case (TypedConstantKind.Type, ITypeSymbol type): + return TypeOfExpression(IdentifierName(type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))); + case (TypedConstantKind.Enum, object value): + return CastExpression( + IdentifierName(arg.Type!.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + LiteralExpression(SyntaxKind.NumericLiteralExpression, ParseToken(value.ToString()))); + default: throw new ArgumentException(); + } + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/INamedTypeSymbolExtensions.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/INamedTypeSymbolExtensions.cs new file mode 100644 index 00000000000..e04728d019c --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/INamedTypeSymbolExtensions.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics.Contracts; +using System.Text; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions +{ + /// + /// Extension methods for the type. + /// + internal static class INamedTypeSymbolExtensions + { + /// + /// Gets the full metadata name for a given instance. + /// + /// The input instance. + /// The full metadata name for . + [Pure] + public static string GetFullMetadataName(this INamedTypeSymbol symbol) + { + static StringBuilder BuildFrom(ISymbol? symbol, StringBuilder builder) + { + return symbol switch + { + INamespaceSymbol ns when ns.IsGlobalNamespace => builder, + INamespaceSymbol ns when ns.ContainingNamespace is { IsGlobalNamespace: false } + => BuildFrom(ns.ContainingNamespace, builder.Insert(0, $".{ns.MetadataName}")), + ITypeSymbol ts when ts.ContainingType is ISymbol pt => BuildFrom(pt, builder.Insert(0, $"+{ts.MetadataName}")), + ITypeSymbol ts when ts.ContainingNamespace is ISymbol pn => BuildFrom(pn, builder.Insert(0, $".{ts.MetadataName}")), + ISymbol => BuildFrom(symbol.ContainingSymbol, builder.Insert(0, symbol.MetadataName)), + _ => builder + }; + } + + return BuildFrom(symbol, new StringBuilder(256)).ToString(); + } + + /// + /// Gets a valid filename for a given instance. + /// + /// The input instance. + /// The full metadata name for that is also a valid filename. + [Pure] + public static string GetFullMetadataNameForFileName(this INamedTypeSymbol symbol) + { + return symbol.GetFullMetadataName().Replace('`', '-').Replace('+', '.'); + } + + /// + /// Checks whether or not a given inherits from a specified type. + /// + /// The target instance to check. + /// The type symbol of the type to check for inheritance. + /// Whether or not inherits from . + [Pure] + public static bool InheritsFrom(this INamedTypeSymbol typeSymbol, INamedTypeSymbol targetTypeSymbol) + { + INamedTypeSymbol? baseType = typeSymbol.BaseType; + + while (baseType != null) + { + if (SymbolEqualityComparer.Default.Equals(baseType, targetTypeSymbol)) + { + return true; + } + + baseType = baseType.BaseType; + } + + return false; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/MemberDeclarationSyntaxExtensions.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/MemberDeclarationSyntaxExtensions.cs new file mode 100644 index 00000000000..cdc63118306 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Extensions/MemberDeclarationSyntaxExtensions.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics.Contracts; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions +{ + /// + /// Extension methods for the type. + /// + internal static class MemberDeclarationSyntaxExtensions + { + /// + /// Replaces a specific modifier. + /// + /// The input instance. + /// The target modifier kind to replace. + /// The new modifier kind to add or replace. + /// A instance with the target modifier. + [Pure] + public static MemberDeclarationSyntax ReplaceModifier(this MemberDeclarationSyntax memberDeclaration, SyntaxKind oldKind, SyntaxKind newKind) + { + int index = memberDeclaration.Modifiers.IndexOf(oldKind); + + if (index != -1) + { + return memberDeclaration.WithModifiers(memberDeclaration.Modifiers.Replace(memberDeclaration.Modifiers[index], Token(newKind))); + } + + return memberDeclaration; + } + + /// + /// Removes a specific modifier. + /// + /// The input instance. + /// The modifier kind to remove. + /// A instance without the specified modifier. + [Pure] + public static MemberDeclarationSyntax RemoveModifier(this MemberDeclarationSyntax memberDeclaration, SyntaxKind kind) + { + int index = memberDeclaration.Modifiers.IndexOf(kind); + + if (index != -1) + { + return memberDeclaration.WithModifiers(memberDeclaration.Modifiers.RemoveAt(index)); + } + + return memberDeclaration; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.SyntaxReceiver.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.SyntaxReceiver.cs new file mode 100644 index 00000000000..5e365a90b97 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.SyntaxReceiver.cs @@ -0,0 +1,51 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + public sealed partial class ICommandGenerator + { + /// + /// An that selects candidate nodes to process. + /// + private sealed class SyntaxReceiver : ISyntaxContextReceiver + { + /// + /// The list of info gathered during exploration. + /// + private readonly List gatheredInfo = new(); + + /// + /// Gets the collection of gathered info to process. + /// + public IReadOnlyCollection GatheredInfo => this.gatheredInfo; + + /// + public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + { + if (context.Node is MethodDeclarationSyntax methodDeclaration && + context.SemanticModel.GetDeclaredSymbol(methodDeclaration) is IMethodSymbol methodSymbol && + context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.ICommandAttribute") is INamedTypeSymbol iCommandSymbol && + methodSymbol.GetAttributes().Any(a => SymbolEqualityComparer.Default.Equals(a.AttributeClass, iCommandSymbol))) + { + this.gatheredInfo.Add(new Item(methodDeclaration.GetLeadingTrivia(), methodSymbol)); + } + } + + /// + /// A model for a group of item representing a discovered type to process. + /// + /// The leading trivia for the field declaration. + /// The instance for the target method. + public sealed record Item(SyntaxTriviaList LeadingTrivia, IMethodSymbol MethodSymbol); + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs new file mode 100644 index 00000000000..ffcca24f30a --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Input/ICommandGenerator.cs @@ -0,0 +1,340 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.Linq; +using System.Text; +using System.Text.RegularExpressions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; +using static Microsoft.CodeAnalysis.SymbolDisplayTypeQualificationStyle; +using static Microsoft.Toolkit.Mvvm.SourceGenerators.Diagnostics.DiagnosticDescriptors; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for generating command properties from annotated methods. + /// + [Generator] + public sealed partial class ICommandGenerator : ISourceGenerator + { + /// + public void Initialize(GeneratorInitializationContext context) + { + context.RegisterForSyntaxNotifications(static () => new SyntaxReceiver()); + } + + /// + public void Execute(GeneratorExecutionContext context) + { + // Get the syntax receiver with the candidate nodes + if (context.SyntaxContextReceiver is not SyntaxReceiver syntaxReceiver || + syntaxReceiver.GatheredInfo.Count == 0) + { + return; + } + + foreach (var items in syntaxReceiver.GatheredInfo.GroupBy(static item => item.MethodSymbol.ContainingType, SymbolEqualityComparer.Default)) + { + if (items.Key.DeclaringSyntaxReferences.Length > 0 && + items.Key.DeclaringSyntaxReferences.First().GetSyntax() is ClassDeclarationSyntax classDeclaration) + { + try + { + OnExecute(context, classDeclaration, items.Key, items); + } + catch + { + context.ReportDiagnostic(ICommandGeneratorError, classDeclaration, items.Key); + } + } + } + } + + /// + /// Processes a given target type. + /// + /// The input instance to use. + /// The node to process. + /// The for . + /// The sequence of instances to process. + private static void OnExecute( + GeneratorExecutionContext context, + ClassDeclarationSyntax classDeclaration, + INamedTypeSymbol classDeclarationSymbol, + IEnumerable items) + { + // Create the class declaration for the user type. This will produce a tree as follows: + // + // + // { + // + // } + var classDeclarationSyntax = + ClassDeclaration(classDeclarationSymbol.Name) + .WithModifiers(classDeclaration.Modifiers) + .AddMembers(items.Select(item => CreateCommandMembers(context, item.LeadingTrivia, item.MethodSymbol)).SelectMany(static g => g).ToArray()); + + TypeDeclarationSyntax typeDeclarationSyntax = classDeclarationSyntax; + + // Add all parent types in ascending order, if any + foreach (var parentType in classDeclaration.Ancestors().OfType()) + { + typeDeclarationSyntax = parentType + .WithMembers(SingletonList(typeDeclarationSyntax)) + .WithConstraintClauses(List()) + .WithBaseList(null) + .WithAttributeLists(List()) + .WithoutTrivia(); + } + + // Create the compilation unit with the namespace and target member. + // From this, we can finally generate the source code to output. + var namespaceName = classDeclarationSymbol.ContainingNamespace.ToDisplayString(new(typeQualificationStyle: NameAndContainingTypesAndNamespaces)); + + // Create the final compilation unit to generate (with leading trivia) + var source = + CompilationUnit().AddMembers( + NamespaceDeclaration(IdentifierName(namespaceName)).WithLeadingTrivia(TriviaList( + Comment("// Licensed to the .NET Foundation under one or more agreements."), + Comment("// The .NET Foundation licenses this file to you under the MIT license."), + Comment("// See the LICENSE file in the project root for more information."), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))) + .AddMembers(typeDeclarationSyntax)) + .NormalizeWhitespace() + .ToFullString(); + + // Add the partial type + context.AddSource($"{classDeclarationSymbol.GetFullMetadataNameForFileName()}.cs", SourceText.From(source, Encoding.UTF8)); + } + + /// + /// Creates the instances for a specified command. + /// + /// The input instance to use. + /// The leading trivia for the field to process. + /// The input instance to process. + /// The instances for the input command. + [Pure] + private static IEnumerable CreateCommandMembers(GeneratorExecutionContext context, SyntaxTriviaList leadingTrivia, IMethodSymbol methodSymbol) + { + // Get the command member names + var (fieldName, propertyName) = GetGeneratedFieldAndPropertyNames(context, methodSymbol); + + // Get the command type symbols + if (!TryMapCommandTypesFromMethod( + context, + methodSymbol, + out ITypeSymbol? commandInterfaceTypeSymbol, + out ITypeSymbol? commandClassTypeSymbol, + out ITypeSymbol? delegateTypeSymbol)) + { + context.ReportDiagnostic(InvalidICommandMethodSignatureError, methodSymbol, methodSymbol.ContainingType, methodSymbol); + + return Array.Empty(); + } + + // Construct the generated field as follows: + // + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // private ? ; + FieldDeclarationSyntax fieldDeclaration = + FieldDeclaration( + VariableDeclaration(NullableType(IdentifierName(commandClassTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))) + .AddVariables(VariableDeclarator(Identifier(fieldName)))) + .AddModifiers(Token(SyntaxKind.PrivateKeyword)) + .AddAttributeLists(AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).Assembly.GetName().Version.ToString()))))))); + + SyntaxTriviaList summaryTrivia = SyntaxTriviaList.Empty; + + // Parse the docs, if present + foreach (SyntaxTrivia trivia in leadingTrivia) + { + if (trivia.IsKind(SyntaxKind.SingleLineCommentTrivia) || + trivia.IsKind(SyntaxKind.SingleLineDocumentationCommentTrivia)) + { + string text = trivia.ToString(); + + Match match = Regex.Match(text, @".*?<\/summary>", RegexOptions.Singleline); + + if (match.Success) + { + summaryTrivia = TriviaList(Comment($"/// {match.Value}")); + + break; + } + } + } + + // Construct the generated property as follows (the explicit delegate cast is needed to avoid overload resolution conflicts): + // + // + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // [global::System.Diagnostics.DebuggerNonUserCode] + // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + // public => ??= new (new ()); + PropertyDeclarationSyntax propertyDeclaration = + PropertyDeclaration( + IdentifierName(commandInterfaceTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + Identifier(propertyName)) + .AddModifiers(Token(SyntaxKind.PublicKeyword)) + .AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(typeof(ICommandGenerator).Assembly.GetName().Version.ToString())))))) + .WithOpenBracketToken(Token(summaryTrivia, SyntaxKind.OpenBracketToken, TriviaList())), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage"))))) + .WithExpressionBody( + ArrowExpressionClause( + AssignmentExpression( + SyntaxKind.CoalesceAssignmentExpression, + IdentifierName(fieldName), + ObjectCreationExpression(IdentifierName(commandClassTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) + .AddArgumentListArguments(Argument( + ObjectCreationExpression(IdentifierName(delegateTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))) + .AddArgumentListArguments(Argument(IdentifierName(methodSymbol.Name)))))))) + .WithSemicolonToken(Token(SyntaxKind.SemicolonToken)); + + return new MemberDeclarationSyntax[] { fieldDeclaration, propertyDeclaration }; + } + + /// + /// Get the generated field and property names for the input method. + /// + /// The input instance to use. + /// The input instance to process. + /// The generated field and property names for . + [Pure] + private static (string FieldName, string PropertyName) GetGeneratedFieldAndPropertyNames(GeneratorExecutionContext context, IMethodSymbol methodSymbol) + { + string propertyName = methodSymbol.Name; + + if (SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType, context.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task")) && + methodSymbol.Name.EndsWith("Async")) + { + propertyName = propertyName.Substring(0, propertyName.Length - "Async".Length); + } + + propertyName += "Command"; + + string fieldName = $"{char.ToLower(propertyName[0])}{propertyName.Substring(1)}"; + + return (fieldName, propertyName); + } + + /// + /// Gets the type symbols for the input method, if supported. + /// + /// The input instance to use. + /// The input instance to process. + /// The command interface type symbol. + /// The command class type symbol. + /// The delegate type symbol for the wrapped method. + /// Whether or not was valid and the requested types have been set. + private static bool TryMapCommandTypesFromMethod( + GeneratorExecutionContext context, + IMethodSymbol methodSymbol, + [NotNullWhen(true)] out ITypeSymbol? commandInterfaceTypeSymbol, + [NotNullWhen(true)] out ITypeSymbol? commandClassTypeSymbol, + [NotNullWhen(true)] out ITypeSymbol? delegateTypeSymbol) + { + // Map to IRelayCommand, RelayCommand, Action + if (methodSymbol.ReturnsVoid && methodSymbol.Parameters.Length == 0) + { + commandInterfaceTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.IRelayCommand")!; + commandClassTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.RelayCommand")!; + delegateTypeSymbol = context.Compilation.GetTypeByMetadataName("System.Action")!; + + return true; + } + + // Map to IRelayCommand, RelayCommand, Action + if (methodSymbol.ReturnsVoid && + methodSymbol.Parameters.Length == 1 && + methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } parameter) + { + commandInterfaceTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.IRelayCommand`1")!.Construct(parameter.Type); + commandClassTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.RelayCommand`1")!.Construct(parameter.Type); + delegateTypeSymbol = context.Compilation.GetTypeByMetadataName("System.Action`1")!.Construct(parameter.Type); + + return true; + } + + if (SymbolEqualityComparer.Default.Equals(methodSymbol.ReturnType, context.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task")!)) + { + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + if (methodSymbol.Parameters.Length == 0) + { + commandInterfaceTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.IAsyncRelayCommand")!; + commandClassTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.AsyncRelayCommand")!; + delegateTypeSymbol = context.Compilation.GetTypeByMetadataName("System.Func`1")!.Construct(context.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task")!); + + return true; + } + + if (methodSymbol.Parameters.Length == 1 && + methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } singleParameter) + { + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + if (SymbolEqualityComparer.Default.Equals(singleParameter.Type, context.Compilation.GetTypeByMetadataName("System.Threading.CancellationToken")!)) + { + commandInterfaceTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.IAsyncRelayCommand")!; + commandClassTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.AsyncRelayCommand")!; + delegateTypeSymbol = context.Compilation.GetTypeByMetadataName("System.Func`2")!.Construct( + context.Compilation.GetTypeByMetadataName("System.Threading.CancellationToken")!, + context.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task")!); + + return true; + } + + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + commandInterfaceTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.IAsyncRelayCommand`1")!.Construct(singleParameter.Type); + commandClassTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.AsyncRelayCommand`1")!.Construct(singleParameter.Type); + delegateTypeSymbol = context.Compilation.GetTypeByMetadataName("System.Func`2")!.Construct( + singleParameter.Type, + context.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task")!); + + return true; + } + + // Map to IAsyncRelayCommand, AsyncRelayCommand, Func + if (methodSymbol.Parameters.Length == 2 && + methodSymbol.Parameters[0] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } firstParameter && + methodSymbol.Parameters[1] is IParameterSymbol { RefKind: RefKind.None, Type: { IsRefLikeType: false, TypeKind: not TypeKind.Pointer and not TypeKind.FunctionPointer } } secondParameter && + SymbolEqualityComparer.Default.Equals(secondParameter.Type, context.Compilation.GetTypeByMetadataName("System.Threading.CancellationToken")!)) + { + commandInterfaceTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.IAsyncRelayCommand`1")!.Construct(firstParameter.Type); + commandClassTypeSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Input.AsyncRelayCommand`1")!.Construct(firstParameter.Type); + delegateTypeSymbol = context.Compilation.GetTypeByMetadataName("System.Func`3")!.Construct( + firstParameter.Type, + secondParameter.Type, + context.Compilation.GetTypeByMetadataName("System.Threading.Tasks.Task")!); + + return true; + } + } + + commandInterfaceTypeSymbol = null; + commandClassTypeSymbol = null; + delegateTypeSymbol = null; + + return false; + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.SyntaxReceiver.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.SyntaxReceiver.cs new file mode 100644 index 00000000000..979cfb8f492 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.SyntaxReceiver.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + public sealed partial class IMessengerRegisterAllGenerator + { + /// + /// An that selects candidate nodes to process. + /// + private sealed class SyntaxReceiver : ISyntaxContextReceiver + { + /// + /// The list of info gathered during exploration. + /// + private readonly List gatheredInfo = new(); + + /// + /// Gets the collection of gathered info to process. + /// + public IReadOnlyCollection GatheredInfo => this.gatheredInfo; + + /// + public void OnVisitSyntaxNode(GeneratorSyntaxContext context) + { + if (context.Node is ClassDeclarationSyntax classDeclaration && + context.SemanticModel.GetDeclaredSymbol(classDeclaration) is INamedTypeSymbol { IsGenericType: false } classSymbol && + context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Messaging.IRecipient`1") is INamedTypeSymbol iRecipientSymbol && + classSymbol.AllInterfaces.Any(i => SymbolEqualityComparer.Default.Equals(i.OriginalDefinition, iRecipientSymbol))) + { + this.gatheredInfo.Add(classSymbol); + } + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs new file mode 100644 index 00000000000..91e4925dd99 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Messaging/IMessengerRegisterAllGenerator.cs @@ -0,0 +1,285 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics.Contracts; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Text; +using Microsoft.Toolkit.Mvvm.SourceGenerators.Extensions; +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; + +namespace Microsoft.Toolkit.Mvvm.SourceGenerators +{ + /// + /// A source generator for message registration without relying on compiled LINQ expressions. + /// + [Generator] + public sealed partial class IMessengerRegisterAllGenerator : ISourceGenerator + { + /// + public void Initialize(GeneratorInitializationContext context) + { + context.RegisterForSyntaxNotifications(static () => new SyntaxReceiver()); + } + + /// + public void Execute(GeneratorExecutionContext context) + { + // Get the syntax receiver with the candidate nodes + if (context.SyntaxContextReceiver is not SyntaxReceiver syntaxReceiver || + syntaxReceiver.GatheredInfo.Count == 0) + { + return; + } + + // Get the symbol for the IRecipient interface type + INamedTypeSymbol iRecipientSymbol = context.Compilation.GetTypeByMetadataName("Microsoft.Toolkit.Mvvm.Messaging.IRecipient`1")!; + + // Prepare the attributes to add to the first class declaration + AttributeListSyntax[] classAttributes = new[] + { + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName($"global::System.CodeDom.Compiler.GeneratedCode")) + .AddArgumentListArguments( + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().FullName))), + AttributeArgument(LiteralExpression(SyntaxKind.StringLiteralExpression, Literal(GetType().Assembly.GetName().Version.ToString())))))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.DebuggerNonUserCode")))), + AttributeList(SingletonSeparatedList(Attribute(IdentifierName("global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage")))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This type is not intended to be used directly by user code")))))) + }; + + foreach (INamedTypeSymbol classSymbol in syntaxReceiver.GatheredInfo) + { + // Create a static factory method to register all messages for a given recipient type. + // This follows the same pattern used in ObservableValidatorValidateAllPropertiesGenerator, + // with the same advantages mentioned there (type safety, more AOT-friendly, etc.). + // There are two versions that are generated: a non-generic one doing the registration + // with no tokens, which is the most common scenario and will help particularly in AOT + // scenarios, and a generic version that will support all other cases with custom tokens. + // Note: the generic overload has a different name to simplify the lookup with reflection. + // This code takes a class symbol and produces a compilation unit as follows: + // + // // Licensed to the .NET Foundation under one or more agreements. + // // The .NET Foundation licenses this file to you under the MIT license. + // // See the LICENSE file in the project root for more information. + // + // #pragma warning disable + // + // namespace Microsoft.Toolkit.Mvvm.Messaging.__Internals + // { + // [global::System.CodeDom.Compiler.GeneratedCode("...", "...")] + // [global::System.Diagnostics.DebuggerNonUserCode] + // [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This type is not intended to be used directly by user code")] + // internal static partial class __IMessengerExtensions + // { + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This method is not intended to be called directly by user code")] + // public static global::System.Action CreateAllMessagesRegistrator( _) + // { + // static void RegisterAll(IMessenger messenger, object obj) + // { + // var recipient = ()obj; + // + // } + // + // return RegisterAll; + // } + // + // [global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)] + // [global::System.Obsolete("This method is not intended to be called directly by user code")] + // public static global::System.Action CreateAllMessagesRegistratorWithToken( _) + // where TToken : global::System.IEquatable + // { + // static void RegisterAll(IMessenger messenger, object obj, TToken token) + // { + // var recipient = ()obj; + // + // } + // + // return RegisterAll; + // } + // } + // } + var source = + CompilationUnit().AddMembers( + NamespaceDeclaration(IdentifierName("Microsoft.Toolkit.Mvvm.Messaging.__Internals")).WithLeadingTrivia(TriviaList( + Comment("// Licensed to the .NET Foundation under one or more agreements."), + Comment("// The .NET Foundation licenses this file to you under the MIT license."), + Comment("// See the LICENSE file in the project root for more information."), + Trivia(PragmaWarningDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true)))).AddMembers( + ClassDeclaration("__IMessengerExtensions").AddModifiers( + Token(SyntaxKind.InternalKeyword), + Token(SyntaxKind.StaticKeyword), + Token(SyntaxKind.PartialKeyword)).AddAttributeLists(classAttributes).AddMembers( + MethodDeclaration( + GenericName("global::System.Action").AddTypeArgumentListArguments( + IdentifierName("IMessenger"), + PredefinedType(Token(SyntaxKind.ObjectKeyword))), + Identifier("CreateAllMessagesRegistrator")).AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This method is not intended to be called directly by user code"))))))).AddModifiers( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)).AddParameterListParameters( + Parameter(Identifier("_")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))) + .WithBody(Block( + LocalFunctionStatement( + PredefinedType(Token(SyntaxKind.VoidKeyword)), + Identifier("RegisterAll")) + .AddModifiers(Token(SyntaxKind.StaticKeyword)) + .AddParameterListParameters( + Parameter(Identifier("messenger")).WithType(IdentifierName("IMessenger")), + Parameter(Identifier("obj")).WithType(PredefinedType(Token(SyntaxKind.ObjectKeyword)))) + .WithBody(Block( + LocalDeclarationStatement( + VariableDeclaration(IdentifierName("var")) + .AddVariables( + VariableDeclarator(Identifier("recipient")) + .WithInitializer(EqualsValueClause( + CastExpression( + IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + IdentifierName("obj"))))))) + .AddStatements(EnumerateRegistrationStatements(classSymbol, iRecipientSymbol).ToArray())), + ReturnStatement(IdentifierName("RegisterAll")))), + MethodDeclaration( + GenericName("global::System.Action").AddTypeArgumentListArguments( + IdentifierName("IMessenger"), + PredefinedType(Token(SyntaxKind.ObjectKeyword)), + IdentifierName("TToken")), + Identifier("CreateAllMessagesRegistratorWithToken")).AddAttributeLists( + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.ComponentModel.EditorBrowsable")).AddArgumentListArguments( + AttributeArgument(ParseExpression("global::System.ComponentModel.EditorBrowsableState.Never"))))), + AttributeList(SingletonSeparatedList( + Attribute(IdentifierName("global::System.Obsolete")).AddArgumentListArguments( + AttributeArgument(LiteralExpression( + SyntaxKind.StringLiteralExpression, + Literal("This method is not intended to be called directly by user code"))))))).AddModifiers( + Token(SyntaxKind.PublicKeyword), + Token(SyntaxKind.StaticKeyword)).AddParameterListParameters( + Parameter(Identifier("_")).WithType(IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)))) + .AddTypeParameterListParameters(TypeParameter("TToken")) + .AddConstraintClauses( + TypeParameterConstraintClause("TToken") + .AddConstraints(TypeConstraint(GenericName("global::System.IEquatable").AddTypeArgumentListArguments(IdentifierName("TToken"))))) + .WithBody(Block( + LocalFunctionStatement( + PredefinedType(Token(SyntaxKind.VoidKeyword)), + Identifier("RegisterAll")) + .AddModifiers(Token(SyntaxKind.StaticKeyword)) + .AddParameterListParameters( + Parameter(Identifier("messenger")).WithType(IdentifierName("IMessenger")), + Parameter(Identifier("obj")).WithType(PredefinedType(Token(SyntaxKind.ObjectKeyword))), + Parameter(Identifier("token")).WithType(IdentifierName("TToken"))) + .WithBody(Block( + LocalDeclarationStatement( + VariableDeclaration(IdentifierName("var")) + .AddVariables( + VariableDeclarator(Identifier("recipient")) + .WithInitializer(EqualsValueClause( + CastExpression( + IdentifierName(classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + IdentifierName("obj"))))))) + .AddStatements(EnumerateRegistrationStatementsWithTokens(classSymbol, iRecipientSymbol).ToArray())), + ReturnStatement(IdentifierName("RegisterAll"))))))) + .NormalizeWhitespace() + .ToFullString(); + + // Reset the attributes list (so the same class doesn't get duplicate attributes) + classAttributes = Array.Empty(); + + // Add the partial type + context.AddSource($"{classSymbol.GetFullMetadataNameForFileName()}.cs", SourceText.From(source, Encoding.UTF8)); + } + } + + /// + /// Gets a sequence of statements to register declared message handlers. + /// + /// The input instance to process. + /// The type symbol for the IRecipient<T> interface. + /// The sequence of instances to register message handleers. + [Pure] + private static IEnumerable EnumerateRegistrationStatements(INamedTypeSymbol classSymbol, INamedTypeSymbol iRecipientSymbol) + { + foreach (var interfaceSymbol in classSymbol.AllInterfaces) + { + if (!SymbolEqualityComparer.Default.Equals(interfaceSymbol.OriginalDefinition, iRecipientSymbol)) + { + continue; + } + + // This enumerator produces a sequence of statements as follows: + // + // messenger.Register<>(recipient); + // messenger.Register<>(recipient); + // ... + // messenger.Register<>(recipient); + yield return + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("messenger"), + GenericName(Identifier("Register")).AddTypeArgumentListArguments( + IdentifierName(interfaceSymbol.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))))) + .AddArgumentListArguments(Argument(IdentifierName("recipient")))); + } + } + + /// + /// Gets a sequence of statements to register declared message handlers with custom tokens. + /// + /// The input instance to process. + /// The type symbol for the IRecipient<T> interface. + /// The sequence of instances to register message handleers. + [Pure] + private static IEnumerable EnumerateRegistrationStatementsWithTokens(INamedTypeSymbol classSymbol, INamedTypeSymbol iRecipientSymbol) + { + foreach (var interfaceSymbol in classSymbol.AllInterfaces) + { + if (!SymbolEqualityComparer.Default.Equals(interfaceSymbol.OriginalDefinition, iRecipientSymbol)) + { + continue; + } + + // This enumerator produces a sequence of statements as follows: + // + // messenger.Register<, TToken>(recipient, token); + // messenger.Register<, TToken>(recipient, token); + // ... + // messenger.Register<, TToken>(recipient, token); + yield return + ExpressionStatement( + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + IdentifierName("messenger"), + GenericName(Identifier("Register")).AddTypeArgumentListArguments( + IdentifierName(interfaceSymbol.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)), + IdentifierName("TToken")))) + .AddArgumentListArguments(Argument(IdentifierName("recipient")), Argument(IdentifierName("token")))); + } + } + } +} diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/Microsoft.Toolkit.Mvvm.SourceGenerators.csproj b/Microsoft.Toolkit.Mvvm.SourceGenerators/Microsoft.Toolkit.Mvvm.SourceGenerators.csproj new file mode 100644 index 00000000000..8e718f68287 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/Microsoft.Toolkit.Mvvm.SourceGenerators.csproj @@ -0,0 +1,37 @@ + + + + netstandard2.0 + 9.0 + enable + false + + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + + + + + + + + + + diff --git a/Microsoft.Toolkit.Mvvm.SourceGenerators/System.Runtime.CompilerServices/IsExternalInit.cs b/Microsoft.Toolkit.Mvvm.SourceGenerators/System.Runtime.CompilerServices/IsExternalInit.cs new file mode 100644 index 00000000000..cadcf5a8570 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm.SourceGenerators/System.Runtime.CompilerServices/IsExternalInit.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.ComponentModel; + +namespace System.Runtime.CompilerServices +{ + /// + /// Reserved to be used by the compiler for tracking metadata. + /// This class should not be used by developers in source code. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + internal static class IsExternalInit + { + } +} \ No newline at end of file diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/AlsoNotifyChangeForAttribute.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/AlsoNotifyChangeForAttribute.cs new file mode 100644 index 00000000000..ae5a66feaf3 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/AlsoNotifyChangeForAttribute.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.ComponentModel; +using System.Linq; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel +{ + /// + /// An attribute that can be used to support in generated properties. When this attribute is + /// used, the generated property setter will also call (or the equivalent + /// method in the target class) for the properties specified in the attribute data. This can be useful to keep the code compact when + /// there are one or more dependent properties that should also be reported as updated when the value of the annotated observable + /// property is changed. If this attribute is used in a field without , it is ignored. + /// + /// In order to use this attribute, the containing type has to implement the interface + /// and expose a method with the same signature as . If the containing + /// type also implements the interface and exposes a method with the same signature as + /// , then this method will be invoked as well by the property setter. + /// + /// + /// This attribute can be used as follows: + /// + /// partial class MyViewModel : ObservableObject + /// { + /// [ObservableProperty] + /// [AlsoNotifyChangeFor(nameof(FullName))] + /// private string name; + /// + /// [ObservableProperty] + /// [AlsoNotifyChangeFor(nameof(FullName))] + /// private string surname; + /// + /// public string FullName => $"{Name} {Surname}"; + /// } + /// + /// + /// And with this, code analogous to this will be generated: + /// + /// partial class MyViewModel + /// { + /// public string Name + /// { + /// get => name; + /// set + /// { + /// if (SetProperty(ref name, value)) + /// { + /// OnPropertyChanged(nameof(FullName)); + /// } + /// } + /// } + /// + /// public string Surname + /// { + /// get => surname; + /// set + /// { + /// if (SetProperty(ref surname, value)) + /// { + /// OnPropertyChanged(nameof(FullName)); + /// } + /// } + /// } + /// } + /// + /// + [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = false)] + public sealed class AlsoNotifyChangeForAttribute : Attribute + { + /// + /// Initializes a new instance of the class. + /// + /// The name of the property to also notify when the annotated property changes. + public AlsoNotifyChangeForAttribute(string propertyName) + { + PropertyNames = new[] { propertyName }; + } + + /// + /// Initializes a new instance of the class. + /// + /// The name of the property to also notify when the annotated property changes. + /// + /// The other property names to also notify when the annotated property changes. This parameter can optionally + /// be used to indicate a series of dependent properties from the same attribute, to keep the code more compact. + /// + public AlsoNotifyChangeForAttribute(string propertyName, string[] otherPropertyNames) + { + PropertyNames = new[] { propertyName }.Concat(otherPropertyNames).ToArray(); + } + + /// + /// Gets the property names to also notify when the annotated property changes. + /// + public string[] PropertyNames { get; } + } +} diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/INotifyPropertyChangedAttribute.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/INotifyPropertyChangedAttribute.cs new file mode 100644 index 00000000000..eef91ae0ef8 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/INotifyPropertyChangedAttribute.cs @@ -0,0 +1,38 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.ComponentModel; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel +{ + /// + /// An attribute that indicates that a given type should implement the interface and + /// have minimal built-in functionality to support it. This includes exposing the necessary event and having two methods + /// to raise it that mirror and + /// . For more extensive support, use . + /// + /// This attribute can be used as follows: + /// + /// [INotifyPropertyChanged] + /// partial class MyViewModel : SomeOtherClass + /// { + /// // Other members here... + /// } + /// + /// + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] + public sealed class INotifyPropertyChangedAttribute : Attribute + { + /// + /// Gets or sets a value indicating whether or not to also generate all the additional helper methods that are found + /// in as well (eg. . + /// If set to , only the event and + /// the two overloads will be generated. + /// The default value is . + /// + public bool IncludeAdditionalHelperMethods { get; set; } = true; + } +} diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservableObjectAttribute.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservableObjectAttribute.cs new file mode 100644 index 00000000000..72c4b078c85 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservableObjectAttribute.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.ComponentModel; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel +{ + /// + /// An attribute that indicates that a given type should have all the members from + /// generated into it, as well as the and + /// interfaces. This can be useful when you want the same functionality from into a class + /// that already inherits from another one (since C# doesn't support multiple inheritance). This attribute will trigger + /// the source generator to just create the same APIs directly into the decorated class. + /// + /// This attribute can be used as follows: + /// + /// [ObservableObject] + /// partial class MyViewModel : SomeOtherClass + /// { + /// // Other members here... + /// } + /// + /// + /// And with this, the same APIs from will be available on this type as well. + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] + public sealed class ObservableObjectAttribute : Attribute + { + } +} diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservablePropertyAttribute.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservablePropertyAttribute.cs new file mode 100644 index 00000000000..2395ec31ad9 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservablePropertyAttribute.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.ComponentModel; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel +{ + /// + /// An attribute that indicates that a given field should be wrapped by a generated observable property. + /// In order to use this attribute, the containing type has to implement the interface + /// and expose a method with the same signature as . If the containing + /// type also implements the interface and exposes a method with the same signature as + /// , then this method will be invoked as well by the property setter. + /// + /// This attribute can be used as follows: + /// + /// partial class MyViewModel : ObservableObject + /// { + /// [ObservableProperty] + /// private string name; + /// + /// [ObservableProperty] + /// private bool isEnabled; + /// } + /// + /// + /// And with this, code analogous to this will be generated: + /// + /// partial class MyViewModel + /// { + /// public string Name + /// { + /// get => name; + /// set => SetProperty(ref name, value); + /// } + /// + /// public bool IsEnabled + /// { + /// get => name; + /// set => SetProperty(ref isEnabled, value); + /// } + /// } + /// + /// + /// + /// The generated properties will automatically use the UpperCamelCase format for their names, + /// which will be derived from the field names. The generator can also recognize fields using either + /// the _lowerCamel or m_lowerCamel naming scheme. Otherwise, the first character in the + /// source field name will be converted to uppercase (eg. isEnabled to IsEnabled). + /// + [AttributeUsage(AttributeTargets.Field, AllowMultiple = false, Inherited = false)] + public sealed class ObservablePropertyAttribute : Attribute + { + } +} diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservableRecipientAttribute.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservableRecipientAttribute.cs new file mode 100644 index 00000000000..382cce5c7e1 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/Attributes/ObservableRecipientAttribute.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel +{ + /// + /// An attribute that indicates that a given type should have all the members from + /// generated into it. This can be useful when you want the same functionality from into + /// a class that already inherits from another one (since C# doesn't support multiple inheritance). This attribute will trigger + /// the source generator to just create the same APIs directly into the decorated class. For instance, this attribute can be + /// used to easily combine the functionality from both and , + /// by using as the base class and adding this attribute to the declared type. + /// + /// This attribute can be used as follows: + /// + /// [ObservableRecipient] + /// partial class MyViewModel : ObservableValidator + /// { + /// // Other members here... + /// } + /// + /// + /// And with this, the same APIs from will be available on this type as well. + /// + /// To avoid conflicts with other APIs in types where the new members are being generated, constructors are only generated when the annotated + /// type doesn't have any explicit constructors being declared. If that is the case, the same constructors from + /// are emitted, with the accessibility adapted to that of the annotated type. Otherwise, they are skipped, so the type being annotated has the + /// respondibility of properly initializing the property. Additionally, if the annotated type inherits + /// from , the overloads will be skipped + /// as well, as they would conflict with the methods. + /// + /// + /// + /// In order to work, needs to be applied to a type that inherits from + /// (either directly or indirectly), or to one decorated with . + /// This is because the methods rely on some of the inherited members to work. + /// If this condition is not met, the code will fail to build. + /// + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = false)] + public sealed class ObservableRecipientAttribute : Attribute + { + } +} diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableObject.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableObject.cs index c7e4f053306..5685e3ed413 100644 --- a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableObject.cs +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableObject.cs @@ -7,6 +7,14 @@ // This file is inspired from the MvvmLight library (lbugnion/MvvmLight), // more info in ThirdPartyNotices.txt in the root of the project. +// ================================== NOTE ================================== +// This file is mirrored in the trimmed-down INotifyPropertyChanged file in +// the source generator project, to be used with the [INotifyPropertyChanged], +// attribute, along with the ObservableObject annotated copy (for debugging info). +// If any changes are made to this file, they should also be appropriately +// ported to that file as well to keep the behavior consistent. +// ========================================================================== + using System; using System.Collections.Generic; using System.ComponentModel; diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs index dc5d5c85690..dbf15fe1d6e 100644 --- a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableRecipient.cs @@ -7,6 +7,13 @@ // This file is inspired from the MvvmLight library (lbugnion/MvvmLight), // more info in ThirdPartyNotices.txt in the root of the project. +// ================================= NOTE ================================= +// This file is mirrored in the ObservableRecipient annotated copy +// (for debugging info) in the Mvvm.SourceGenerators project. +// If any changes are made to this file, they should also be appropriately +// ported to that file as well to keep the behavior consistent. +// ======================================================================== + using System; using System.Collections.Generic; using System.Runtime.CompilerServices; diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableValidator.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableValidator.cs index 58bf60577d1..0390ba60247 100644 --- a/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableValidator.cs +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/ObservableValidator.cs @@ -471,7 +471,21 @@ IEnumerable GetAllErrors() /// protected void ValidateAllProperties() { + // Fast path that tries to create a delegate from a generated type-specific method. This + // is used to make this method more AOT-friendly and faster, as there is no dynamic code. static Action GetValidationAction(Type type) + { + if (type.Assembly.GetType("Microsoft.Toolkit.Mvvm.ComponentModel.__Internals.__ObservableValidatorExtensions") is Type extensionsType && + extensionsType.GetMethod("CreateAllPropertiesValidator", new[] { type }) is MethodInfo methodInfo) + { + return (Action)methodInfo.Invoke(null, new object?[] { null })!; + } + + return GetValidationActionFallback(type); + } + + // Fallback method to create the delegate with a compiled LINQ expression + static Action GetValidationActionFallback(Type type) { // MyViewModel inst0 = (MyViewModel)arg0; ParameterExpression arg0 = Expression.Parameter(typeof(object)); @@ -489,6 +503,7 @@ static Action GetValidationAction(Type type) // inst0.ValidateProperty(inst0.Property0, nameof(MyViewModel.Property0)); // inst0.ValidateProperty(inst0.Property1, nameof(MyViewModel.Property1)); // ... + // inst0.ValidateProperty(inst0.PropertyN, nameof(MyViewModel.PropertyN)); // } // =============================================================================== // We also add an explicit object conversion to represent boxing, if a given property @@ -523,7 +538,7 @@ where getter is not null /// The value to test for the specified property. /// The name of the property to validate. /// Thrown when is . - protected void ValidateProperty(object? value, [CallerMemberName] string? propertyName = null) + protected internal void ValidateProperty(object? value, [CallerMemberName] string? propertyName = null) { if (propertyName is null) { diff --git a/Microsoft.Toolkit.Mvvm/ComponentModel/__Internals/__ObservableValidatorHelper.cs b/Microsoft.Toolkit.Mvvm/ComponentModel/__Internals/__ObservableValidatorHelper.cs new file mode 100644 index 00000000000..bd9ee7ede80 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/ComponentModel/__Internals/__ObservableValidatorHelper.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma warning disable SA1300 + +using System; +using System.ComponentModel; + +namespace Microsoft.Toolkit.Mvvm.ComponentModel.__Internals +{ + /// + /// An internal helper to support the source generator APIs related to . + /// This type is not intended to be used directly by user code. + /// + [EditorBrowsable(EditorBrowsableState.Never)] + [Obsolete("This type is not intended to be used directly by user code")] + public static class __ObservableValidatorHelper + { + /// + /// Invokes externally on a target instance. + /// + /// The target instance. + /// The value to test for the specified property. + /// The name of the property to validate. + [EditorBrowsable(EditorBrowsableState.Never)] + [Obsolete("This method is not intended to be called directly by user code")] + public static void ValidateProperty(ObservableValidator instance, object? value, string propertyName) + { + instance.ValidateProperty(value, propertyName); + } + } +} \ No newline at end of file diff --git a/Microsoft.Toolkit.Mvvm/Input/Attributes/ICommandAttribute.cs b/Microsoft.Toolkit.Mvvm/Input/Attributes/ICommandAttribute.cs new file mode 100644 index 00000000000..2a701d86b72 --- /dev/null +++ b/Microsoft.Toolkit.Mvvm/Input/Attributes/ICommandAttribute.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Windows.Input; + +namespace Microsoft.Toolkit.Mvvm.Input +{ + /// + /// An attribute that can be used to automatically generate properties from declared methods. When this attribute + /// is used to decorate a method, a generator will create a command property with the corresponding interface + /// depending on the signature of the method. If an invalid method signature is used, the generator will report an error. + /// + /// In order to use this attribute, the containing type doesn't need to implement any interfaces. The generated properties will be lazily + /// assigned but their value will never change, so there is no need to support property change notifications or other additional functionality. + /// + /// + /// This attribute can be used as follows: + /// + /// partial class MyViewModel + /// { + /// [ICommand] + /// private void GreetUser(User? user) + /// { + /// Console.WriteLine($"Hello {user.Name}!"); + /// } + /// } + /// + /// And with this, code analogous to this will be generated: + /// + /// partial class MyViewModel + /// { + /// private IRelayCommand? greetUserCommand; + /// + /// public IRelayCommand GreetUserCommand => greetUserCommand ??= new RelayCommand(GreetUser); + /// } + /// + /// + /// + /// The following signatures are supported for annotated methods: + /// + /// void Method(); + /// + /// Will generate an property (using a instance). + /// + /// void Method(T?); + /// + /// Will generate an property (using a instance). + /// + /// Task Method(); + /// Task Method(CancellationToken); + /// + /// Will both generate an property (using an instance). + /// + /// Task Method(T?); + /// Task Method(T?, CancellationToken); + /// + /// Will both generate an property (using an instance). + /// + /// + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false, Inherited = false)] + public sealed class ICommandAttribute : Attribute + { + } +} diff --git a/Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs b/Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs index 43b9dc42add..2c1fe66f629 100644 --- a/Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs +++ b/Microsoft.Toolkit.Mvvm/Messaging/IMessengerExtensions.cs @@ -33,6 +33,17 @@ private static class MethodInfos public static readonly MethodInfo RegisterIRecipient = new Action, Unit>(Register).Method.GetGenericMethodDefinition(); } + /// + /// A non-generic version of . + /// + private static class DiscoveredRecipients + { + /// + /// The instance used to track the preloaded registration action for each recipient. + /// + public static readonly ConditionalWeakTable?> RegistrationMethods = new(); + } + /// /// A class that acts as a static container to associate a instance to each /// type in use. This is done because we can only use a single type as key, but we need to track @@ -45,9 +56,9 @@ private static class DiscoveredRecipients where TToken : IEquatable { /// - /// The instance used to track the preloaded registration actions for each recipient. + /// The instance used to track the preloaded registration action for each recipient. /// - public static readonly ConditionalWeakTable[]> RegistrationMethods = new(); + public static readonly ConditionalWeakTable> RegistrationMethods = new(); } /// @@ -73,7 +84,33 @@ public static bool IsRegistered(this IMessenger messenger, object reci /// See notes for for more info. public static void RegisterAll(this IMessenger messenger, object recipient) { - messenger.RegisterAll(recipient, default(Unit)); + // We use this method as a callback for the conditional weak table, which will handle + // thread-safety for us. This first callback will try to find a generated method for the + // target recipient type, and just invoke it to get the delegate to cache and use later. + static Action? LoadRegistrationMethodsForType(Type recipientType) + { + if (recipientType.Assembly.GetType("Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType && + extensionsType.GetMethod("CreateAllMessagesRegistrator", new[] { recipientType }) is MethodInfo methodInfo) + { + return (Action)methodInfo.Invoke(null, new object?[] { null })!; + } + + return null; + } + + // Try to get the cached delegate, if the generatos has run correctly + Action? registrationAction = DiscoveredRecipients.RegistrationMethods.GetValue( + recipient.GetType(), + static t => LoadRegistrationMethodsForType(t)); + + if (registrationAction is not null) + { + registrationAction(messenger, recipient); + } + else + { + messenger.RegisterAll(recipient, default(Unit)); + } } /// @@ -93,25 +130,28 @@ public static void RegisterAll(this IMessenger messenger, object recipient) public static void RegisterAll(this IMessenger messenger, object recipient, TToken token) where TToken : IEquatable { - // We use this method as a callback for the conditional weak table, which will both - // handle thread-safety for us, as well as avoiding all the LINQ codegen bloat here. - // This method is only invoked once per recipient type and token type, so we're not - // worried about making it super efficient, and we can use the LINQ code for clarity. - static Action[] LoadRegistrationMethodsForType(Type type) + // We use this method as a callback for the conditional weak table, which will handle + // thread-safety for us. This first callback will try to find a generated method for the + // target recipient type, and just invoke it to get the delegate to cache and use later. + // In this case we also need to create a generic instantiation of the target method first. + static Action LoadRegistrationMethodsForType(Type recipientType) { - return ( - from interfaceType in type.GetInterfaces() - where interfaceType.IsGenericType && - interfaceType.GetGenericTypeDefinition() == typeof(IRecipient<>) - let messageType = interfaceType.GenericTypeArguments[0] - let registrationMethod = MethodInfos.RegisterIRecipient.MakeGenericMethod(messageType, typeof(TToken)) - let registrationAction = GetRegistrationAction(type, registrationMethod) - select registrationAction).ToArray(); + if (recipientType.Assembly.GetType("Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions") is Type extensionsType && + extensionsType.GetMethod("CreateAllMessagesRegistratorWithToken", new[] { recipientType }) is MethodInfo methodInfo) + { + MethodInfo genericMethodInfo = methodInfo.MakeGenericMethod(typeof(TToken)); + + return (Action)genericMethodInfo.Invoke(null, new object?[] { null })!; + } + + return LoadRegistrationMethodsForTypeFallback(recipientType); } - // Helper method to build and compile an expression tree to a message handler to use for the registration - // This is used to reduce the overhead of repeated calls to MethodInfo.Invoke (which is over 10 times slower). - static Action GetRegistrationAction(Type type, MethodInfo methodInfo) + // Fallback method when a generated method is not found. + // This method is only invoked once per recipient type and token type, so we're not + // worried about making it super efficient, and we can use the LINQ code for clarity. + // The LINQ codegen bloat is not really important for the same reason. + static Action LoadRegistrationMethodsForTypeFallback(Type recipientType) { // Input parameters (IMessenger instance, non-generic recipient, token) ParameterExpression @@ -119,31 +159,51 @@ static Action GetRegistrationAction(Type type, Metho arg1 = Expression.Parameter(typeof(object)), arg2 = Expression.Parameter(typeof(TToken)); - // Cast the recipient and invoke the registration method - MethodCallExpression body = Expression.Call(null, methodInfo, new Expression[] - { - arg0, - Expression.Convert(arg1, type), - arg2 - }); + // Declare a local resulting from the (RecipientType)recipient cast + UnaryExpression inst1 = Expression.Convert(arg1, recipientType); + + // We want a single compiled LINQ expression that executes the registration for all + // the declared message types in the input type. To do so, we create a block with the + // unrolled invocations for the indivudual message registration (for each IRecipient). + // The code below will generate the following block expression: + // =============================================================================== + // { + // var inst1 = (RecipientType)arg1; + // IMessengerExtensions.Register(arg0, inst1, arg2); + // IMessengerExtensions.Register(arg0, inst1, arg2); + // ... + // IMessengerExtensions.Register(arg0, inst1, arg2); + // } + // =============================================================================== + // We also add an explicit object conversion to cast the input recipient type to + // the actual specific type, so that the exposed message handlers are accessible. + BlockExpression body = Expression.Block( + from interfaceType in recipientType.GetInterfaces() + where interfaceType.IsGenericType && + interfaceType.GetGenericTypeDefinition() == typeof(IRecipient<>) + let messageType = interfaceType.GenericTypeArguments[0] + let registrationMethod = MethodInfos.RegisterIRecipient.MakeGenericMethod(messageType, typeof(TToken)) + select Expression.Call(registrationMethod, new Expression[] + { + arg0, + inst1, + arg2 + })); - // Create the expression tree and compile to a target delegate return Expression.Lambda>(body, arg0, arg1, arg2).Compile(); } - // Get or compute the registration methods for the current recipient type. + // Get or compute the registration method for the current recipient type. // As in Microsoft.Toolkit.Diagnostics.TypeExtensions.ToTypeString, we use a lambda // expression instead of a method group expression to leverage the statically initialized // delegate and avoid repeated allocations for each invocation of this method. // For more info on this, see the related issue at https://github.com/dotnet/roslyn/issues/5835. - Action[] registrationActions = DiscoveredRecipients.RegistrationMethods.GetValue( + Action registrationAction = DiscoveredRecipients.RegistrationMethods.GetValue( recipient.GetType(), static t => LoadRegistrationMethodsForType(t)); - foreach (Action registrationAction in registrationActions) - { - registrationAction(messenger, recipient, token); - } + // Invoke the cached delegate to actually execute the message registration + registrationAction(messenger, recipient, token); } /// diff --git a/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj b/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj index 93ac0794922..efa6bad4691 100644 --- a/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj +++ b/Microsoft.Toolkit.Mvvm/Microsoft.Toolkit.Mvvm.csproj @@ -20,6 +20,7 @@ - Ioc: a helper class to configure dependency injection service containers. MVVM;Toolkit;MVVMToolkit;INotifyPropertyChanged;Observable;IOC;DI;Dependency Injection;Object Messaging;Extensions;Helpers + $(TargetsForTfmSpecificContentInPackage);CopyAnalyzerProjectReferencesToPackage @@ -34,5 +35,22 @@ + + + + + + + + + + + analyzers\dotnet\cs + + + \ No newline at end of file diff --git a/UnitTests/UnitTests.NetCore/Mvvm/Test_ICommandAttribute.cs b/UnitTests/UnitTests.NetCore/Mvvm/Test_ICommandAttribute.cs new file mode 100644 index 00000000000..2bcea9e41dd --- /dev/null +++ b/UnitTests/UnitTests.NetCore/Mvvm/Test_ICommandAttribute.cs @@ -0,0 +1,112 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma warning disable CS0618 + +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Toolkit.Mvvm.Input; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + public partial class Test_ICommandAttribute + { + [TestCategory("Mvvm")] + [TestMethod] + public async Task Test_ICommandAttribute_RelayCommand() + { + var model = new MyViewModel(); + + Assert.AreEqual(model.Counter, 0); + + model.IncrementCounterCommand.Execute(null); + + Assert.AreEqual(model.Counter, 1); + + model.IncrementCounterWithValueCommand.Execute(5); + + Assert.AreEqual(model.Counter, 6); + + await model.DelayAndIncrementCounterCommand.ExecuteAsync(null); + + Assert.AreEqual(model.Counter, 7); + + await model.DelayAndIncrementCounterWithTokenCommand.ExecuteAsync(null); + + Assert.AreEqual(model.Counter, 8); + + await model.DelayAndIncrementCounterWithValueCommand.ExecuteAsync(5); + + Assert.AreEqual(model.Counter, 13); + + await model.DelayAndIncrementCounterWithValueAndTokenCommand.ExecuteAsync(5); + + Assert.AreEqual(model.Counter, 18); + } + + public sealed partial class MyViewModel + { + public int Counter { get; private set; } + + /// This is a single line summary. + [ICommand] + private void IncrementCounter() + { + Counter++; + } + + /// + /// This is a multiline summary + /// + [ICommand] + private void IncrementCounterWithValue(int count) + { + Counter += count; + } + + /// This is single line with also other stuff below + /// Foo bar baz + /// A task + [ICommand] + private async Task DelayAndIncrementCounterAsync() + { + await Task.Delay(50); + + Counter += 1; + } + + /// + /// This is multi line with also other stuff below + /// + /// Foo bar baz + /// A task + [ICommand] + private async Task DelayAndIncrementCounterWithTokenAsync(CancellationToken token) + { + await Task.Delay(50); + + Counter += 1; + } + + // This should not be ported over + [ICommand] + private async Task DelayAndIncrementCounterWithValueAsync(int count) + { + await Task.Delay(50); + + Counter += count; + } + + [ICommand] + private async Task DelayAndIncrementCounterWithValueAndTokenAsync(int count, CancellationToken token) + { + await Task.Delay(50); + + Counter += count; + } + } + } +} diff --git a/UnitTests/UnitTests.NetCore/Mvvm/Test_INotifyPropertyChangedAttribute.cs b/UnitTests/UnitTests.NetCore/Mvvm/Test_INotifyPropertyChangedAttribute.cs new file mode 100644 index 00000000000..5eba0389465 --- /dev/null +++ b/UnitTests/UnitTests.NetCore/Mvvm/Test_INotifyPropertyChangedAttribute.cs @@ -0,0 +1,77 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.ComponentModel; +using System.Reflection; +using Microsoft.Toolkit.Mvvm.ComponentModel; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + public partial class Test_INotifyPropertyChangedAttribute + { + [TestCategory("Mvvm")] + [TestMethod] + public void Test_INotifyPropertyChanged_Events() + { + var model = new SampleModel(); + + (PropertyChangedEventArgs, int) changed = default; + + model.PropertyChanged += (s, e) => + { + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changed = (e, model.Data); + }; + + model.Data = 42; + + Assert.AreEqual(changed.Item1?.PropertyName, nameof(SampleModel.Data)); + Assert.AreEqual(changed.Item2, 42); + } + + [INotifyPropertyChanged] + public partial class SampleModel + { + private int data; + + public int Data + { + get => data; + set => SetProperty(ref data, value); + } + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_INotifyPropertyChanged_WithoutHelpers() + { + Assert.IsTrue(typeof(INotifyPropertyChanged).IsAssignableFrom(typeof(SampleModelWithoutHelpers))); + Assert.IsFalse(typeof(INotifyPropertyChanging).IsAssignableFrom(typeof(SampleModelWithoutHelpers))); + + // This just needs to check that it compiles + _ = nameof(SampleModelWithoutHelpers.PropertyChanged); + + var methods = typeof(SampleModelWithoutHelpers).GetMethods(BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.DeclaredOnly); + + Assert.AreEqual(methods.Length, 2); + Assert.AreEqual(methods[0].Name, "OnPropertyChanged"); + Assert.AreEqual(methods[1].Name, "OnPropertyChanged"); + + var types = typeof(SampleModelWithoutHelpers).GetNestedTypes(BindingFlags.NonPublic); + + Assert.AreEqual(types.Length, 0); + } + + [INotifyPropertyChanged(IncludeAdditionalHelperMethods = false)] + public partial class SampleModelWithoutHelpers + { + } + } +} diff --git a/UnitTests/UnitTests.NetCore/Mvvm/Test_IRecipientGenerator.cs b/UnitTests/UnitTests.NetCore/Mvvm/Test_IRecipientGenerator.cs new file mode 100644 index 00000000000..dece3973370 --- /dev/null +++ b/UnitTests/UnitTests.NetCore/Mvvm/Test_IRecipientGenerator.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#pragma warning disable CS0618 + +using System; +using Microsoft.Toolkit.Mvvm.Messaging; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + public partial class Test_IRecipientGenerator + { + [TestCategory("Mvvm")] + [TestMethod] + public void Test_IRecipientGenerator_GeneratedRegistration() + { + var messenger = new StrongReferenceMessenger(); + var recipient = new RecipientWithSomeMessages(); + + var messageA = new MessageA(); + var messageB = new MessageB(); + + Action registrator = Microsoft.Toolkit.Mvvm.Messaging.__Internals.__IMessengerExtensions.CreateAllMessagesRegistratorWithToken(recipient); + + registrator(messenger, recipient, 42); + + Assert.IsTrue(messenger.IsRegistered(recipient, 42)); + Assert.IsTrue(messenger.IsRegistered(recipient, 42)); + + Assert.IsNull(recipient.A); + Assert.IsNull(recipient.B); + + messenger.Send(messageA, 42); + + Assert.AreSame(recipient.A, messageA); + Assert.IsNull(recipient.B); + + messenger.Send(messageB, 42); + + Assert.AreSame(recipient.A, messageA); + Assert.AreSame(recipient.B, messageB); + } + + public sealed class RecipientWithSomeMessages : + IRecipient, + IRecipient + { + public MessageA A { get; private set; } + + public MessageB B { get; private set; } + + public void Receive(MessageA message) + { + A = message; + } + + public void Receive(MessageB message) + { + B = message; + } + } + + public sealed class MessageA + { + } + + public sealed class MessageB + { + } + } +} diff --git a/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservableObjectAttribute.cs b/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservableObjectAttribute.cs new file mode 100644 index 00000000000..6b3cdf0dbe3 --- /dev/null +++ b/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservableObjectAttribute.cs @@ -0,0 +1,116 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.ComponentModel; +using Microsoft.Toolkit.Mvvm.ComponentModel; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + public partial class Test_ObservableObjectAttribute + { + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservableObjectAttribute_Events() + { + var model = new SampleModel(); + + (PropertyChangingEventArgs, int) changing = default; + (PropertyChangedEventArgs, int) changed = default; + + model.PropertyChanging += (s, e) => + { + Assert.IsNull(changing.Item1); + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changing = (e, model.Data); + }; + + model.PropertyChanged += (s, e) => + { + Assert.IsNotNull(changing.Item1); + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changed = (e, model.Data); + }; + + model.Data = 42; + + Assert.AreEqual(changing.Item1?.PropertyName, nameof(SampleModel.Data)); + Assert.AreEqual(changing.Item2, 0); + Assert.AreEqual(changed.Item1?.PropertyName, nameof(SampleModel.Data)); + Assert.AreEqual(changed.Item2, 42); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservableObjectAttribute_OnSealedClass_Events() + { + var model = new SampleModelSealed(); + + (PropertyChangingEventArgs, int) changing = default; + (PropertyChangedEventArgs, int) changed = default; + + model.PropertyChanging += (s, e) => + { + Assert.IsNull(changing.Item1); + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changing = (e, model.Data); + }; + + model.PropertyChanged += (s, e) => + { + Assert.IsNotNull(changing.Item1); + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changed = (e, model.Data); + }; + + model.Data = 42; + + Assert.AreEqual(changing.Item1?.PropertyName, nameof(SampleModelSealed.Data)); + Assert.AreEqual(changing.Item2, 0); + Assert.AreEqual(changed.Item1?.PropertyName, nameof(SampleModelSealed.Data)); + Assert.AreEqual(changed.Item2, 42); + } + + [ObservableObject] + public partial class SampleModel + { + private int data; + + public int Data + { + get => data; + set => SetProperty(ref data, value); + } + } + + [ObservableObject] + public sealed partial class SampleModelSealed + { + private int data; + + public int Data + { + get => data; + set => SetProperty(ref data, value); + } + } + } +} diff --git a/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservablePropertyAttribute.cs b/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservablePropertyAttribute.cs new file mode 100644 index 00000000000..97ef214b38c --- /dev/null +++ b/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservablePropertyAttribute.cs @@ -0,0 +1,196 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.ComponentModel.DataAnnotations; +using System.Reflection; +using Microsoft.Toolkit.Mvvm.ComponentModel; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +#nullable enable + +namespace UnitTests.Mvvm +{ + [TestClass] + public partial class Test_ObservablePropertyAttribute + { + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservablePropertyAttribute_Events() + { + var model = new SampleModel(); + + (PropertyChangingEventArgs, int) changing = default; + (PropertyChangedEventArgs, int) changed = default; + + model.PropertyChanging += (s, e) => + { + Assert.IsNull(changing.Item1); + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changing = (e, model.Data); + }; + + model.PropertyChanged += (s, e) => + { + Assert.IsNotNull(changing.Item1); + Assert.IsNull(changed.Item1); + Assert.AreSame(model, s); + Assert.IsNotNull(s); + Assert.IsNotNull(e); + + changed = (e, model.Data); + }; + + model.Data = 42; + + Assert.AreEqual(changing.Item1?.PropertyName, nameof(SampleModel.Data)); + Assert.AreEqual(changing.Item2, 0); + Assert.AreEqual(changed.Item1?.PropertyName, nameof(SampleModel.Data)); + Assert.AreEqual(changed.Item2, 42); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_AlsoNotifyChangeForAttribute_Events() + { + var model = new DependentPropertyModel(); + + List propertyNames = new(); + + model.PropertyChanged += (s, e) => propertyNames.Add(e.PropertyName); + + model.Name = "Bob"; + model.Surname = "Ross"; + + CollectionAssert.AreEqual(new[] { nameof(model.Name), nameof(model.FullName), nameof(model.Surname), nameof(model.FullName) }, propertyNames); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ValidationAttributes() + { + var nameProperty = typeof(MyFormViewModel).GetProperty(nameof(MyFormViewModel.Name))!; + + Assert.IsNotNull(nameProperty.GetCustomAttribute()); + Assert.IsNotNull(nameProperty.GetCustomAttribute()); + Assert.AreEqual(nameProperty.GetCustomAttribute()!.Length, 1); + Assert.IsNotNull(nameProperty.GetCustomAttribute()); + Assert.AreEqual(nameProperty.GetCustomAttribute()!.Length, 100); + + var ageProperty = typeof(MyFormViewModel).GetProperty(nameof(MyFormViewModel.Age))!; + + Assert.IsNotNull(ageProperty.GetCustomAttribute()); + Assert.AreEqual(ageProperty.GetCustomAttribute()!.Minimum, 0); + Assert.AreEqual(ageProperty.GetCustomAttribute()!.Maximum, 120); + + var emailProperty = typeof(MyFormViewModel).GetProperty(nameof(MyFormViewModel.Email))!; + + Assert.IsNotNull(emailProperty.GetCustomAttribute()); + + var comboProperty = typeof(MyFormViewModel).GetProperty(nameof(MyFormViewModel.IfThisWorksThenThatsGreat))!; + + TestValidationAttribute testAttribute = comboProperty.GetCustomAttribute()!; + + Assert.IsNotNull(testAttribute); + Assert.IsNull(testAttribute.O); + Assert.AreEqual(testAttribute.T, typeof(SampleModel)); + Assert.AreEqual(testAttribute.Flag, true); + Assert.AreEqual(testAttribute.D, 6.28); + CollectionAssert.AreEqual(testAttribute.Names, new[] { "Bob", "Ross" }); + + object[] nestedArray = (object[])testAttribute.NestedArray; + + Assert.AreEqual(nestedArray.Length, 3); + Assert.AreEqual(nestedArray[0], 1); + Assert.AreEqual(nestedArray[1], "Hello"); + Assert.IsTrue(nestedArray[2] is int[]); + CollectionAssert.AreEqual((int[])nestedArray[2], new[] { 2, 3, 4 }); + + Assert.AreEqual(testAttribute.Animal, Animal.Llama); + } + + public partial class SampleModel : ObservableObject + { + /// + /// This is a sample data field within of type . + /// + [ObservableProperty] + private int data; + } + + [INotifyPropertyChanged] + public sealed partial class DependentPropertyModel + { + [ObservableProperty] + [AlsoNotifyChangeFor(nameof(FullName))] + private string? name; + + [ObservableProperty] + [AlsoNotifyChangeFor(nameof(FullName))] + private string? surname; + + public string FullName => $"{Name} {Surname}"; + } + + public partial class MyFormViewModel : ObservableValidator + { + [ObservableProperty] + [Required] + [MinLength(1)] + [MaxLength(100)] + private string? name; + + [ObservableProperty] + [Range(0, 120)] + private int age; + + [ObservableProperty] + [EmailAddress] + private string? email; + + [ObservableProperty] + [TestValidation(null, typeof(SampleModel), true, 6.28, new[] { "Bob", "Ross" }, NestedArray = new object[] { 1, "Hello", new int[] { 2, 3, 4 } }, Animal = Animal.Llama)] + private int ifThisWorksThenThatsGreat; + } + + private sealed class TestValidationAttribute : ValidationAttribute + { + public TestValidationAttribute(object? o, Type t, bool flag, double d, string[] names) + { + O = o; + T = t; + Flag = flag; + D = d; + Names = names; + } + + public object? O { get; } + + public Type T { get; } + + public bool Flag { get; } + + public double D { get; } + + public string[] Names { get; } + + public object NestedArray { get; set; } + + public Animal Animal { get; set; } + } + + public enum Animal + { + Cat, + Dog, + Llama + } + } +} diff --git a/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservableRecipientAttribute.cs b/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservableRecipientAttribute.cs new file mode 100644 index 00000000000..c2889fb4795 --- /dev/null +++ b/UnitTests/UnitTests.NetCore/Mvvm/Test_ObservableRecipientAttribute.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.ComponentModel; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using System.Reflection; +using Microsoft.Toolkit.Mvvm.ComponentModel; +using Microsoft.Toolkit.Mvvm.Messaging; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + public partial class Test_ObservableRecipientAttribute + { + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservableRecipientAttribute_Events() + { + var model = new Person(); + var args = new List(); + + model.PropertyChanged += (s, e) => args.Add(e); + + Assert.IsFalse(model.HasErrors); + + model.Name = "No"; + + Assert.IsTrue(model.HasErrors); + Assert.AreEqual(args.Count, 2); + Assert.AreEqual(args[0].PropertyName, nameof(Person.Name)); + Assert.AreEqual(args[1].PropertyName, nameof(INotifyDataErrorInfo.HasErrors)); + + model.Name = "Valid"; + + Assert.IsFalse(model.HasErrors); + Assert.AreEqual(args.Count, 4); + Assert.AreEqual(args[2].PropertyName, nameof(Person.Name)); + Assert.AreEqual(args[3].PropertyName, nameof(INotifyDataErrorInfo.HasErrors)); + + Assert.IsNotNull(typeof(Person).GetProperty("Messenger", BindingFlags.Instance | BindingFlags.NonPublic)); + Assert.AreEqual(typeof(Person).GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic).Length, 0); + } + + [ObservableRecipient] + public partial class Person : ObservableValidator + { + public Person() + { + Messenger = WeakReferenceMessenger.Default; + } + + private string name; + + [MinLength(4)] + [MaxLength(20)] + [Required] + public string Name + { + get => this.name; + set => SetProperty(ref this.name, value, true); + } + + public void TestCompile() + { + // Validates that the method Broadcast is correctly being generated + Broadcast(0, 1, nameof(TestCompile)); + } + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservableRecipientAttribute_AbstractConstructors() + { + var ctors = typeof(AbstractPerson).GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic); + + Assert.AreEqual(ctors.Length, 2); + Assert.IsTrue(ctors.All(static ctor => ctor.IsFamily)); + } + + [ObservableRecipient] + public abstract partial class AbstractPerson : ObservableObject + { + } + + [TestCategory("Mvvm")] + [TestMethod] + public void Test_ObservableRecipientAttribute_NonAbstractConstructors() + { + var ctors = typeof(NonAbstractPerson).GetConstructors(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + + Assert.AreEqual(ctors.Length, 2); + Assert.IsTrue(ctors.All(static ctor => ctor.IsPublic)); + } + + [ObservableRecipient] + public partial class NonAbstractPerson : ObservableObject + { + } + } +} diff --git a/UnitTests/UnitTests.NetCore/UnitTests.NetCore.csproj b/UnitTests/UnitTests.NetCore/UnitTests.NetCore.csproj index 32b82e9bf13..3f7353455ff 100644 --- a/UnitTests/UnitTests.NetCore/UnitTests.NetCore.csproj +++ b/UnitTests/UnitTests.NetCore/UnitTests.NetCore.csproj @@ -2,12 +2,14 @@ netcoreapp2.1;netcoreapp3.1;net5.0 + 9.0 + diff --git a/UnitTests/UnitTests.SourceGenerators/Test_SourceGeneratorsDiagnostics.cs b/UnitTests/UnitTests.SourceGenerators/Test_SourceGeneratorsDiagnostics.cs new file mode 100644 index 00000000000..1df3ac3ab44 --- /dev/null +++ b/UnitTests/UnitTests.SourceGenerators/Test_SourceGeneratorsDiagnostics.cs @@ -0,0 +1,277 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.ComponentModel.DataAnnotations; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.Toolkit.Mvvm.SourceGenerators; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace UnitTests.Mvvm +{ + [TestClass] + public class Test_SourceGeneratorsDiagnostics + { + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError_Explicit() + { + string source = @" + using System.ComponentModel; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace MyApp + { + [INotifyPropertyChanged] + public partial class SampleViewModel : INotifyPropertyChanged + { + public event PropertyChangedEventHandler? PropertyChanged; + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0004"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateINotifyPropertyChangedInterfaceForINotifyPropertyChangedAttributeError_Inherited() + { + string source = @" + using System.ComponentModel; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace Microsoft.Toolkit.Mvvm.ComponentModel + { + public abstract class ObservableObject : INotifyPropertyChanged, INotifyPropertyChanging + { + public event PropertyChangedEventHandler? PropertyChanged; + public event PropertyChangingEventHandler? PropertyChanging; + } + } + + namespace MyApp + { + [INotifyPropertyChanged] + public partial class SampleViewModel : ObservableObject + { + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0004"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError_Explicit() + { + string source = @" + using System.ComponentModel; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace MyApp + { + [ObservableObject] + public partial class SampleViewModel : INotifyPropertyChanged + { + public event PropertyChangedEventHandler? PropertyChanged; + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0005"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateINotifyPropertyChangedInterfaceForObservableObjectAttributeError_Inherited() + { + string source = @" + using System.ComponentModel; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace Microsoft.Toolkit.Mvvm.ComponentModel + { + public abstract class ObservableObject : INotifyPropertyChanged + { + public event PropertyChangedEventHandler? PropertyChanged; + } + } + + namespace MyApp + { + [ObservableObject] + public partial class SampleViewModel : ObservableObject + { + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0005"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError_Explicit() + { + string source = @" + using System.ComponentModel; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace MyApp + { + [ObservableObject] + public partial class SampleViewModel : INotifyPropertyChanging + { + public event PropertyChangingEventHandler? PropertyChanging; + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0006"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateINotifyPropertyChangingInterfaceForObservableObjectAttributeError_Inherited() + { + string source = @" + using System.ComponentModel; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace MyApp + { + public abstract class MyBaseViewModel : INotifyPropertyChanging + { + public event PropertyChangingEventHandler? PropertyChanging; + } + + [ObservableObject] + public partial class SampleViewModel : MyBaseViewModel + { + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0006"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void DuplicateObservableRecipientError() + { + string source = @" + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace Microsoft.Toolkit.Mvvm.ComponentModel + { + public abstract class ObservableRecipient : ObservableObject + { + } + } + + namespace MyApp + { + [ObservableRecipient] + public partial class SampleViewModel : ObservableRecipient + { + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0007"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void MissingBaseObservableObjectFunctionalityError() + { + string source = @" + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace MyApp + { + [ObservableRecipient] + public partial class SampleViewModel + { + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0008"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void MissingObservableValidatorInheritanceError() + { + string source = @" + using System.ComponentModel.DataAnnotations; + using Microsoft.Toolkit.Mvvm.ComponentModel; + + namespace MyApp + { + [INotifyPropertyChanged] + public partial class SampleViewModel + { + [ObservableProperty] + [Required] + private string name; + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0009"); + } + + [TestCategory("Mvvm")] + [TestMethod] + public void InvalidICommandMethodSignatureError() + { + string source = @" + using Microsoft.Toolkit.Mvvm.Input; + + namespace MyApp + { + public partial class SampleViewModel + { + [ICommand] + private string GreetUser() => ""Hello world!""; + } + }"; + + VerifyGeneratedDiagnostics(source, "MVVMTK0012"); + } + + /// + /// Verifies the output of a source generator. + /// + /// The generator type to use. + /// The input source to process. + /// The diagnostic ids to expect for the input source code. + private void VerifyGeneratedDiagnostics(string source, params string[] diagnosticsIds) + where TGenerator : class, ISourceGenerator, new() + { + Type validationAttributeType = typeof(ValidationAttribute); + + SyntaxTree syntaxTree = CSharpSyntaxTree.ParseText(source); + + IEnumerable references = + from assembly in AppDomain.CurrentDomain.GetAssemblies() + where !assembly.IsDynamic + let reference = MetadataReference.CreateFromFile(assembly.Location) + select reference; + + CSharpCompilation compilation = CSharpCompilation.Create("original", new SyntaxTree[] { syntaxTree }, references, new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + + ISourceGenerator generator = new TGenerator(); + + CSharpGeneratorDriver driver = CSharpGeneratorDriver.Create(generator); + + driver.RunGeneratorsAndUpdateCompilation(compilation, out Compilation outputCompilation, out ImmutableArray diagnostics); + + HashSet resultingIds = diagnostics.Select(diagnostic => diagnostic.Id).ToHashSet(); + + Assert.IsTrue(resultingIds.SetEquals(diagnosticsIds)); + + GC.KeepAlive(validationAttributeType); + } + } +} diff --git a/UnitTests/UnitTests.SourceGenerators/UnitTests.SourceGenerators.csproj b/UnitTests/UnitTests.SourceGenerators/UnitTests.SourceGenerators.csproj new file mode 100644 index 00000000000..6b580aa0997 --- /dev/null +++ b/UnitTests/UnitTests.SourceGenerators/UnitTests.SourceGenerators.csproj @@ -0,0 +1,20 @@ + + + + net5.0 + false + 9.0 + + + + + + + + + + + + + + diff --git a/Windows Community Toolkit (NET).slnf b/Windows Community Toolkit (NET).slnf index 82ef7943535..41a012b46b7 100644 --- a/Windows Community Toolkit (NET).slnf +++ b/Windows Community Toolkit (NET).slnf @@ -4,12 +4,14 @@ "projects": [ "Microsoft.Toolkit.Diagnostics\\Microsoft.Toolkit.Diagnostics.csproj", "Microsoft.Toolkit.HighPerformance\\Microsoft.Toolkit.HighPerformance.csproj", + "Microsoft.Toolkit.Mvvm.SourceGenerators\\Microsoft.Toolkit.Mvvm.SourceGenerators.csproj", "Microsoft.Toolkit.Mvvm\\Microsoft.Toolkit.Mvvm.csproj", "Microsoft.Toolkit\\Microsoft.Toolkit.csproj", "UnitTests\\UnitTests.HighPerformance.NetCore\\UnitTests.HighPerformance.NetCore.csproj", "UnitTests\\UnitTests.HighPerformance.Shared\\UnitTests.HighPerformance.Shared.shproj", "UnitTests\\UnitTests.NetCore\\UnitTests.NetCore.csproj", - "UnitTests\\UnitTests.Shared\\UnitTests.Shared.shproj" + "UnitTests\\UnitTests.Shared\\UnitTests.Shared.shproj", + "UnitTests\\UnitTests.SourceGenerators\\UnitTests.SourceGenerators.csproj", ] } } \ No newline at end of file diff --git a/Windows Community Toolkit.sln b/Windows Community Toolkit.sln index 991177c199b..9b0c8e72f87 100644 --- a/Windows Community Toolkit.sln +++ b/Windows Community Toolkit.sln @@ -157,6 +157,10 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.Toolkit.Uwp.UI.Co EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Toolkit.Uwp.UI.Controls", "Microsoft.Toolkit.Uwp.UI.Controls\Microsoft.Toolkit.Uwp.UI.Controls.csproj", "{099B60FD-DAD6-4648-9DE2-8DBF9DCD9557}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Toolkit.Mvvm.SourceGenerators", "Microsoft.Toolkit.Mvvm.SourceGenerators\Microsoft.Toolkit.Mvvm.SourceGenerators.csproj", "{E24D1146-5AD8-498F-A518-4890D8BF4937}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "UnitTests.SourceGenerators", "UnitTests\UnitTests.SourceGenerators\UnitTests.SourceGenerators.csproj", "{338C3BE4-2E71-4F21-AD30-03FDBB47A272}" +EndProject Global GlobalSection(SharedMSBuildProjectFiles) = preSolution UITests\UITests.Tests.Shared\UITests.Tests.Shared.projitems*{05c83067-fa46-45e2-bec4-edee84ad18d0}*SharedItemsImports = 4 @@ -1110,6 +1114,46 @@ Global {099B60FD-DAD6-4648-9DE2-8DBF9DCD9557}.Release|x64.Build.0 = Release|Any CPU {099B60FD-DAD6-4648-9DE2-8DBF9DCD9557}.Release|x86.ActiveCfg = Release|Any CPU {099B60FD-DAD6-4648-9DE2-8DBF9DCD9557}.Release|x86.Build.0 = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|ARM.ActiveCfg = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|ARM.Build.0 = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|ARM64.ActiveCfg = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|ARM64.Build.0 = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|x64.ActiveCfg = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|x64.Build.0 = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|x86.ActiveCfg = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Debug|x86.Build.0 = Debug|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|Any CPU.Build.0 = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|ARM.ActiveCfg = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|ARM.Build.0 = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|ARM64.ActiveCfg = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|ARM64.Build.0 = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|x64.ActiveCfg = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|x64.Build.0 = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|x86.ActiveCfg = Release|Any CPU + {E24D1146-5AD8-498F-A518-4890D8BF4937}.Release|x86.Build.0 = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|Any CPU.Build.0 = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|ARM.ActiveCfg = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|ARM.Build.0 = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|ARM64.ActiveCfg = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|ARM64.Build.0 = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|x64.ActiveCfg = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|x64.Build.0 = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|x86.ActiveCfg = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Debug|x86.Build.0 = Debug|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|Any CPU.ActiveCfg = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|Any CPU.Build.0 = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|ARM.ActiveCfg = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|ARM.Build.0 = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|ARM64.ActiveCfg = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|ARM64.Build.0 = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|x64.ActiveCfg = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|x64.Build.0 = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|x86.ActiveCfg = Release|Any CPU + {338C3BE4-2E71-4F21-AD30-03FDBB47A272}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1159,6 +1203,7 @@ Global {AF1BE4E9-E2E1-4729-B076-B3725D8E21EE} = {F1AFFFA7-28FE-4770-BA48-10D76F3E59BC} {3307BC1D-5D71-41C6-A1B3-B113B8242D08} = {F1AFFFA7-28FE-4770-BA48-10D76F3E59BC} {099B60FD-DAD6-4648-9DE2-8DBF9DCD9557} = {F1AFFFA7-28FE-4770-BA48-10D76F3E59BC} + {338C3BE4-2E71-4F21-AD30-03FDBB47A272} = {B30036C4-D514-4E5B-A323-587A061772CE} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {5403B0C4-F244-4F73-A35C-FE664D0F4345}