aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/clang/Basic/DiagnosticSemaKinds.td2
-rw-r--r--lib/Sema/SemaExpr.cpp14
-rw-r--r--test/SemaCUDA/kernel-call.cu8
3 files changed, 24 insertions, 0 deletions
diff --git a/include/clang/Basic/DiagnosticSemaKinds.td b/include/clang/Basic/DiagnosticSemaKinds.td
index 98523ee57f..1ff3c0f316 100644
--- a/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3121,6 +3121,8 @@ def err_kern_type_not_void_return : Error<
"kernel function type %0 must have void return type">;
def err_config_scalar_return : Error<
"CUDA special function 'cudaConfigureCall' must have scalar return type">;
+def err_kern_call_not_global_function : Error<
+ "kernel call to non-global function %0">;
def err_cannot_pass_objc_interface_to_vararg : Error<
diff --git a/lib/Sema/SemaExpr.cpp b/lib/Sema/SemaExpr.cpp
index 9e2b21aca2..1ba8ea62b4 100644
--- a/lib/Sema/SemaExpr.cpp
+++ b/lib/Sema/SemaExpr.cpp
@@ -4625,6 +4625,20 @@ Sema::BuildResolvedCallExpr(Expr *Fn, NamedDecl *NDecl,
return ExprError(Diag(LParenLoc, diag::err_typecheck_call_not_function)
<< Fn->getType() << Fn->getSourceRange());
+ if (getLangOptions().CUDA) {
+ if (Config) {
+ // CUDA: Kernel calls must be to global functions
+ if (FDecl && !FDecl->hasAttr<CUDAGlobalAttr>())
+ return ExprError(Diag(LParenLoc,diag::err_kern_call_not_global_function)
+ << FDecl->getName() << Fn->getSourceRange());
+
+ // CUDA: Kernel function must have 'void' return type
+ if (!FuncT->getResultType()->isVoidType())
+ return ExprError(Diag(LParenLoc, diag::err_kern_type_not_void_return)
+ << Fn->getType() << Fn->getSourceRange());
+ }
+ }
+
// Check for a valid return type
if (CheckCallReturnType(FuncT->getResultType(),
Fn->getSourceRange().getBegin(), TheCall,
diff --git a/test/SemaCUDA/kernel-call.cu b/test/SemaCUDA/kernel-call.cu
index 6d51695522..7bc7ae1131 100644
--- a/test/SemaCUDA/kernel-call.cu
+++ b/test/SemaCUDA/kernel-call.cu
@@ -8,8 +8,16 @@ template <typename T> void t1(T arg) {
g1<<<arg, arg>>>(1);
}
+void h1(int x) {}
+int h2(int x) { return 1; }
+
int main(void) {
g1<<<1, 1>>>(42);
t1(1);
+
+ h1<<<1, 1>>>(42); // expected-error {{kernel call to non-global function h1}}
+
+ int (*fp)(int) = h2;
+ fp<<<1, 1>>>(42); // expected-error {{must have void return type}}
}