feat: restrict non credential provider interactions (#871)

* wip: add provider field to sqlite user table

* feat: disable invites when credentials provider is not used

* wip: add migration for provider field in user table with sqlite

* wip: remove fields that can not be modified by non credential users

* wip: make username, mail and avatar disabled instead of hidden

* wip: external users membership of group cannot be managed manually

* feat: add alerts to inform about disabled fields and managing group members

* wip: add mysql migration for provider on user table

* chore: fix format issues

* chore: address pull request feedback

* fix: build issue

* fix: deepsource issues

* fix: tests not working

* feat: restrict login to specific auth providers

* chore: address pull request feedback

* fix: deepsource issue
This commit is contained in:
Meier Lukas
2024-07-27 11:38:51 +02:00
committed by GitHub
parent eba4052522
commit 6f7327b774
36 changed files with 2989 additions and 116 deletions

View File

@@ -2,6 +2,7 @@ import { notFound } from "next/navigation";
import { Card, Center, Stack, Text, Title } from "@mantine/core"; import { Card, Center, Stack, Text, Title } from "@mantine/core";
import { auth } from "@homarr/auth/next"; import { auth } from "@homarr/auth/next";
import { isProviderEnabled } from "@homarr/auth/server";
import { and, db, eq } from "@homarr/db"; import { and, db, eq } from "@homarr/db";
import { invites } from "@homarr/db/schema/sqlite"; import { invites } from "@homarr/db/schema/sqlite";
import { getScopedI18n } from "@homarr/translation/server"; import { getScopedI18n } from "@homarr/translation/server";
@@ -19,6 +20,8 @@ interface InviteUsagePageProps {
} }
export default async function InviteUsagePage({ params, searchParams }: InviteUsagePageProps) { export default async function InviteUsagePage({ params, searchParams }: InviteUsagePageProps) {
if (!isProviderEnabled("credentials")) notFound();
const session = await auth(); const session = await auth();
if (session) notFound(); if (session) notFound();

View File

@@ -22,6 +22,7 @@ import {
IconUsersGroup, IconUsersGroup,
} from "@tabler/icons-react"; } from "@tabler/icons-react";
import { isProviderEnabled } from "@homarr/auth/server";
import { getScopedI18n } from "@homarr/translation/server"; import { getScopedI18n } from "@homarr/translation/server";
import { MainHeader } from "~/components/layout/header"; import { MainHeader } from "~/components/layout/header";
@@ -65,6 +66,7 @@ export default async function ManageLayout({ children }: PropsWithChildren) {
label: t("items.users.items.invites"), label: t("items.users.items.invites"),
icon: IconMailForward, icon: IconMailForward,
href: "/manage/users/invites", href: "/manage/users/invites",
hidden: !isProviderEnabled("credentials"),
}, },
{ {
label: t("items.users.items.groups"), label: t("items.users.items.groups"),

View File

@@ -3,6 +3,7 @@ import { Card, Group, SimpleGrid, Space, Stack, Text } from "@mantine/core";
import { IconArrowRight } from "@tabler/icons-react"; import { IconArrowRight } from "@tabler/icons-react";
import { api } from "@homarr/api/server"; import { api } from "@homarr/api/server";
import { isProviderEnabled } from "@homarr/auth/server";
import { getScopedI18n } from "@homarr/translation/server"; import { getScopedI18n } from "@homarr/translation/server";
import { DynamicBreadcrumb } from "~/components/navigation/dynamic-breadcrumb"; import { DynamicBreadcrumb } from "~/components/navigation/dynamic-breadcrumb";
@@ -14,6 +15,7 @@ interface LinkProps {
subtitle: string; subtitle: string;
count: number; count: number;
href: string; href: string;
hidden?: boolean;
} }
export async function generateMetadata() { export async function generateMetadata() {
@@ -42,6 +44,7 @@ export default async function ManagementPage() {
title: t("statistic.createUser"), title: t("statistic.createUser"),
}, },
{ {
hidden: !isProviderEnabled("credentials"),
count: statistics.countInvites, count: statistics.countInvites,
href: "/manage/users/invites", href: "/manage/users/invites",
subtitle: t("statisticLabel.authentication"), subtitle: t("statisticLabel.authentication"),
@@ -72,24 +75,27 @@ export default async function ManagementPage() {
<HeroBanner /> <HeroBanner />
<Space h="md" /> <Space h="md" />
<SimpleGrid cols={{ xs: 1, sm: 2, md: 3 }}> <SimpleGrid cols={{ xs: 1, sm: 2, md: 3 }}>
{links.map((link, index) => ( {links.map(
<Card component={Link} href={link.href} key={`link-${index}`} withBorder> (link) =>
<Group justify="space-between" wrap="nowrap"> !link.hidden && (
<Group wrap="nowrap"> <Card component={Link} href={link.href} key={link.href} withBorder>
<Text size="2.4rem" fw="bolder"> <Group justify="space-between" wrap="nowrap">
{link.count} <Group wrap="nowrap">
</Text> <Text size="2.4rem" fw="bolder">
<Stack gap={0}> {link.count}
<Text c="red" size="xs"> </Text>
{link.subtitle} <Stack gap={0}>
</Text> <Text c="red" size="xs">
<Text fw="bold">{link.title}</Text> {link.subtitle}
</Stack> </Text>
</Group> <Text fw="bold">{link.title}</Text>
<IconArrowRight /> </Stack>
</Group> </Group>
</Card> <IconArrowRight />
))} </Group>
</Card>
),
)}
</SimpleGrid> </SimpleGrid>
</> </>
); );

View File

@@ -93,24 +93,38 @@ export const UserProfileAvatarForm = ({ user }: UserProfileAvatarForm) => {
}); });
}, [mutate, user.id, openConfirmModal, tManageAvatar]); }, [mutate, user.id, openConfirmModal, tManageAvatar]);
const isCredentialsUser = user.provider === "credentials";
return ( return (
<Box pos="relative"> <Box pos="relative">
<Menu opened={opened} keepMounted onChange={toggle} position="bottom-start" withArrow> <Menu
opened={opened}
keepMounted
onChange={isCredentialsUser ? toggle : undefined}
position="bottom-start"
withArrow
>
<Menu.Target> <Menu.Target>
<UnstyledButton onClick={toggle}> <UnstyledButton
component={isCredentialsUser ? undefined : "div"}
style={{ cursor: !isCredentialsUser ? "default" : undefined }}
onClick={isCredentialsUser ? toggle : undefined}
>
<UserAvatar user={user} size={200} /> <UserAvatar user={user} size={200} />
<Button {isCredentialsUser && (
component="div" <Button
pos="absolute" component="div"
bottom={0} pos="absolute"
left={0} bottom={0}
size="compact-md" left={0}
fw="normal" size="compact-md"
variant="default" fw="normal"
leftSection={<IconPencil size={18} stroke={1.5} />} variant="default"
> leftSection={<IconPencil size={18} stroke={1.5} />}
{t("common.action.edit")} >
</Button> {t("common.action.edit")}
</Button>
)}
</UnstyledButton> </UnstyledButton>
</Menu.Target> </Menu.Target>
<Menu.Dropdown> <Menu.Dropdown>

View File

@@ -51,8 +51,12 @@ export const UserProfileForm = ({ user }: UserProfileFormProps) => {
}, },
}); });
// Only credentials users can edit their profile
const isProviderCredentials = user.provider === "credentials";
const handleSubmit = useCallback( const handleSubmit = useCallback(
(values: FormType) => { (values: FormType) => {
if (!isProviderCredentials) return;
mutate({ mutate({
...values, ...values,
id: user.id, id: user.id,
@@ -64,14 +68,25 @@ export const UserProfileForm = ({ user }: UserProfileFormProps) => {
return ( return (
<form onSubmit={form.onSubmit(handleSubmit)}> <form onSubmit={form.onSubmit(handleSubmit)}>
<Stack> <Stack>
<TextInput label={t("user.field.username.label")} withAsterisk {...form.getInputProps("name")} /> <TextInput
<TextInput label={t("user.field.email.label")} {...form.getInputProps("email")} /> disabled={!isProviderCredentials}
label={t("user.field.username.label")}
withAsterisk
{...form.getInputProps("name")}
/>
<TextInput
disabled={!isProviderCredentials}
label={t("user.field.email.label")}
{...form.getInputProps("email")}
/>
<Group justify="end"> {isProviderCredentials && (
<Button type="submit" color="teal" disabled={!form.isDirty()} loading={isPending}> <Group justify="end">
{t("common.action.saveChanges")} <Button type="submit" color="teal" disabled={!form.isDirty()} loading={isPending}>
</Button> {t("common.action.saveChanges")}
</Group> </Button>
</Group>
)}
</Stack> </Stack>
</form> </form>
); );

View File

@@ -1,5 +1,6 @@
import { notFound } from "next/navigation"; import { notFound } from "next/navigation";
import { Box, Group, Stack, Title } from "@mantine/core"; import { Alert, Box, Group, Stack, Title } from "@mantine/core";
import { IconExclamationCircle } from "@tabler/icons-react";
import { api } from "@homarr/api/server"; import { api } from "@homarr/api/server";
import { auth } from "@homarr/auth/next"; import { auth } from "@homarr/auth/next";
@@ -53,8 +54,14 @@ export default async function EditUserPage({ params }: Props) {
notFound(); notFound();
} }
const isCredentialsUser = user.provider === "credentials";
return ( return (
<Stack> <Stack>
<Alert variant="light" color="yellow" icon={<IconExclamationCircle size="1rem" stroke={1.5} />}>
{t("management.page.user.fieldsDisabledExternalProvider")}
</Alert>
<Title>{tGeneral("title")}</Title> <Title>{tGeneral("title")}</Title>
<Group gap="xl"> <Group gap="xl">
<Box flex={1}> <Box flex={1}>
@@ -67,13 +74,15 @@ export default async function EditUserPage({ params }: Props) {
<ProfileLanguageChange /> <ProfileLanguageChange />
<DangerZoneRoot> {isCredentialsUser && (
<DangerZoneItem <DangerZoneRoot>
label={t("user.action.delete.label")} <DangerZoneItem
description={t("user.action.delete.description")} label={t("user.action.delete.label")}
action={<DeleteUserButton user={user} />} description={t("user.action.delete.description")}
/> action={<DeleteUserButton user={user} />}
</DangerZoneRoot> />
</DangerZoneRoot>
)}
</Stack> </Stack>
); );
} }

View File

@@ -28,6 +28,8 @@ export default async function Layout({ children, params }: PropsWithChildren<Lay
notFound(); notFound();
} }
const isCredentialsUser = user.provider === "credentials";
return ( return (
<ManageContainer size="xl"> <ManageContainer size="xl">
<DynamicBreadcrumb <DynamicBreadcrumb
@@ -57,11 +59,13 @@ export default async function Layout({ children, params }: PropsWithChildren<Lay
label={tUser("setting.general.title")} label={tUser("setting.general.title")}
icon={<IconSettings size="1rem" stroke={1.5} />} icon={<IconSettings size="1rem" stroke={1.5} />}
/> />
<NavigationLink {isCredentialsUser && (
href={`/manage/users/${params.userId}/security`} <NavigationLink
label={tUser("setting.security.title")} href={`/manage/users/${params.userId}/security`}
icon={<IconShieldLock size="1rem" stroke={1.5} />} label={tUser("setting.security.title")}
/> icon={<IconShieldLock size="1rem" stroke={1.5} />}
/>
)}
</Stack> </Stack>
</Stack> </Stack>
</GridCol> </GridCol>

View File

@@ -28,6 +28,10 @@ export default async function UserSecurityPage({ params }: Props) {
notFound(); notFound();
} }
if (user.provider !== "credentials") {
notFound();
}
return ( return (
<Stack> <Stack>
<Title>{tSecurity("title")}</Title> <Title>{tSecurity("title")}</Title>

View File

@@ -1,8 +1,11 @@
import Link from "next/link"; import Link from "next/link";
import { Anchor, Center, Group, Stack, Table, TableTbody, TableTd, TableTr, Text, Title } from "@mantine/core"; import { Alert, Anchor, Center, Group, Stack, Table, TableTbody, TableTd, TableTr, Text, Title } from "@mantine/core";
import { IconExclamationCircle } from "@tabler/icons-react";
import type { RouterOutputs } from "@homarr/api"; import type { RouterOutputs } from "@homarr/api";
import { api } from "@homarr/api/server"; import { api } from "@homarr/api/server";
import { env } from "@homarr/auth/env.mjs";
import { isProviderEnabled } from "@homarr/auth/server";
import { getI18n, getScopedI18n } from "@homarr/translation/server"; import { getI18n, getScopedI18n } from "@homarr/translation/server";
import { SearchInput, UserAvatar } from "@homarr/ui"; import { SearchInput, UserAvatar } from "@homarr/ui";
@@ -28,9 +31,22 @@ export default async function GroupsDetailPage({ params, searchParams }: GroupsD
group.members.filter((member) => member.name?.toLowerCase().includes(searchParams.search!.trim().toLowerCase())) group.members.filter((member) => member.name?.toLowerCase().includes(searchParams.search!.trim().toLowerCase()))
: group.members; : group.members;
const providerTypes = isProviderEnabled("credentials")
? env.AUTH_PROVIDERS.length > 1
? "mixed"
: "credentials"
: "external";
return ( return (
<Stack> <Stack>
<Title>{tMembers("title")}</Title> <Title>{tMembers("title")}</Title>
{providerTypes !== "credentials" && (
<Alert variant="light" color="yellow" icon={<IconExclamationCircle size="1rem" stroke={1.5} />}>
{t(`group.memberNotice.${providerTypes}`)}
</Alert>
)}
<Group justify="space-between"> <Group justify="space-between">
<SearchInput <SearchInput
placeholder={t("common.rtl", { placeholder={t("common.rtl", {
@@ -39,7 +55,9 @@ export default async function GroupsDetailPage({ params, searchParams }: GroupsD
})} })}
defaultValue={searchParams.search} defaultValue={searchParams.search}
/> />
<AddGroupMember groupId={group.id} presentUserIds={group.members.map((member) => member.id)} /> {isProviderEnabled("credentials") && (
<AddGroupMember groupId={group.id} presentUserIds={group.members.map((member) => member.id)} />
)}
</Group> </Group>
{filteredMembers.length === 0 && ( {filteredMembers.length === 0 && (
<Center py="sm"> <Center py="sm">
@@ -60,7 +78,7 @@ export default async function GroupsDetailPage({ params, searchParams }: GroupsD
} }
interface RowProps { interface RowProps {
member: RouterOutputs["group"]["getPaginated"]["items"][number]["members"][number]; member: RouterOutputs["group"]["getById"]["members"][number];
groupId: string; groupId: string;
} }
@@ -70,13 +88,13 @@ const Row = ({ member, groupId }: RowProps) => {
<TableTd> <TableTd>
<Group> <Group>
<UserAvatar size="sm" user={member} /> <UserAvatar size="sm" user={member} />
<Anchor component={Link} href={`/manage/users/${member.id}`}> <Anchor component={Link} href={`/manage/users/${member.id}/general`}>
{member.name} {member.name}
</Anchor> </Anchor>
</Group> </Group>
</TableTd> </TableTd>
<TableTd w={100}> <TableTd w={100}>
<RemoveGroupMember user={member} groupId={groupId} /> {member.provider === "credentials" && <RemoveGroupMember user={member} groupId={groupId} />}
</TableTd> </TableTd>
</TableTr> </TableTr>
); );

View File

@@ -1,9 +1,16 @@
import { notFound } from "next/navigation";
import { api } from "@homarr/api/server"; import { api } from "@homarr/api/server";
import { isProviderEnabled } from "@homarr/auth/server";
import { DynamicBreadcrumb } from "~/components/navigation/dynamic-breadcrumb"; import { DynamicBreadcrumb } from "~/components/navigation/dynamic-breadcrumb";
import { InviteListComponent } from "./_components/invite-list"; import { InviteListComponent } from "./_components/invite-list";
export default async function InvitesOverviewPage() { export default async function InvitesOverviewPage() {
if (!isProviderEnabled("credentials")) {
notFound();
}
const initialInvites = await api.invite.getAll(); const initialInvites = await api.invite.getAll();
return ( return (
<> <>

View File

@@ -22,18 +22,24 @@ export const MainNavigation = ({ headerSection, footerSection, links }: MainNavi
component={ScrollArea} component={ScrollArea}
> >
{links.map((link, index) => { {links.map((link, index) => {
if (link.hidden) {
return null;
}
const { icon: TablerIcon, ...props } = link; const { icon: TablerIcon, ...props } = link;
const Icon = <TablerIcon size={20} stroke={1.5} />; const Icon = <TablerIcon size={20} stroke={1.5} />;
let clientLink: ClientNavigationLink; let clientLink: ClientNavigationLink;
if ("items" in props) { if ("items" in props) {
clientLink = { clientLink = {
...props, ...props,
items: props.items.map((item) => { items: props.items
return { .filter((item) => !item.hidden)
...item, .map((item) => {
icon: <item.icon size={20} stroke={1.5} />, return {
}; ...item,
}), icon: <item.icon size={20} stroke={1.5} />,
};
}),
} as ClientNavigationLink; } as ClientNavigationLink;
} else { } else {
clientLink = props as ClientNavigationLink; clientLink = props as ClientNavigationLink;
@@ -49,6 +55,7 @@ export const MainNavigation = ({ headerSection, footerSection, links }: MainNavi
interface CommonNavigationLinkProps { interface CommonNavigationLinkProps {
label: string; label: string;
icon: TablerIcon; icon: TablerIcon;
hidden?: boolean;
} }
interface NavigationLinkHref extends CommonNavigationLinkProps { interface NavigationLinkHref extends CommonNavigationLinkProps {

View File

@@ -57,6 +57,7 @@ export const groupRouter = createTRPCRouter({
name: true, name: true,
email: true, email: true,
image: true, image: true,
provider: true,
}, },
}, },
}, },

View File

@@ -6,9 +6,11 @@ import { invites } from "@homarr/db/schema/sqlite";
import { z } from "@homarr/validation"; import { z } from "@homarr/validation";
import { createTRPCRouter, protectedProcedure } from "../trpc"; import { createTRPCRouter, protectedProcedure } from "../trpc";
import { throwIfCredentialsDisabled } from "./invite/checks";
export const inviteRouter = createTRPCRouter({ export const inviteRouter = createTRPCRouter({
getAll: protectedProcedure.query(async ({ ctx }) => { getAll: protectedProcedure.query(async ({ ctx }) => {
throwIfCredentialsDisabled();
const dbInvites = await ctx.db.query.invites.findMany({ const dbInvites = await ctx.db.query.invites.findMany({
orderBy: asc(invites.expirationDate), orderBy: asc(invites.expirationDate),
columns: { columns: {
@@ -32,6 +34,7 @@ export const inviteRouter = createTRPCRouter({
}), }),
) )
.mutation(async ({ ctx, input }) => { .mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const id = createId(); const id = createId();
const token = randomBytes(20).toString("hex"); const token = randomBytes(20).toString("hex");
@@ -54,6 +57,7 @@ export const inviteRouter = createTRPCRouter({
}), }),
) )
.mutation(async ({ ctx, input }) => { .mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const dbInvite = await ctx.db.query.invites.findFirst({ const dbInvite = await ctx.db.query.invites.findFirst({
where: eq(invites.id, input.id), where: eq(invites.id, input.id),
}); });

View File

@@ -0,0 +1,12 @@
import { TRPCError } from "@trpc/server";
import { env } from "@homarr/auth/env.mjs";
export const throwIfCredentialsDisabled = () => {
if (!env.AUTH_PROVIDERS.includes("credentials")) {
throw new TRPCError({
code: "FORBIDDEN",
message: "Credentials provider is disabled",
});
}
};

View File

@@ -170,8 +170,8 @@ describe("byId should return group by id including members and permissions", ()
expect(result.members.length).toBe(1); expect(result.members.length).toBe(1);
const userKeys = Object.keys(result.members[0] ?? {}); const userKeys = Object.keys(result.members[0] ?? {});
expect(userKeys.length).toBe(4); expect(userKeys.length).toBe(5);
expect(["id", "name", "email", "image"].some((key) => userKeys.includes(key))); expect(["id", "name", "email", "image", "provider"].some((key) => userKeys.includes(key)));
expect(result.permissions.length).toBe(1); expect(result.permissions.length).toBe(1);
expect(result.permissions[0]).toBe("admin"); expect(result.permissions[0]).toBe("admin");
}); });

View File

@@ -22,6 +22,15 @@ vi.mock("@homarr/auth", async () => {
return { ...mod, auth: () => ({}) as Session }; return { ...mod, auth: () => ({}) as Session };
}); });
// Mock the env module to return the credentials provider
vi.mock("@homarr/auth/env.mjs", () => {
return {
env: {
AUTH_PROVIDERS: ["credentials"],
},
};
});
describe("all should return all existing invites without sensitive informations", () => { describe("all should return all existing invites without sensitive informations", () => {
test("invites should not contain sensitive informations", async () => { test("invites should not contain sensitive informations", async () => {
// Arrange // Arrange

View File

@@ -13,6 +13,15 @@ vi.mock("@homarr/auth", async () => {
return { ...mod, auth: () => ({}) as Session }; return { ...mod, auth: () => ({}) as Session };
}); });
// Mock the env module to return the credentials provider
vi.mock("@homarr/auth/env.mjs", () => {
return {
env: {
AUTH_PROVIDERS: ["credentials"],
},
};
});
describe("initUser should initialize the first user", () => { describe("initUser should initialize the first user", () => {
it("should throw an error if a user already exists", async () => { it("should throw an error if a user already exists", async () => {
const db = createDb(); const db = createDb();
@@ -230,6 +239,7 @@ describe("editProfile shoud update user", () => {
password: null, password: null,
image: null, image: null,
homeBoardId: null, homeBoardId: null,
provider: "credentials",
}); });
}); });
@@ -270,6 +280,7 @@ describe("editProfile shoud update user", () => {
password: null, password: null,
image: null, image: null,
homeBoardId: null, homeBoardId: null,
provider: "credentials",
}); });
}); });
}); });
@@ -294,6 +305,7 @@ describe("delete should delete user", () => {
password: null, password: null,
salt: null, salt: null,
homeBoardId: null, homeBoardId: null,
provider: "ldap" as const,
}, },
{ {
id: userToDelete, id: userToDelete,
@@ -314,6 +326,7 @@ describe("delete should delete user", () => {
password: null, password: null,
salt: null, salt: null,
homeBoardId: null, homeBoardId: null,
provider: "oidc" as const,
}, },
]; ];

View File

@@ -4,12 +4,17 @@ import { createSaltAsync, hashPasswordAsync } from "@homarr/auth";
import type { Database } from "@homarr/db"; import type { Database } from "@homarr/db";
import { and, createId, eq, schema } from "@homarr/db"; import { and, createId, eq, schema } from "@homarr/db";
import { groupMembers, groupPermissions, groups, invites, users } from "@homarr/db/schema/sqlite"; import { groupMembers, groupPermissions, groups, invites, users } from "@homarr/db/schema/sqlite";
import type { SupportedAuthProvider } from "@homarr/definitions";
import { logger } from "@homarr/log";
import { validation, z } from "@homarr/validation"; import { validation, z } from "@homarr/validation";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "../trpc"; import { createTRPCRouter, protectedProcedure, publicProcedure } from "../trpc";
import { throwIfCredentialsDisabled } from "./invite/checks";
export const userRouter = createTRPCRouter({ export const userRouter = createTRPCRouter({
initUser: publicProcedure.input(validation.user.init).mutation(async ({ ctx, input }) => { initUser: publicProcedure.input(validation.user.init).mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const firstUser = await ctx.db.query.users.findFirst({ const firstUser = await ctx.db.query.users.findFirst({
columns: { columns: {
id: true, id: true,
@@ -40,6 +45,7 @@ export const userRouter = createTRPCRouter({
}); });
}), }),
register: publicProcedure.input(validation.user.registrationApi).mutation(async ({ ctx, input }) => { register: publicProcedure.input(validation.user.registrationApi).mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const inviteWhere = and(eq(invites.id, input.inviteId), eq(invites.token, input.token)); const inviteWhere = and(eq(invites.id, input.inviteId), eq(invites.token, input.token));
const dbInvite = await ctx.db.query.invites.findFirst({ const dbInvite = await ctx.db.query.invites.findFirst({
columns: { columns: {
@@ -56,7 +62,7 @@ export const userRouter = createTRPCRouter({
}); });
} }
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, input.username); await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, "credentials", input.username);
await createUserAsync(ctx.db, input); await createUserAsync(ctx.db, input);
@@ -64,7 +70,8 @@ export const userRouter = createTRPCRouter({
await ctx.db.delete(invites).where(inviteWhere); await ctx.db.delete(invites).where(inviteWhere);
}), }),
create: publicProcedure.input(validation.user.create).mutation(async ({ ctx, input }) => { create: publicProcedure.input(validation.user.create).mutation(async ({ ctx, input }) => {
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, input.username); throwIfCredentialsDisabled();
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, "credentials", input.username);
await createUserAsync(ctx.db, input); await createUserAsync(ctx.db, input);
}), }),
@@ -93,6 +100,7 @@ export const userRouter = createTRPCRouter({
columns: { columns: {
id: true, id: true,
image: true, image: true,
provider: true,
}, },
where: eq(users.id, input.userId), where: eq(users.id, input.userId),
}); });
@@ -104,6 +112,13 @@ export const userRouter = createTRPCRouter({
}); });
} }
if (user.provider !== "credentials") {
throw new TRPCError({
code: "FORBIDDEN",
message: "Profile image can not be changed for users with external providers",
});
}
await ctx.db await ctx.db
.update(users) .update(users)
.set({ .set({
@@ -112,13 +127,14 @@ export const userRouter = createTRPCRouter({
.where(eq(users.id, input.userId)); .where(eq(users.id, input.userId));
}), }),
getAll: publicProcedure.query(async ({ ctx }) => { getAll: publicProcedure.query(async ({ ctx }) => {
return ctx.db.query.users.findMany({ return await ctx.db.query.users.findMany({
columns: { columns: {
id: true, id: true,
name: true, name: true,
email: true, email: true,
emailVerified: true, emailVerified: true,
image: true, image: true,
provider: true,
}, },
}); });
}), }),
@@ -139,6 +155,7 @@ export const userRouter = createTRPCRouter({
email: true, email: true,
emailVerified: true, emailVerified: true,
image: true, image: true,
provider: true,
}, },
where: eq(users.id, input.userId), where: eq(users.id, input.userId),
}); });
@@ -154,7 +171,7 @@ export const userRouter = createTRPCRouter({
}), }),
editProfile: publicProcedure.input(validation.user.editProfile).mutation(async ({ input, ctx }) => { editProfile: publicProcedure.input(validation.user.editProfile).mutation(async ({ input, ctx }) => {
const user = await ctx.db.query.users.findFirst({ const user = await ctx.db.query.users.findFirst({
columns: { email: true }, columns: { email: true, provider: true },
where: eq(users.id, input.id), where: eq(users.id, input.id),
}); });
@@ -165,7 +182,14 @@ export const userRouter = createTRPCRouter({
}); });
} }
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, input.name, input.id); if (user.provider !== "credentials") {
throw new TRPCError({
code: "FORBIDDEN",
message: "Username and email can not be changed for users with external providers",
});
}
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, "credentials", input.name, input.id);
const emailDirty = input.email && user.email !== input.email; const emailDirty = input.email && user.email !== input.email;
await ctx.db await ctx.db
@@ -190,26 +214,38 @@ export const userRouter = createTRPCRouter({
}); });
} }
const dbUser = await ctx.db.query.users.findFirst({
columns: {
id: true,
password: true,
salt: true,
provider: true,
},
where: eq(users.id, input.userId),
});
if (!dbUser) {
throw new TRPCError({
code: "NOT_FOUND",
message: "User not found",
});
}
if (dbUser.provider !== "credentials") {
throw new TRPCError({
code: "FORBIDDEN",
message: "Password can not be changed for users with external providers",
});
}
// Admins can change the password of other users without providing the previous password // Admins can change the password of other users without providing the previous password
const isPreviousPasswordRequired = ctx.session.user.id === input.userId; const isPreviousPasswordRequired = ctx.session.user.id === input.userId;
logger.info(
`User ${user.id} is changing password for user ${input.userId}, previous password is required: ${isPreviousPasswordRequired}`,
);
if (isPreviousPasswordRequired) { if (isPreviousPasswordRequired) {
const dbUser = await ctx.db.query.users.findFirst({
columns: {
id: true,
password: true,
salt: true,
},
where: eq(users.id, input.userId),
});
if (!dbUser) {
throw new TRPCError({
code: "NOT_FOUND",
message: "User not found",
});
}
const previousPasswordHash = await hashPasswordAsync(input.previousPassword, dbUser.salt ?? ""); const previousPasswordHash = await hashPasswordAsync(input.previousPassword, dbUser.salt ?? "");
const isValid = previousPasswordHash === dbUser.password; const isValid = previousPasswordHash === dbUser.password;
@@ -249,9 +285,14 @@ const createUserAsync = async (db: Database, input: z.infer<typeof validation.us
return userId; return userId;
}; };
const checkUsernameAlreadyTakenAndThrowAsync = async (db: Database, username: string, ignoreId?: string) => { const checkUsernameAlreadyTakenAndThrowAsync = async (
db: Database,
provider: SupportedAuthProvider,
username: string,
ignoreId?: string,
) => {
const user = await db.query.users.findFirst({ const user = await db.query.users.findFirst({
where: eq(users.name, username.toLowerCase()), where: and(eq(users.name, username.toLowerCase()), eq(users.provider, provider)),
}); });
if (!user) return; if (!user) return;

View File

@@ -0,0 +1,9 @@
import type { SupportedAuthProvider } from "@homarr/definitions";
import { env } from "../env.mjs";
export const isProviderEnabled = (provider: SupportedAuthProvider) => {
// The question mark is placed there because isProviderEnabled is called during static build of about page
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
return env.AUTH_PROVIDERS?.includes(provider);
};

View File

@@ -1,7 +1,7 @@
import bcrypt from "bcrypt"; import bcrypt from "bcrypt";
import type { Database } from "@homarr/db"; import type { Database } from "@homarr/db";
import { eq } from "@homarr/db"; import { and, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite"; import { users } from "@homarr/db/schema/sqlite";
import { logger } from "@homarr/log"; import { logger } from "@homarr/log";
import type { validation, z } from "@homarr/validation"; import type { validation, z } from "@homarr/validation";
@@ -11,7 +11,7 @@ export const authorizeWithBasicCredentialsAsync = async (
credentials: z.infer<typeof validation.user.signIn>, credentials: z.infer<typeof validation.user.signIn>,
) => { ) => {
const user = await db.query.users.findFirst({ const user = await db.query.users.findFirst({
where: eq(users.name, credentials.name), where: and(eq(users.name, credentials.name), eq(users.provider, "credentials")),
}); });
if (!user?.password) { if (!user?.password) {

View File

@@ -1,7 +1,8 @@
import type { Adapter } from "@auth/core/adapters";
import { CredentialsSignin } from "@auth/core/errors"; import { CredentialsSignin } from "@auth/core/errors";
import { createId } from "@homarr/db"; import type { Database } from "@homarr/db";
import { and, createId, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite";
import { logger } from "@homarr/log"; import { logger } from "@homarr/log";
import type { validation } from "@homarr/validation"; import type { validation } from "@homarr/validation";
import { z } from "@homarr/validation"; import { z } from "@homarr/validation";
@@ -10,7 +11,7 @@ import { env } from "../../../env.mjs";
import { LdapClient } from "../ldap-client"; import { LdapClient } from "../ldap-client";
export const authorizeWithLdapCredentialsAsync = async ( export const authorizeWithLdapCredentialsAsync = async (
adapter: Adapter, db: Database,
credentials: z.infer<typeof validation.user.signIn>, credentials: z.infer<typeof validation.user.signIn>,
) => { ) => {
logger.info(`user ${credentials.name} is trying to log in using LDAP. Connecting to LDAP server...`); logger.info(`user ${credentials.name} is trying to log in using LDAP. Connecting to LDAP server...`);
@@ -89,18 +90,30 @@ export const authorizeWithLdapCredentialsAsync = async (
await client.disconnectAsync(); await client.disconnectAsync();
// Create or update user in the database // Create or update user in the database
let user = await adapter.getUserByEmail?.(mailResult.data); let user = await db.query.users.findFirst({
columns: {
id: true,
name: true,
image: true,
email: true,
emailVerified: true,
provider: true,
},
where: and(eq(users.email, mailResult.data), eq(users.provider, "ldap")),
});
if (!user) { if (!user) {
logger.info(`User ${credentials.name} not found in the database. Creating...`); logger.info(`User ${credentials.name} not found in the database. Creating...`);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion user = {
user = await adapter.createUser!({
id: createId(), id: createId(),
name: credentials.name, name: credentials.name,
email: mailResult.data, email: mailResult.data,
emailVerified: new Date(), // assume email is verified emailVerified: new Date(), // assume email is verified
}); image: null,
provider: "ldap",
};
await db.insert(users).values(user);
logger.info(`User ${credentials.name} created successfully.`); logger.info(`User ${credentials.name} created successfully.`);
} }
@@ -108,11 +121,9 @@ export const authorizeWithLdapCredentialsAsync = async (
if (user.name !== credentials.name) { if (user.name !== credentials.name) {
logger.warn(`User ${credentials.name} found in the database but with different name. Updating...`); logger.warn(`User ${credentials.name} found in the database but with different name. Updating...`);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion user.name = credentials.name;
user = await adapter.updateUser!({
id: user.id, await db.update(users).set({ name: user.name }).where(eq(users.id, user.id));
name: credentials.name,
});
logger.info(`User ${credentials.name} updated successfully.`); logger.info(`User ${credentials.name} updated successfully.`);
} }

View File

@@ -3,7 +3,6 @@ import type Credentials from "@auth/core/providers/credentials";
import type { Database } from "@homarr/db"; import type { Database } from "@homarr/db";
import { validation } from "@homarr/validation"; import { validation } from "@homarr/validation";
import { adapter } from "../../adapter";
import { authorizeWithBasicCredentialsAsync } from "./authorization/basic-authorization"; import { authorizeWithBasicCredentialsAsync } from "./authorization/basic-authorization";
import { authorizeWithLdapCredentialsAsync } from "./authorization/ldap-authorization"; import { authorizeWithLdapCredentialsAsync } from "./authorization/ldap-authorization";
@@ -32,7 +31,7 @@ export const createCredentialsConfiguration = (db: Database) =>
const data = await validation.user.signIn.parseAsync(credentials); const data = await validation.user.signIn.parseAsync(credentials);
if (data.credentialType === "ldap") { if (data.credentialType === "ldap") {
return await authorizeWithLdapCredentialsAsync(adapter, data).catch(() => null); return await authorizeWithLdapCredentialsAsync(db, data).catch(() => null);
} }
return await authorizeWithBasicCredentialsAsync(db, data); return await authorizeWithBasicCredentialsAsync(db, data);

View File

@@ -32,6 +32,7 @@ export const OidcProvider = (headers: ReadonlyHeaders | null): OIDCConfig<Profil
// Use the name as the username if the preferred_username is an email address // Use the name as the username if the preferred_username is an email address
name: profile.preferred_username.includes("@") ? profile.name : profile.preferred_username, name: profile.preferred_username.includes("@") ? profile.name : profile.preferred_username,
email: profile.email, email: profile.email,
provider: "oidc",
}; };
}, },
}); });

View File

@@ -1,9 +1,8 @@
import type { Adapter } from "@auth/core/adapters";
import { CredentialsSignin } from "@auth/core/errors"; import { CredentialsSignin } from "@auth/core/errors";
import { DrizzleAdapter } from "@auth/drizzle-adapter";
import { describe, expect, test, vi } from "vitest"; import { describe, expect, test, vi } from "vitest";
import { createId, eq } from "@homarr/db"; import type { Database } from "@homarr/db";
import { and, createId, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite"; import { users } from "@homarr/db/schema/sqlite";
import { createDb } from "@homarr/db/test"; import { createDb } from "@homarr/db/test";
@@ -32,7 +31,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act // Act
const act = () => const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, { authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test", name: "test",
password: "test", password: "test",
credentialType: "ldap", credentialType: "ldap",
@@ -55,7 +54,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act // Act
const act = () => const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, { authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test", name: "test",
password: "test", password: "test",
credentialType: "ldap", credentialType: "ldap",
@@ -85,7 +84,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act // Act
const act = () => const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, { authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test", name: "test",
password: "test", password: "test",
credentialType: "ldap", credentialType: "ldap",
@@ -118,7 +117,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act // Act
const act = () => const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, { authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test", name: "test",
password: "test", password: "test",
credentialType: "ldap", credentialType: "ldap",
@@ -132,7 +131,6 @@ describe("authorizeWithLdapCredentials", () => {
test("should authorize user with correct credentials and create user", async () => { test("should authorize user with correct credentials and create user", async () => {
// Arrange // Arrange
const db = createDb(); const db = createDb();
const adapter = DrizzleAdapter(db);
const spy = vi.spyOn(ldapClient, "LdapClient"); const spy = vi.spyOn(ldapClient, "LdapClient");
spy.mockImplementation( spy.mockImplementation(
() => () =>
@@ -151,7 +149,7 @@ describe("authorizeWithLdapCredentials", () => {
); );
// Act // Act
const result = await authorizeWithLdapCredentialsAsync(adapter, { const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test", name: "test",
password: "test", password: "test",
credentialType: "ldap", credentialType: "ldap",
@@ -166,13 +164,68 @@ describe("authorizeWithLdapCredentials", () => {
expect(dbUser?.id).toBe(result.id); expect(dbUser?.id).toBe(result.id);
expect(dbUser?.email).toBe("test@gmail.com"); expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.emailVerified).not.toBeNull(); expect(dbUser?.emailVerified).not.toBeNull();
expect(dbUser?.provider).toBe("ldap");
});
test("should authorize user with correct credentials and create user with same email when credentials user already exists", async () => {
// Arrange
const db = createDb();
const spy = vi.spyOn(ldapClient, "LdapClient");
const salt = await createSaltAsync();
spy.mockImplementation(
() =>
({
bindAsync: vi.fn(() => Promise.resolve()),
searchAsync: vi.fn(() =>
Promise.resolve([
{
dn: "test",
mail: "test@gmail.com",
},
]),
),
disconnectAsync: vi.fn(),
}) as unknown as ldapClient.LdapClient,
);
await db.insert(users).values({
id: createId(),
name: "test",
salt,
password: await hashPasswordAsync("test", salt),
email: "test@gmail.com",
provider: "credentials",
});
// Act
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result.name).toBe("test");
const dbUser = await db.query.users.findFirst({
where: and(eq(users.name, "test"), eq(users.provider, "ldap")),
});
expect(dbUser).toBeDefined();
expect(dbUser?.id).toBe(result.id);
expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.emailVerified).not.toBeNull();
expect(dbUser?.provider).toBe("ldap");
const credentialsUser = await db.query.users.findFirst({
where: and(eq(users.name, "test"), eq(users.provider, "credentials")),
});
expect(credentialsUser).toBeDefined();
expect(credentialsUser?.id).not.toBe(result.id);
}); });
test("should authorize user with correct credentials and update name", async () => { test("should authorize user with correct credentials and update name", async () => {
// Arrange // Arrange
const userId = createId(); const userId = createId();
const db = createDb(); const db = createDb();
const adapter = DrizzleAdapter(db);
const salt = await createSaltAsync(); const salt = await createSaltAsync();
await db.insert(users).values({ await db.insert(users).values({
id: userId, id: userId,
@@ -180,10 +233,11 @@ describe("authorizeWithLdapCredentials", () => {
salt, salt,
password: await hashPasswordAsync("test", salt), password: await hashPasswordAsync("test", salt),
email: "test@gmail.com", email: "test@gmail.com",
provider: "ldap",
}); });
// Act // Act
const result = await authorizeWithLdapCredentialsAsync(adapter, { const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test", name: "test",
password: "test", password: "test",
credentialType: "ldap", credentialType: "ldap",
@@ -200,5 +254,6 @@ describe("authorizeWithLdapCredentials", () => {
expect(dbUser?.id).toBe(userId); expect(dbUser?.id).toBe(userId);
expect(dbUser?.name).toBe("test"); expect(dbUser?.name).toBe("test");
expect(dbUser?.email).toBe("test@gmail.com"); expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.provider).toBe("ldap");
}); });
}); });

View File

@@ -1 +1,2 @@
export { hasQueryAccessToIntegrationsAsync } from "./permissions/integration-query-permissions"; export { hasQueryAccessToIntegrationsAsync } from "./permissions/integration-query-permissions";
export { isProviderEnabled } from "./providers/check-provider";

View File

@@ -0,0 +1 @@
ALTER TABLE `user` ADD `provider` varchar(64) DEFAULT 'credentials' NOT NULL;

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,13 @@
"when": 1720113913876, "when": 1720113913876,
"tag": "0004_noisy_giant_girl", "tag": "0004_noisy_giant_girl",
"breakpoints": true "breakpoints": true
},
{
"idx": 5,
"version": "5",
"when": 1722068832607,
"tag": "0005_soft_microbe",
"breakpoints": true
} }
] ]
} }

View File

@@ -0,0 +1 @@
ALTER TABLE `user` ADD `provider` text DEFAULT 'credentials' NOT NULL;

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,13 @@
"when": 1720036615408, "when": 1720036615408,
"tag": "0004_peaceful_red_ghost", "tag": "0004_peaceful_red_ghost",
"breakpoints": true "breakpoints": true
},
{
"idx": 5,
"version": "6",
"when": 1722014142492,
"tag": "0005_lean_random",
"breakpoints": true
} }
] ]
} }

View File

@@ -13,6 +13,7 @@ import type {
IntegrationPermission, IntegrationPermission,
IntegrationSecretKind, IntegrationSecretKind,
SectionKind, SectionKind,
SupportedAuthProvider,
WidgetKind, WidgetKind,
} from "@homarr/definitions"; } from "@homarr/definitions";
import { backgroundImageAttachments, backgroundImageRepeats, backgroundImageSizes } from "@homarr/definitions"; import { backgroundImageAttachments, backgroundImageRepeats, backgroundImageSizes } from "@homarr/definitions";
@@ -25,6 +26,7 @@ export const users = mysqlTable("user", {
image: text("image"), image: text("image"),
password: text("password"), password: text("password"),
salt: text("salt"), salt: text("salt"),
provider: varchar("provider", { length: 64 }).$type<SupportedAuthProvider>().default("credentials").notNull(),
homeBoardId: varchar("homeBoardId", { length: 64 }).references((): AnyMySqlColumn => boards.id, { homeBoardId: varchar("homeBoardId", { length: 64 }).references((): AnyMySqlColumn => boards.id, {
onDelete: "set null", onDelete: "set null",
}), }),

View File

@@ -15,6 +15,7 @@ import type {
IntegrationPermission, IntegrationPermission,
IntegrationSecretKind, IntegrationSecretKind,
SectionKind, SectionKind,
SupportedAuthProvider,
WidgetKind, WidgetKind,
} from "@homarr/definitions"; } from "@homarr/definitions";
@@ -26,6 +27,7 @@ export const users = sqliteTable("user", {
image: text("image"), image: text("image"),
password: text("password"), password: text("password"),
salt: text("salt"), salt: text("salt"),
provider: text("provider").$type<SupportedAuthProvider>().default("credentials").notNull(),
homeBoardId: text("homeBoardId").references((): AnySQLiteColumn => boards.id, { homeBoardId: text("homeBoardId").references((): AnySQLiteColumn => boards.id, {
onDelete: "set null", onDelete: "set null",
}), }),

View File

@@ -0,0 +1,2 @@
export const supportedAuthProviders = ["credentials", "oidc", "ldap"] as const;
export type SupportedAuthProvider = (typeof supportedAuthProviders)[number];

View File

@@ -4,3 +4,4 @@ export * from "./section";
export * from "./widget"; export * from "./widget";
export * from "./permissions"; export * from "./permissions";
export * from "./docker"; export * from "./docker";
export * from "./auth";

View File

@@ -195,6 +195,10 @@ export default {
}, },
}, },
}, },
memberNotice: {
mixed: "Some members are from external providers and cannot be managed here",
external: "All members are from external providers and cannot be managed here",
},
action: { action: {
create: { create: {
label: "New group", label: "New group",
@@ -1334,6 +1338,8 @@ export default {
}, },
user: { user: {
back: "Back to users", back: "Back to users",
fieldsDisabledExternalProvider:
"Certain fields are disabled because they are managed by an external authentication provider.",
setting: { setting: {
general: { general: {
title: "General", title: "General",
@@ -1379,7 +1385,7 @@ export default {
}, },
}, },
invite: { invite: {
title: "Manager user invites", title: "Manage user invites",
action: { action: {
new: { new: {
title: "New invite", title: "New invite",