@@ -1080,8 +1080,7 @@ static ParamDesc makeParamDesc(ASTContext &Ctx, StringRef Name, QualType Ty) {
10801080}
10811081
10821082static void unsupportedFreeFunctionParamType () {
1083- llvm::report_fatal_error (" Only scalars and pointers are permitted as "
1084- " free function parameters" );
1083+ llvm::report_fatal_error (" Unsupported free kernel parameter type!" );
10851084}
10861085
10871086class MarkWIScopeFnVisitor : public RecursiveASTVisitor <MarkWIScopeFnVisitor> {
@@ -2080,13 +2079,7 @@ class SyclKernelFieldChecker : public SyclKernelFieldHandler {
20802079 }
20812080
20822081 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2083- if (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory) &&
2084- !SemaSYCL::isSyclType (ParamTy,
2085- SYCLTypeAttr::dynamic_work_group_memory)) {
2086- Diag.Report (PD->getLocation (), diag::err_bad_kernel_param_type)
2087- << ParamTy;
2088- IsInvalid = true ;
2089- }
2082+ IsInvalid |= checkSyclSpecialType (ParamTy, PD->getLocation ());
20902083 return isValid ();
20912084 }
20922085
@@ -2238,10 +2231,7 @@ class SyclKernelUnionChecker : public SyclKernelFieldHandler {
22382231 }
22392232
22402233 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
2241- if (!SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory) &&
2242- !SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::dynamic_work_group_memory))
2243- unsupportedFreeFunctionParamType (); // TODO
2244- return true ;
2234+ return checkType (PD->getLocation (), ParamTy);
22452235 }
22462236
22472237 bool handleSyclSpecialType (const CXXRecordDecl *, const CXXBaseSpecifier &BS,
@@ -2830,23 +2820,34 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
28302820 // kernel parameters from __init method parameters. We will use __init method
28312821 // and kernel parameters which we build here to initialize special objects in
28322822 // the kernel body.
2833- bool handleSpecialType (FieldDecl *FD, QualType FieldTy) {
2834- const auto *RecordDecl = FieldTy->getAsCXXRecordDecl ();
2835- assert (RecordDecl && " The type must be a RecordDecl" );
2823+ // ParentDecl parameterizes whether we are in a free function kernel or a
2824+ // lambda kernel by taking the value ParmVarDecl or FieldDecl respectively.
2825+ template <typename ParentDecl>
2826+ bool handleSpecialType (ParentDecl *decl, QualType Ty) {
2827+ const auto *RD = Ty->getAsCXXRecordDecl ();
2828+ assert (RD && " The type must be a RecordDecl" );
28362829 llvm::StringLiteral MethodName =
2837- KernelDecl->hasAttr <SYCLSimdAttr>() && isSyclAccessorType (FieldTy )
2830+ KernelDecl->hasAttr <SYCLSimdAttr>() && isSyclAccessorType (Ty )
28382831 ? InitESIMDMethodName
28392832 : InitMethodName;
2840- CXXMethodDecl *InitMethod = getMethodByName (RecordDecl , MethodName);
2833+ CXXMethodDecl *InitMethod = getMethodByName (RD , MethodName);
28412834 assert (InitMethod && " The type must have the __init method" );
28422835
28432836 // Don't do -1 here because we count on this to be the first parameter added
28442837 // (if any).
28452838 size_t ParamIndex = Params.size ();
28462839 for (const ParmVarDecl *Param : InitMethod->parameters ()) {
28472840 QualType ParamTy = Param->getType ();
2848- addParam (FD, ParamTy.getCanonicalType ());
2849-
2841+ // For lambda kernels the arguments to the OpenCL kernel are named
2842+ // based on the position they have as fields in the definition of the
2843+ // special type structure i.e __arg_field1, __arg_field2 and so on.
2844+ // For free function kernels the arguments are named in direct mapping
2845+ // with the names they have in the __init method i.e __arg_Ptr for work
2846+ // group memory since its init function takes a parameter with Ptr name.
2847+ if constexpr (std::is_same_v<ParentDecl, FieldDecl>)
2848+ addParam (decl, ParamTy.getCanonicalType ());
2849+ else
2850+ addParam (Param, ParamTy.getCanonicalType ());
28502851 // Propagate add_ir_attributes_kernel_parameter attribute.
28512852 if (const auto *AddIRAttr =
28522853 Param->getAttr <SYCLAddIRAttributesKernelParameterAttr>())
@@ -2858,8 +2859,8 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
28582859 // handleAccessorPropertyList. If new classes with property list are
28592860 // added, this code needs to be refactored to call
28602861 // handleAccessorPropertyList for each class which requires it.
2861- if (ParamTy.getTypePtr ()->isPointerType () && isSyclAccessorType (FieldTy ))
2862- handleAccessorType (FieldTy, RecordDecl, FD ->getBeginLoc ());
2862+ if (ParamTy.getTypePtr ()->isPointerType () && isSyclAccessorType (Ty ))
2863+ handleAccessorType (Ty, RD, decl ->getBeginLoc ());
28632864 }
28642865 LastParamIndex = ParamIndex;
28652866 return true ;
@@ -3026,28 +3027,7 @@ class SyclKernelDeclCreator : public SyclKernelFieldHandler {
30263027 }
30273028
30283029 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
3029- if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory) ||
3030- SemaSYCL::isSyclType (ParamTy,
3031- SYCLTypeAttr::dynamic_work_group_memory)) {
3032- const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
3033- assert (RecordDecl && " The type must be a RecordDecl" );
3034- CXXMethodDecl *InitMethod = getMethodByName (RecordDecl, InitMethodName);
3035- assert (InitMethod && " The type must have the __init method" );
3036- // Don't do -1 here because we count on this to be the first parameter
3037- // added (if any).
3038- size_t ParamIndex = Params.size ();
3039- for (const ParmVarDecl *Param : InitMethod->parameters ()) {
3040- QualType ParamTy = Param->getType ();
3041- addParam (Param, ParamTy.getCanonicalType ());
3042- // Propagate add_ir_attributes_kernel_parameter attribute.
3043- if (const auto *AddIRAttr =
3044- Param->getAttr <SYCLAddIRAttributesKernelParameterAttr>())
3045- Params.back ()->addAttr (AddIRAttr->clone (SemaSYCLRef.getASTContext ()));
3046- }
3047- LastParamIndex = ParamIndex;
3048- } else // TODO
3049- unsupportedFreeFunctionParamType ();
3050- return true ;
3030+ return handleSpecialType (PD, ParamTy);
30513031 }
30523032
30533033 RecordDecl *wrapField (FieldDecl *Field, QualType FieldTy) {
@@ -4540,47 +4520,48 @@ class FreeFunctionKernelBodyCreator : public SyclKernelFieldHandler {
45404520 // TODO: Revisit this approach once https://github.com/intel/llvm/issues/16061
45414521 // is closed.
45424522 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4543- if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory) ||
4544- SemaSYCL::isSyclType (ParamTy,
4545- SYCLTypeAttr::dynamic_work_group_memory)) {
4546- const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
4547- AccessSpecifier DefaultConstructorAccess;
4548- auto DefaultConstructor =
4549- std::find_if (RecordDecl->ctor_begin (), RecordDecl->ctor_end (),
4550- [](auto it) { return it->isDefaultConstructor (); });
4551- DefaultConstructorAccess = DefaultConstructor->getAccess ();
4552- DefaultConstructor->setAccess (AS_public);
4553-
4554- QualType Ty = PD->getOriginalType ();
4555- ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4556- VarDecl *WorkGroupMemoryClone = VarDecl::Create (
4557- Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc,
4558- FreeFunctionSrcLoc, PD->getIdentifier (), PD->getType (),
4559- Ctx.getTrivialTypeSourceInfo (Ty), SC_None);
4560- InitializedEntity VarEntity =
4561- InitializedEntity::InitializeVariable (WorkGroupMemoryClone);
4562- InitializationKind InitKind =
4563- InitializationKind::CreateDefault (FreeFunctionSrcLoc);
4564- InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
4565- std::nullopt );
4566- ExprResult Init = InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity,
4567- InitKind, std::nullopt );
4568- WorkGroupMemoryClone->setInit (
4569- SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4570- WorkGroupMemoryClone->setInitStyle (VarDecl::CallInit);
4571- DefaultConstructor->setAccess (DefaultConstructorAccess);
4572-
4573- Stmt *DS = new (SemaSYCLRef.getASTContext ())
4574- DeclStmt (DeclGroupRef (WorkGroupMemoryClone), FreeFunctionSrcLoc,
4575- FreeFunctionSrcLoc);
4576- BodyStmts.push_back (DS);
4577- Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4578- WorkGroupMemoryClone, Ty, VK_PRValue, FreeFunctionSrcLoc);
4579- createSpecialMethodCall (RecordDecl, InitMethodName, MemberBaseExpr,
4580- BodyStmts);
4581- ArgExprs.push_back (MemberBaseExpr);
4582- } else // TODO
4583- unsupportedFreeFunctionParamType ();
4523+ // The code produced looks like this in the case of a work group memory
4524+ // parameter:
4525+ // void auto_generated_kernel(__local int * arg) {
4526+ // work_group_memory wgm;
4527+ // wgm.__init(arg);
4528+ // user_kernel(some arguments..., wgm, some arguments...);
4529+ // }
4530+ const auto *RecordDecl = ParamTy->getAsCXXRecordDecl ();
4531+ AccessSpecifier DefaultConstructorAccess;
4532+ auto DefaultConstructor =
4533+ std::find_if (RecordDecl->ctor_begin (), RecordDecl->ctor_end (),
4534+ [](auto it) { return it->isDefaultConstructor (); });
4535+ DefaultConstructorAccess = DefaultConstructor->getAccess ();
4536+ DefaultConstructor->setAccess (AS_public);
4537+
4538+ ASTContext &Ctx = SemaSYCLRef.SemaRef .getASTContext ();
4539+ VarDecl *SpecialObjectClone =
4540+ VarDecl::Create (Ctx, DeclCreator.getKernelDecl (), FreeFunctionSrcLoc,
4541+ FreeFunctionSrcLoc, PD->getIdentifier (), ParamTy,
4542+ Ctx.getTrivialTypeSourceInfo (ParamTy), SC_None);
4543+ InitializedEntity VarEntity =
4544+ InitializedEntity::InitializeVariable (SpecialObjectClone);
4545+ InitializationKind InitKind =
4546+ InitializationKind::CreateDefault (FreeFunctionSrcLoc);
4547+ InitializationSequence InitSeq (SemaSYCLRef.SemaRef , VarEntity, InitKind,
4548+ std::nullopt );
4549+ ExprResult Init =
4550+ InitSeq.Perform (SemaSYCLRef.SemaRef , VarEntity, InitKind, std::nullopt );
4551+ SpecialObjectClone->setInit (
4552+ SemaSYCLRef.SemaRef .MaybeCreateExprWithCleanups (Init.get ()));
4553+ SpecialObjectClone->setInitStyle (VarDecl::CallInit);
4554+ DefaultConstructor->setAccess (DefaultConstructorAccess);
4555+
4556+ Stmt *DS = new (SemaSYCLRef.getASTContext ())
4557+ DeclStmt (DeclGroupRef (SpecialObjectClone), FreeFunctionSrcLoc,
4558+ FreeFunctionSrcLoc);
4559+ BodyStmts.push_back (DS);
4560+ Expr *MemberBaseExpr = SemaSYCLRef.SemaRef .BuildDeclRefExpr (
4561+ SpecialObjectClone, ParamTy, VK_PRValue, FreeFunctionSrcLoc);
4562+ createSpecialMethodCall (RecordDecl, InitMethodName, MemberBaseExpr,
4563+ BodyStmts);
4564+ ArgExprs.push_back (MemberBaseExpr);
45844565 return true ;
45854566 }
45864567
@@ -4874,14 +4855,45 @@ class SyclKernelIntHeaderCreator : public SyclKernelFieldHandler {
48744855 }
48754856
48764857 bool handleSyclSpecialType (ParmVarDecl *PD, QualType ParamTy) final {
4877- if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory))
4858+ const auto *ClassTy = ParamTy->getAsCXXRecordDecl ();
4859+ assert (ClassTy && " Type must be a C++ record type" );
4860+ if (isSyclAccessorType (ParamTy)) {
4861+ const auto *AccTy =
4862+ cast<ClassTemplateSpecializationDecl>(ParamTy->getAsRecordDecl ());
4863+ assert (AccTy->getTemplateArgs ().size () >= 2 &&
4864+ " Incorrect template args for Accessor Type" );
4865+ int Dims = static_cast <int >(
4866+ AccTy->getTemplateArgs ()[1 ].getAsIntegral ().getExtValue ());
4867+ int Info = getAccessTarget (ParamTy, AccTy) | (Dims << 11 );
4868+ Header.addParamDesc (SYCLIntegrationHeader::kind_accessor, Info,
4869+ CurOffset);
4870+ } else if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::stream)) {
4871+ addParam (PD, ParamTy, SYCLIntegrationHeader::kind_stream);
4872+ } else if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::work_group_memory)) {
48784873 addParam (PD, ParamTy, SYCLIntegrationHeader::kind_work_group_memory);
4879- else if (SemaSYCL::isSyclType (ParamTy,
4880- SYCLTypeAttr::dynamic_work_group_memory))
4874+ } else if (SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::sampler) ||
4875+ SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::annotated_ptr) ||
4876+ SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::annotated_arg)) {
4877+ CXXMethodDecl *InitMethod = getMethodByName (ClassTy, InitMethodName);
4878+ assert (InitMethod && " type must have __init method" );
4879+ const ParmVarDecl *InitArg = InitMethod->getParamDecl (0 );
4880+ assert (InitArg && " Init method must have arguments" );
4881+ QualType T = InitArg->getType ();
4882+ SYCLIntegrationHeader::kernel_param_kind_t ParamKind =
4883+ SemaSYCL::isSyclType (ParamTy, SYCLTypeAttr::sampler)
4884+ ? SYCLIntegrationHeader::kind_sampler
4885+ : (T->isPointerType () ? SYCLIntegrationHeader::kind_pointer
4886+ : SYCLIntegrationHeader::kind_std_layout);
4887+ addParam (PD, ParamTy, ParamKind);
4888+ } else if (SemaSYCL::isSyclType (ParamTy,
4889+ SYCLTypeAttr::dynamic_work_group_memory))
48814890 addParam (PD, ParamTy,
48824891 SYCLIntegrationHeader::kind_dynamic_work_group_memory);
4883- else
4884- unsupportedFreeFunctionParamType (); // TODO
4892+
4893+ else {
4894+ llvm_unreachable (
4895+ " Unexpected SYCL special class when generating integration header" );
4896+ }
48854897 return true ;
48864898 }
48874899
@@ -6666,6 +6678,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
66666678 O << " #include <sycl/detail/defines_elementary.hpp>\n " ;
66676679 O << " #include <sycl/detail/kernel_desc.hpp>\n " ;
66686680 O << " #include <sycl/ext/oneapi/experimental/free_function_traits.hpp>\n " ;
6681+ O << " #include <sycl/access/access.hpp>\n " ;
66696682 O << " \n " ;
66706683
66716684 LangOptions LO;
@@ -6977,6 +6990,7 @@ void SYCLIntegrationHeader::emit(raw_ostream &O) {
69776990 Policy.PolishForDeclaration = true ;
69786991 Policy.FullyQualifiedName = true ;
69796992 Policy.EnforceScopeForElaboratedTypes = true ;
6993+ Policy.UseFullyQualifiedEnumerators = true ;
69806994
69816995 // Now we need to print the declaration of the kernel itself.
69826996 // Example:
0 commit comments