Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { ensureProductIdOrInlineProduct } from "@/lib/payments";
import { validateRedirectUrl } from "@/lib/redirect-urls";
import { getStripeForAccount } from "@/lib/stripe";
import { globalPrismaClient } from "@/prisma-client";
import { createSmartRouteHandler } from "@/route-handlers/smart-route-handler";
import { CustomerType } from "@prisma/client";
import { KnownErrors } from "@stackframe/stack-shared/dist/known-errors";
import { adaptSchema, clientOrHigherAuthTypeSchema, inlineProductSchema, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { adaptSchema, clientOrHigherAuthTypeSchema, inlineProductSchema, urlSchema, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { getEnvVariable } from "@stackframe/stack-shared/dist/utils/env";
import { throwErr } from "@stackframe/stack-shared/dist/utils/errors";
import { purchaseUrlVerificationCodeHandler } from "../verification-code-handler";
Expand All @@ -24,6 +25,7 @@ export const POST = createSmartRouteHandler({
customer_id: yupString().defined(),
product_id: yupString().optional(),
product_inline: inlineProductSchema.optional(),
return_url: urlSchema.optional(),
}),
}),
response: yupObject({
Expand Down Expand Up @@ -77,6 +79,12 @@ export const POST = createSmartRouteHandler({

const fullCode = `${tenancy.id}_${code}`;
const url = new URL(`/purchase/${fullCode}`, getEnvVariable("NEXT_PUBLIC_STACK_DASHBOARD_URL"));
if (req.body.return_url) {
if (!validateRedirectUrl(req.body.return_url, tenancy)) {
throw new KnownErrors.RedirectUrlNotWhitelisted();
}
url.searchParams.set("return_url", req.body.return_url);
}

return {
statusCode: 200,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import { getSubscriptions, isActiveSubscription } from "@/lib/payments";
import { validateRedirectUrl } from "@/lib/redirect-urls";
import { getTenancy } from "@/lib/tenancies";
import { getPrismaClientForTenancy } from "@/prisma-client";
import { createSmartRouteHandler } from "@/route-handlers/smart-route-handler";
import { inlineProductSchema, yupArray, yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { KnownErrors } from "@stackframe/stack-shared";
import { inlineProductSchema, urlSchema, yupArray, yupBoolean, yupNumber, yupObject, yupString } from "@stackframe/stack-shared/dist/schema-fields";
import { SUPPORTED_CURRENCIES } from "@stackframe/stack-shared/dist/utils/currency-constants";
import { StackAssertionError } from "@stackframe/stack-shared/dist/utils/errors";
import { filterUndefined, getOrUndefined, typedEntries, typedFromEntries } from "@stackframe/stack-shared/dist/utils/objects";
Expand All @@ -22,6 +24,7 @@ export const POST = createSmartRouteHandler({
request: yupObject({
body: yupObject({
full_code: yupString().defined(),
return_url: urlSchema.optional(),
}),
}),
response: yupObject({
Expand All @@ -44,6 +47,9 @@ export const POST = createSmartRouteHandler({
if (!tenancy) {
throw new StackAssertionError(`No tenancy found for given tenancyId`);
}
if (body.return_url && !validateRedirectUrl(body.return_url, tenancy)) {
throw new KnownErrors.RedirectUrlNotWhitelisted();
}
const product = verificationCode.data.product;
const productData: yup.InferType<typeof productDataSchema> = {
display_name: product.displayName ?? "Product",
Expand Down Expand Up @@ -100,3 +106,43 @@ export const POST = createSmartRouteHandler({
};
},
});


export const GET = createSmartRouteHandler({
metadata: {
hidden: true,
},
request: yupObject({
query: yupObject({
full_code: yupString().defined(),
return_url: urlSchema.optional(),
}),
}),
response: yupObject({
statusCode: yupNumber().oneOf([200]).defined(),
bodyType: yupString().oneOf(["json"]).defined(),
body: yupObject({
valid: yupBoolean().defined(),
}).defined(),
}),
async handler({ query }) {
const tenancyId = query.full_code.split("_")[0];
if (!tenancyId) {
throw new KnownErrors.VerificationCodeNotFound();
}
const tenancy = await getTenancy(tenancyId);
if (!tenancy) {
throw new KnownErrors.VerificationCodeNotFound();
}
if (query.return_url && !validateRedirectUrl(query.return_url, tenancy)) {
throw new KnownErrors.RedirectUrlNotWhitelisted();
}
return {
statusCode: 200,
bodyType: "json",
body: {
valid: true,
},
};
},
});
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ function TeamAddUserDialog(props: {
const onSubmit = async (values: yup.InferType<typeof inviteFormSchema>) => {
if (users.length + 1 > quantity) {
alert("You have reached the maximum number of dashboard admins. Please upgrade your plan to add more admins.");
const checkoutUrl = await props.team.createCheckoutUrl({ productId: "team" });
window.open(checkoutUrl, "_blank", "noopener");
const checkoutUrl = await props.team.createCheckoutUrl({
productId: "team",
returnUrl: window.location.href,
});
window.location.assign(checkoutUrl);
return "prevent-close-and-prevent-reset";
}
await props.onSubmit(values.email);
Expand Down
16 changes: 13 additions & 3 deletions apps/dashboard/src/app/(main)/purchase/[code]/page-client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { typedEntries } from "@stackframe/stack-shared/dist/utils/objects";
import { runAsynchronouslyWithAlert } from "@stackframe/stack-shared/dist/utils/promises";
import { Alert, AlertDescription, AlertTitle, Button, Card, CardContent, Input, Skeleton, Typography } from "@stackframe/stack-ui";
import { ArrowRight, Minus, Plus } from "lucide-react";
import { useSearchParams } from "next/navigation";
import { useCallback, useEffect, useMemo, useState } from "react";
import * as yup from "yup";

Expand All @@ -30,8 +31,10 @@ export default function PageClient({ code }: { code: string }) {
const [error, setError] = useState<string | null>(null);
const [selectedPriceId, setSelectedPriceId] = useState<string | null>(null);
const [quantityInput, setQuantityInput] = useState<string>("1");
const searchParams = useSearchParams();
const user = useUser({ projectIdMustMatch: "internal" });
const [adminApp, setAdminApp] = useState<StackAdminApp>();
const returnUrl = searchParams.get("return_url");

useEffect(() => {
if (!user || !data) return;
Expand Down Expand Up @@ -92,7 +95,10 @@ export default function PageClient({ code }: { code: string }) {
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ full_code: code }),
body: JSON.stringify({
full_code: code,
return_url: returnUrl ?? undefined,
}),
});
if (!response.ok) {
throw new Error('Failed to validate code');
Expand All @@ -103,7 +109,7 @@ export default function PageClient({ code }: { code: string }) {
const firstPriceId = Object.keys(result.product.prices)[0];
setSelectedPriceId(firstPriceId);
}
}, [code]);
}, [code, returnUrl]);

useEffect(() => {
setLoading(true);
Expand Down Expand Up @@ -138,8 +144,11 @@ export default function PageClient({ code }: { code: string }) {
const url = new URL(`/purchase/return`, window.location.origin);
url.searchParams.set("bypass", "1");
url.searchParams.set("purchase_full_code", code);
if (returnUrl) {
url.searchParams.set("return_url", returnUrl);
}
window.location.assign(url.toString());
}, [code, adminApp, selectedPriceId, quantityNumber, isTooLarge]);
}, [code, adminApp, selectedPriceId, quantityNumber, isTooLarge, returnUrl]);

return (
<div className="flex flex-row">
Expand Down Expand Up @@ -281,6 +290,7 @@ export default function PageClient({ code }: { code: string }) {
fullCode={code}
stripeAccountId={data.stripe_account_id}
setupSubscription={setupSubscription}
returnUrl={returnUrl ?? undefined}
disabled={quantityNumber < 1 || isTooLarge || data.already_bought_non_stackable === true}
/>
</StripeElementsProvider>
Expand Down
29 changes: 26 additions & 3 deletions apps/dashboard/src/app/(main)/purchase/return/page-client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import { StyledLink } from "@/components/link";
import { getPublicEnvVar } from "@/lib/env";
import { throwErr } from "@stackframe/stack-shared/dist/utils/errors";
import { runAsynchronously } from "@stackframe/stack-shared/dist/utils/promises";
import { Typography } from "@stackframe/stack-ui";
import { loadStripe } from "@stripe/stripe-js";
import { useSearchParams } from "next/navigation";
import { useCallback, useEffect, useState } from "react";

type Props = {
Expand All @@ -22,14 +24,33 @@ type ViewState =
| { kind: "error", message: string };

const stripePublicKey = getPublicEnvVar("NEXT_PUBLIC_STACK_STRIPE_PUBLISHABLE_KEY") ?? "";
const apiUrl = getPublicEnvVar("NEXT_PUBLIC_STACK_API_URL") ?? throwErr("NEXT_PUBLIC_STACK_API_URL is not set");
const baseUrl = new URL("/api/v1", apiUrl).toString();

export default function ReturnClient({ clientSecret, stripeAccountId, purchaseFullCode, bypass }: Props) {
const [state, setState] = useState<ViewState>({ kind: "loading" });
const searchParams = useSearchParams();
const returnUrl = searchParams.get("return_url");

const checkAndReturnUser = useCallback(async () => {
if (!returnUrl || !purchaseFullCode) {
return;
}
const url = new URL(`${baseUrl}/payments/purchases/validate-code`);
url.searchParams.set("full_code", purchaseFullCode);
url.searchParams.set("return_url", returnUrl);
const response = await fetch(url);
if (response.ok) {
window.location.assign(returnUrl);
}
}, [returnUrl, purchaseFullCode]);

const updateViewState = useCallback(async (): Promise<void> => {
try {
if (bypass === "1") {
setState({ kind: "success", message: "Bypassed in test mode. No payment processed." });
runAsynchronously(checkAndReturnUser());
const message = `Bypassed in test mode. No payment processed.${returnUrl ? " You will be redirected shortly." : ""}`;
setState({ kind: "success", message });
return;
}
const stripe = await loadStripe(stripePublicKey, { stripeAccount: stripeAccountId });
Expand All @@ -40,7 +61,9 @@ export default function ReturnClient({ clientSecret, stripeAccountId, purchaseFu
const lastErrorMessage = result.paymentIntent?.last_payment_error?.message;

if (status === "succeeded") {
setState({ kind: "success", message: "Payment succeeded. You can close this page." });
runAsynchronously(checkAndReturnUser());
const message = `Payment succeeded.${returnUrl ? " You will be redirected shortly." : " You can safely close this page."}`;
setState({ kind: "success", message });
return;
}
if (status === "processing") {
Expand All @@ -64,7 +87,7 @@ export default function ReturnClient({ clientSecret, stripeAccountId, purchaseFu
const message = e instanceof Error ? e.message : "Unexpected error retrieving payment.";
setState({ kind: "error", message });
}
}, [clientSecret, stripeAccountId, bypass]);
}, [clientSecret, stripeAccountId, bypass, returnUrl, checkAndReturnUser]);

useEffect(() => {
runAsynchronously(updateViewState());
Expand Down
14 changes: 9 additions & 5 deletions apps/dashboard/src/components/payments/checkout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ type Props = {
setupSubscription: () => Promise<string>,
stripeAccountId: string,
fullCode: string,
returnUrl?: string,
disabled?: boolean,
};

export function CheckoutForm({ setupSubscription, stripeAccountId, fullCode, disabled }: Props) {
export function CheckoutForm({ setupSubscription, stripeAccountId, fullCode, returnUrl, disabled }: Props) {
const stripe = useStripe();
const elements = useElements();
const [message, setMessage] = useState<string | null>(null);
Expand All @@ -39,15 +40,18 @@ export function CheckoutForm({ setupSubscription, stripeAccountId, fullCode, dis
}

const clientSecret = await setupSubscription();
const returnUrl = new URL(`/purchase/return`, window.location.origin);
returnUrl.searchParams.set("stripe_account_id", stripeAccountId);
returnUrl.searchParams.set("purchase_full_code", fullCode);
const stripeReturnUrl = new URL(`/purchase/return`, window.location.origin);
stripeReturnUrl.searchParams.set("stripe_account_id", stripeAccountId);
stripeReturnUrl.searchParams.set("purchase_full_code", fullCode);
if (returnUrl) {
stripeReturnUrl.searchParams.set("return_url", returnUrl);
}

const { error } = await stripe.confirmPayment({
elements,
clientSecret,
confirmParams: {
return_url: returnUrl.toString(),
return_url: stripeReturnUrl.toString(),
},
}) as { error?: StripeError };

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { generateUuid } from "@stackframe/stack-shared/dist/utils/uuids";
import { it } from "../../../../../helpers";
import { Auth, Payments, Project, User, niceBackendFetch } from "../../../../backend-helpers";
import { Auth, niceBackendFetch, Payments, Project, User } from "../../../../backend-helpers";

it("should not be able to create purchase URL without product_id or product_inline", async ({ expect }) => {
await Project.createAndSwitch();
Expand Down Expand Up @@ -309,9 +309,62 @@ it("should allow valid product_id", async ({ expect }) => {
customer_type: "user",
customer_id: userId,
product_id: "test-product",
return_url: "http://stack-test.localhost/after-purchase",
},
});
expect(response.status).toBe(200);
const body = response.body as { url: string };
expect(body.url).toMatch(/^https?:\/\/localhost:8101\/purchase\/[a-z0-9-_]+$/);
expect(body.url).toMatch(/^https?:\/\/localhost:8101\/purchase\/[a-z0-9-_]+\?return_url=/);
const urlObj = new URL(body.url);
const returnUrl = urlObj.searchParams.get("return_url");
expect(returnUrl).toBe("http://stack-test.localhost/after-purchase");
});

it("should error for untrusted return_url", async ({ expect }) => {
await Project.createAndSwitch({ config: { magic_link_enabled: true } });
await Payments.setup();
await Project.updateConfig({
payments: {
products: {
"test-product": {
displayName: "Test Product",
customerType: "user",
serverOnly: false,
stackable: false,
prices: {
"monthly": {
USD: "1000",
interval: [1, "month"],
},
},
includedItems: {},
},
},
},
});

const { userId } = await User.create();
const response = await niceBackendFetch("/api/latest/payments/purchases/create-purchase-url", {
method: "POST",
accessType: "client",
body: {
customer_type: "user",
customer_id: userId,
product_id: "test-product",
return_url: "https://malicious.com/callback",
},
});
expect(response).toMatchInlineSnapshot(`
NiceResponse {
"status": 400,
"body": {
"code": "REDIRECT_URL_NOT_WHITELISTED",
"error": "Redirect URL not whitelisted. Did you forget to add this domain to the trusted domains list on the Stack Auth dashboard?",
},
"headers": Headers {
"x-stack-known-error": "REDIRECT_URL_NOT_WHITELISTED",
<some fields may have been hidden>,
},
}
`);
});
Loading