Revision 63cae12bce9861cec309798d34701cf3da20bc71 authored by Peter Zijlstra on 09 December 2016, 13:59:00 UTC, committed by Ingo Molnar on 14 January 2017, 09:56:10 UTC
There is problem with installing an event in a task that is 'stuck' on
an offline CPU.

Blocked tasks are not dis-assosciated from offlined CPUs, after all, a
blocked task doesn't run and doesn't require a CPU etc.. Only on
wakeup do we ammend the situation and place the task on a available
CPU.

If we hit such a task with perf_install_in_context() we'll loop until
either that task wakes up or the CPU comes back online, if the task
waking depends on the event being installed, we're stuck.

While looking into this issue, I also spotted another problem, if we
hit a task with perf_install_in_context() that is in the middle of
being migrated, that is we observe the old CPU before sending the IPI,
but run the IPI (on the old CPU) while the task is already running on
the new CPU, things also go sideways.

Rework things to rely on task_curr() -- outside of rq->lock -- which
is rather tricky. Imagine the following scenario where we're trying to
install the first event into our task 't':

CPU0            CPU1            CPU2

                (current == t)

t->perf_event_ctxp[] = ctx;
smp_mb();
cpu = task_cpu(t);

                switch(t, n);
                                migrate(t, 2);
                                switch(p, t);

                                ctx = t->perf_event_ctxp[]; // must not be NULL

smp_function_call(cpu, ..);

                generic_exec_single()
                  func();
                    spin_lock(ctx->lock);
                    if (task_curr(t)) // false

                    add_event_to_ctx();
                    spin_unlock(ctx->lock);

                                perf_event_context_sched_in();
                                  spin_lock(ctx->lock);
                                  // sees event

So its CPU0's store of t->perf_event_ctxp[] that must not go 'missing'.
Because if CPU2's load of that variable were to observe NULL, it would
not try to schedule the ctx and we'd have a task running without its
counter, which would be 'bad'.

As long as we observe !NULL, we'll acquire ctx->lock. If we acquire it
first and not see the event yet, then CPU0 must observe task_curr()
and retry. If the install happens first, then we must see the event on
sched-in and all is well.

I think we can translate the first part (until the 'must not be NULL')
of the scenario to a litmus test like:

  C C-peterz

  {
  }

  P0(int *x, int *y)
  {
          int r1;

          WRITE_ONCE(*x, 1);
          smp_mb();
          r1 = READ_ONCE(*y);
  }

  P1(int *y, int *z)
  {
          WRITE_ONCE(*y, 1);
          smp_store_release(z, 1);
  }

  P2(int *x, int *z)
  {
          int r1;
          int r2;

          r1 = smp_load_acquire(z);
	  smp_mb();
          r2 = READ_ONCE(*x);
  }

  exists
  (0:r1=0 /\ 2:r1=1 /\ 2:r2=0)

Where:
  x is perf_event_ctxp[],
  y is our tasks's CPU, and
  z is our task being placed on the rq of CPU2.

The P0 smp_mb() is the one added by this patch, ordering the store to
perf_event_ctxp[] from find_get_context() and the load of task_cpu()
in task_function_call().

The smp_store_release/smp_load_acquire model the RCpc locking of the
rq->lock and the smp_mb() of P2 is the context switch switching from
whatever CPU2 was running to our task 't'.

This litmus test evaluates into:

  Test C-peterz Allowed
  States 7
  0:r1=0; 2:r1=0; 2:r2=0;
  0:r1=0; 2:r1=0; 2:r2=1;
  0:r1=0; 2:r1=1; 2:r2=1;
  0:r1=1; 2:r1=0; 2:r2=0;
  0:r1=1; 2:r1=0; 2:r2=1;
  0:r1=1; 2:r1=1; 2:r2=0;
  0:r1=1; 2:r1=1; 2:r2=1;
  No
  Witnesses
  Positive: 0 Negative: 7
  Condition exists (0:r1=0 /\ 2:r1=1 /\ 2:r2=0)
  Observation C-peterz Never 0 7
  Hash=e427f41d9146b2a5445101d3e2fcaa34

And the strong and weak model agree.

Reported-by: Mark Rutland <mark.rutland@arm.com>
Tested-by: Mark Rutland <mark.rutland@arm.com>
Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
Cc: Alexander Shishkin <alexander.shishkin@linux.intel.com>
Cc: Arnaldo Carvalho de Melo <acme@kernel.org>
Cc: Arnaldo Carvalho de Melo <acme@redhat.com>
Cc: Jiri Olsa <jolsa@redhat.com>
Cc: Linus Torvalds <torvalds@linux-foundation.org>
Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Sebastian Andrzej Siewior <bigeasy@linutronix.de>
Cc: Stephane Eranian <eranian@google.com>
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: Vince Weaver <vincent.weaver@maine.edu>
Cc: Will Deacon <will.deacon@arm.com>
Cc: jeremy.linton@arm.com
Link: http://lkml.kernel.org/r/20161209135900.GU3174@twins.programming.kicks-ass.net
Signed-off-by: Ingo Molnar <mingo@kernel.org>
1 parent ad5013d
Raw File
skcipher.c
/*
 * Symmetric key cipher operations.
 *
 * Generic encrypt/decrypt wrapper for ciphers, handles operations across
 * multiple page boundaries by using temporary blocks.  In user context,
 * the kernel is given a chance to schedule us once per page.
 *
 * Copyright (c) 2015 Herbert Xu <herbert@gondor.apana.org.au>
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the Free
 * Software Foundation; either version 2 of the License, or (at your option)
 * any later version.
 *
 */

#include <crypto/internal/aead.h>
#include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <linux/bug.h>
#include <linux/cryptouser.h>
#include <linux/list.h>
#include <linux/module.h>
#include <linux/rtnetlink.h>
#include <linux/seq_file.h>
#include <net/netlink.h>

#include "internal.h"

enum {
	SKCIPHER_WALK_PHYS = 1 << 0,
	SKCIPHER_WALK_SLOW = 1 << 1,
	SKCIPHER_WALK_COPY = 1 << 2,
	SKCIPHER_WALK_DIFF = 1 << 3,
	SKCIPHER_WALK_SLEEP = 1 << 4,
};

struct skcipher_walk_buffer {
	struct list_head entry;
	struct scatter_walk dst;
	unsigned int len;
	u8 *data;
	u8 buffer[];
};

static int skcipher_walk_next(struct skcipher_walk *walk);

static inline void skcipher_unmap(struct scatter_walk *walk, void *vaddr)
{
	if (PageHighMem(scatterwalk_page(walk)))
		kunmap_atomic(vaddr);
}

static inline void *skcipher_map(struct scatter_walk *walk)
{
	struct page *page = scatterwalk_page(walk);

	return (PageHighMem(page) ? kmap_atomic(page) : page_address(page)) +
	       offset_in_page(walk->offset);
}

static inline void skcipher_map_src(struct skcipher_walk *walk)
{
	walk->src.virt.addr = skcipher_map(&walk->in);
}

static inline void skcipher_map_dst(struct skcipher_walk *walk)
{
	walk->dst.virt.addr = skcipher_map(&walk->out);
}

static inline void skcipher_unmap_src(struct skcipher_walk *walk)
{
	skcipher_unmap(&walk->in, walk->src.virt.addr);
}

static inline void skcipher_unmap_dst(struct skcipher_walk *walk)
{
	skcipher_unmap(&walk->out, walk->dst.virt.addr);
}

static inline gfp_t skcipher_walk_gfp(struct skcipher_walk *walk)
{
	return walk->flags & SKCIPHER_WALK_SLEEP ? GFP_KERNEL : GFP_ATOMIC;
}

/* Get a spot of the specified length that does not straddle a page.
 * The caller needs to ensure that there is enough space for this operation.
 */
static inline u8 *skcipher_get_spot(u8 *start, unsigned int len)
{
	u8 *end_page = (u8 *)(((unsigned long)(start + len - 1)) & PAGE_MASK);

	return max(start, end_page);
}

static int skcipher_done_slow(struct skcipher_walk *walk, unsigned int bsize)
{
	u8 *addr;

	addr = (u8 *)ALIGN((unsigned long)walk->buffer, walk->alignmask + 1);
	addr = skcipher_get_spot(addr, bsize);
	scatterwalk_copychunks(addr, &walk->out, bsize,
			       (walk->flags & SKCIPHER_WALK_PHYS) ? 2 : 1);
	return 0;
}

int skcipher_walk_done(struct skcipher_walk *walk, int err)
{
	unsigned int n = walk->nbytes - err;
	unsigned int nbytes;

	nbytes = walk->total - n;

	if (unlikely(err < 0)) {
		nbytes = 0;
		n = 0;
	} else if (likely(!(walk->flags & (SKCIPHER_WALK_PHYS |
					   SKCIPHER_WALK_SLOW |
					   SKCIPHER_WALK_COPY |
					   SKCIPHER_WALK_DIFF)))) {
unmap_src:
		skcipher_unmap_src(walk);
	} else if (walk->flags & SKCIPHER_WALK_DIFF) {
		skcipher_unmap_dst(walk);
		goto unmap_src;
	} else if (walk->flags & SKCIPHER_WALK_COPY) {
		skcipher_map_dst(walk);
		memcpy(walk->dst.virt.addr, walk->page, n);
		skcipher_unmap_dst(walk);
	} else if (unlikely(walk->flags & SKCIPHER_WALK_SLOW)) {
		if (WARN_ON(err)) {
			err = -EINVAL;
			nbytes = 0;
		} else
			n = skcipher_done_slow(walk, n);
	}

	if (err > 0)
		err = 0;

	walk->total = nbytes;
	walk->nbytes = nbytes;

	scatterwalk_advance(&walk->in, n);
	scatterwalk_advance(&walk->out, n);
	scatterwalk_done(&walk->in, 0, nbytes);
	scatterwalk_done(&walk->out, 1, nbytes);

	if (nbytes) {
		crypto_yield(walk->flags & SKCIPHER_WALK_SLEEP ?
			     CRYPTO_TFM_REQ_MAY_SLEEP : 0);
		return skcipher_walk_next(walk);
	}

	/* Short-circuit for the common/fast path. */
	if (!((unsigned long)walk->buffer | (unsigned long)walk->page))
		goto out;

	if (walk->flags & SKCIPHER_WALK_PHYS)
		goto out;

	if (walk->iv != walk->oiv)
		memcpy(walk->oiv, walk->iv, walk->ivsize);
	if (walk->buffer != walk->page)
		kfree(walk->buffer);
	if (walk->page)
		free_page((unsigned long)walk->page);

out:
	return err;
}
EXPORT_SYMBOL_GPL(skcipher_walk_done);

void skcipher_walk_complete(struct skcipher_walk *walk, int err)
{
	struct skcipher_walk_buffer *p, *tmp;

	list_for_each_entry_safe(p, tmp, &walk->buffers, entry) {
		u8 *data;

		if (err)
			goto done;

		data = p->data;
		if (!data) {
			data = PTR_ALIGN(&p->buffer[0], walk->alignmask + 1);
			data = skcipher_get_spot(data, walk->chunksize);
		}

		scatterwalk_copychunks(data, &p->dst, p->len, 1);

		if (offset_in_page(p->data) + p->len + walk->chunksize >
		    PAGE_SIZE)
			free_page((unsigned long)p->data);

done:
		list_del(&p->entry);
		kfree(p);
	}

	if (!err && walk->iv != walk->oiv)
		memcpy(walk->oiv, walk->iv, walk->ivsize);
	if (walk->buffer != walk->page)
		kfree(walk->buffer);
	if (walk->page)
		free_page((unsigned long)walk->page);
}
EXPORT_SYMBOL_GPL(skcipher_walk_complete);

static void skcipher_queue_write(struct skcipher_walk *walk,
				 struct skcipher_walk_buffer *p)
{
	p->dst = walk->out;
	list_add_tail(&p->entry, &walk->buffers);
}

static int skcipher_next_slow(struct skcipher_walk *walk, unsigned int bsize)
{
	bool phys = walk->flags & SKCIPHER_WALK_PHYS;
	unsigned alignmask = walk->alignmask;
	struct skcipher_walk_buffer *p;
	unsigned a;
	unsigned n;
	u8 *buffer;
	void *v;

	if (!phys) {
		if (!walk->buffer)
			walk->buffer = walk->page;
		buffer = walk->buffer;
		if (buffer)
			goto ok;
	}

	/* Start with the minimum alignment of kmalloc. */
	a = crypto_tfm_ctx_alignment() - 1;
	n = bsize;

	if (phys) {
		/* Calculate the minimum alignment of p->buffer. */
		a &= (sizeof(*p) ^ (sizeof(*p) - 1)) >> 1;
		n += sizeof(*p);
	}

	/* Minimum size to align p->buffer by alignmask. */
	n += alignmask & ~a;

	/* Minimum size to ensure p->buffer does not straddle a page. */
	n += (bsize - 1) & ~(alignmask | a);

	v = kzalloc(n, skcipher_walk_gfp(walk));
	if (!v)
		return skcipher_walk_done(walk, -ENOMEM);

	if (phys) {
		p = v;
		p->len = bsize;
		skcipher_queue_write(walk, p);
		buffer = p->buffer;
	} else {
		walk->buffer = v;
		buffer = v;
	}

ok:
	walk->dst.virt.addr = PTR_ALIGN(buffer, alignmask + 1);
	walk->dst.virt.addr = skcipher_get_spot(walk->dst.virt.addr, bsize);
	walk->src.virt.addr = walk->dst.virt.addr;

	scatterwalk_copychunks(walk->src.virt.addr, &walk->in, bsize, 0);

	walk->nbytes = bsize;
	walk->flags |= SKCIPHER_WALK_SLOW;

	return 0;
}

static int skcipher_next_copy(struct skcipher_walk *walk)
{
	struct skcipher_walk_buffer *p;
	u8 *tmp = walk->page;

	skcipher_map_src(walk);
	memcpy(tmp, walk->src.virt.addr, walk->nbytes);
	skcipher_unmap_src(walk);

	walk->src.virt.addr = tmp;
	walk->dst.virt.addr = tmp;

	if (!(walk->flags & SKCIPHER_WALK_PHYS))
		return 0;

	p = kmalloc(sizeof(*p), skcipher_walk_gfp(walk));
	if (!p)
		return -ENOMEM;

	p->data = walk->page;
	p->len = walk->nbytes;
	skcipher_queue_write(walk, p);

	if (offset_in_page(walk->page) + walk->nbytes + walk->chunksize >
	    PAGE_SIZE)
		walk->page = NULL;
	else
		walk->page += walk->nbytes;

	return 0;
}

static int skcipher_next_fast(struct skcipher_walk *walk)
{
	unsigned long diff;

	walk->src.phys.page = scatterwalk_page(&walk->in);
	walk->src.phys.offset = offset_in_page(walk->in.offset);
	walk->dst.phys.page = scatterwalk_page(&walk->out);
	walk->dst.phys.offset = offset_in_page(walk->out.offset);

	if (walk->flags & SKCIPHER_WALK_PHYS)
		return 0;

	diff = walk->src.phys.offset - walk->dst.phys.offset;
	diff |= walk->src.virt.page - walk->dst.virt.page;

	skcipher_map_src(walk);
	walk->dst.virt.addr = walk->src.virt.addr;

	if (diff) {
		walk->flags |= SKCIPHER_WALK_DIFF;
		skcipher_map_dst(walk);
	}

	return 0;
}

static int skcipher_walk_next(struct skcipher_walk *walk)
{
	unsigned int bsize;
	unsigned int n;
	int err;

	walk->flags &= ~(SKCIPHER_WALK_SLOW | SKCIPHER_WALK_COPY |
			 SKCIPHER_WALK_DIFF);

	n = walk->total;
	bsize = min(walk->chunksize, max(n, walk->blocksize));
	n = scatterwalk_clamp(&walk->in, n);
	n = scatterwalk_clamp(&walk->out, n);

	if (unlikely(n < bsize)) {
		if (unlikely(walk->total < walk->blocksize))
			return skcipher_walk_done(walk, -EINVAL);

slow_path:
		err = skcipher_next_slow(walk, bsize);
		goto set_phys_lowmem;
	}

	if (unlikely((walk->in.offset | walk->out.offset) & walk->alignmask)) {
		if (!walk->page) {
			gfp_t gfp = skcipher_walk_gfp(walk);

			walk->page = (void *)__get_free_page(gfp);
			if (!walk->page)
				goto slow_path;
		}

		walk->nbytes = min_t(unsigned, n,
				     PAGE_SIZE - offset_in_page(walk->page));
		walk->flags |= SKCIPHER_WALK_COPY;
		err = skcipher_next_copy(walk);
		goto set_phys_lowmem;
	}

	walk->nbytes = n;

	return skcipher_next_fast(walk);

set_phys_lowmem:
	if (!err && (walk->flags & SKCIPHER_WALK_PHYS)) {
		walk->src.phys.page = virt_to_page(walk->src.virt.addr);
		walk->dst.phys.page = virt_to_page(walk->dst.virt.addr);
		walk->src.phys.offset &= PAGE_SIZE - 1;
		walk->dst.phys.offset &= PAGE_SIZE - 1;
	}
	return err;
}
EXPORT_SYMBOL_GPL(skcipher_walk_next);

static int skcipher_copy_iv(struct skcipher_walk *walk)
{
	unsigned a = crypto_tfm_ctx_alignment() - 1;
	unsigned alignmask = walk->alignmask;
	unsigned ivsize = walk->ivsize;
	unsigned bs = walk->chunksize;
	unsigned aligned_bs;
	unsigned size;
	u8 *iv;

	aligned_bs = ALIGN(bs, alignmask);

	/* Minimum size to align buffer by alignmask. */
	size = alignmask & ~a;

	if (walk->flags & SKCIPHER_WALK_PHYS)
		size += ivsize;
	else {
		size += aligned_bs + ivsize;

		/* Minimum size to ensure buffer does not straddle a page. */
		size += (bs - 1) & ~(alignmask | a);
	}

	walk->buffer = kmalloc(size, skcipher_walk_gfp(walk));
	if (!walk->buffer)
		return -ENOMEM;

	iv = PTR_ALIGN(walk->buffer, alignmask + 1);
	iv = skcipher_get_spot(iv, bs) + aligned_bs;

	walk->iv = memcpy(iv, walk->iv, walk->ivsize);
	return 0;
}

static int skcipher_walk_first(struct skcipher_walk *walk)
{
	walk->nbytes = 0;

	if (WARN_ON_ONCE(in_irq()))
		return -EDEADLK;

	if (unlikely(!walk->total))
		return 0;

	walk->buffer = NULL;
	if (unlikely(((unsigned long)walk->iv & walk->alignmask))) {
		int err = skcipher_copy_iv(walk);
		if (err)
			return err;
	}

	walk->page = NULL;
	walk->nbytes = walk->total;

	return skcipher_walk_next(walk);
}

static int skcipher_walk_skcipher(struct skcipher_walk *walk,
				  struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);

	scatterwalk_start(&walk->in, req->src);
	scatterwalk_start(&walk->out, req->dst);

	walk->total = req->cryptlen;
	walk->iv = req->iv;
	walk->oiv = req->iv;

	walk->flags &= ~SKCIPHER_WALK_SLEEP;
	walk->flags |= req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP ?
		       SKCIPHER_WALK_SLEEP : 0;

	walk->blocksize = crypto_skcipher_blocksize(tfm);
	walk->chunksize = crypto_skcipher_chunksize(tfm);
	walk->ivsize = crypto_skcipher_ivsize(tfm);
	walk->alignmask = crypto_skcipher_alignmask(tfm);

	return skcipher_walk_first(walk);
}

int skcipher_walk_virt(struct skcipher_walk *walk,
		       struct skcipher_request *req, bool atomic)
{
	int err;

	walk->flags &= ~SKCIPHER_WALK_PHYS;

	err = skcipher_walk_skcipher(walk, req);

	walk->flags &= atomic ? ~SKCIPHER_WALK_SLEEP : ~0;

	return err;
}
EXPORT_SYMBOL_GPL(skcipher_walk_virt);

void skcipher_walk_atomise(struct skcipher_walk *walk)
{
	walk->flags &= ~SKCIPHER_WALK_SLEEP;
}
EXPORT_SYMBOL_GPL(skcipher_walk_atomise);

int skcipher_walk_async(struct skcipher_walk *walk,
			struct skcipher_request *req)
{
	walk->flags |= SKCIPHER_WALK_PHYS;

	INIT_LIST_HEAD(&walk->buffers);

	return skcipher_walk_skcipher(walk, req);
}
EXPORT_SYMBOL_GPL(skcipher_walk_async);

static int skcipher_walk_aead_common(struct skcipher_walk *walk,
				     struct aead_request *req, bool atomic)
{
	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
	int err;

	walk->flags &= ~SKCIPHER_WALK_PHYS;

	scatterwalk_start(&walk->in, req->src);
	scatterwalk_start(&walk->out, req->dst);

	scatterwalk_copychunks(NULL, &walk->in, req->assoclen, 2);
	scatterwalk_copychunks(NULL, &walk->out, req->assoclen, 2);

	walk->iv = req->iv;
	walk->oiv = req->iv;

	if (req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP)
		walk->flags |= SKCIPHER_WALK_SLEEP;
	else
		walk->flags &= ~SKCIPHER_WALK_SLEEP;

	walk->blocksize = crypto_aead_blocksize(tfm);
	walk->chunksize = crypto_aead_chunksize(tfm);
	walk->ivsize = crypto_aead_ivsize(tfm);
	walk->alignmask = crypto_aead_alignmask(tfm);

	err = skcipher_walk_first(walk);

	if (atomic)
		walk->flags &= ~SKCIPHER_WALK_SLEEP;

	return err;
}

int skcipher_walk_aead(struct skcipher_walk *walk, struct aead_request *req,
		       bool atomic)
{
	walk->total = req->cryptlen;

	return skcipher_walk_aead_common(walk, req, atomic);
}
EXPORT_SYMBOL_GPL(skcipher_walk_aead);

int skcipher_walk_aead_encrypt(struct skcipher_walk *walk,
			       struct aead_request *req, bool atomic)
{
	walk->total = req->cryptlen;

	return skcipher_walk_aead_common(walk, req, atomic);
}
EXPORT_SYMBOL_GPL(skcipher_walk_aead_encrypt);

int skcipher_walk_aead_decrypt(struct skcipher_walk *walk,
			       struct aead_request *req, bool atomic)
{
	struct crypto_aead *tfm = crypto_aead_reqtfm(req);

	walk->total = req->cryptlen - crypto_aead_authsize(tfm);

	return skcipher_walk_aead_common(walk, req, atomic);
}
EXPORT_SYMBOL_GPL(skcipher_walk_aead_decrypt);

static unsigned int crypto_skcipher_extsize(struct crypto_alg *alg)
{
	if (alg->cra_type == &crypto_blkcipher_type)
		return sizeof(struct crypto_blkcipher *);

	if (alg->cra_type == &crypto_ablkcipher_type ||
	    alg->cra_type == &crypto_givcipher_type)
		return sizeof(struct crypto_ablkcipher *);

	return crypto_alg_extsize(alg);
}

static int skcipher_setkey_blkcipher(struct crypto_skcipher *tfm,
				     const u8 *key, unsigned int keylen)
{
	struct crypto_blkcipher **ctx = crypto_skcipher_ctx(tfm);
	struct crypto_blkcipher *blkcipher = *ctx;
	int err;

	crypto_blkcipher_clear_flags(blkcipher, ~0);
	crypto_blkcipher_set_flags(blkcipher, crypto_skcipher_get_flags(tfm) &
					      CRYPTO_TFM_REQ_MASK);
	err = crypto_blkcipher_setkey(blkcipher, key, keylen);
	crypto_skcipher_set_flags(tfm, crypto_blkcipher_get_flags(blkcipher) &
				       CRYPTO_TFM_RES_MASK);

	return err;
}

static int skcipher_crypt_blkcipher(struct skcipher_request *req,
				    int (*crypt)(struct blkcipher_desc *,
						 struct scatterlist *,
						 struct scatterlist *,
						 unsigned int))
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_blkcipher **ctx = crypto_skcipher_ctx(tfm);
	struct blkcipher_desc desc = {
		.tfm = *ctx,
		.info = req->iv,
		.flags = req->base.flags,
	};


	return crypt(&desc, req->dst, req->src, req->cryptlen);
}

static int skcipher_encrypt_blkcipher(struct skcipher_request *req)
{
	struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
	struct crypto_tfm *tfm = crypto_skcipher_tfm(skcipher);
	struct blkcipher_alg *alg = &tfm->__crt_alg->cra_blkcipher;

	return skcipher_crypt_blkcipher(req, alg->encrypt);
}

static int skcipher_decrypt_blkcipher(struct skcipher_request *req)
{
	struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
	struct crypto_tfm *tfm = crypto_skcipher_tfm(skcipher);
	struct blkcipher_alg *alg = &tfm->__crt_alg->cra_blkcipher;

	return skcipher_crypt_blkcipher(req, alg->decrypt);
}

static void crypto_exit_skcipher_ops_blkcipher(struct crypto_tfm *tfm)
{
	struct crypto_blkcipher **ctx = crypto_tfm_ctx(tfm);

	crypto_free_blkcipher(*ctx);
}

static int crypto_init_skcipher_ops_blkcipher(struct crypto_tfm *tfm)
{
	struct crypto_alg *calg = tfm->__crt_alg;
	struct crypto_skcipher *skcipher = __crypto_skcipher_cast(tfm);
	struct crypto_blkcipher **ctx = crypto_tfm_ctx(tfm);
	struct crypto_blkcipher *blkcipher;
	struct crypto_tfm *btfm;

	if (!crypto_mod_get(calg))
		return -EAGAIN;

	btfm = __crypto_alloc_tfm(calg, CRYPTO_ALG_TYPE_BLKCIPHER,
					CRYPTO_ALG_TYPE_MASK);
	if (IS_ERR(btfm)) {
		crypto_mod_put(calg);
		return PTR_ERR(btfm);
	}

	blkcipher = __crypto_blkcipher_cast(btfm);
	*ctx = blkcipher;
	tfm->exit = crypto_exit_skcipher_ops_blkcipher;

	skcipher->setkey = skcipher_setkey_blkcipher;
	skcipher->encrypt = skcipher_encrypt_blkcipher;
	skcipher->decrypt = skcipher_decrypt_blkcipher;

	skcipher->ivsize = crypto_blkcipher_ivsize(blkcipher);
	skcipher->keysize = calg->cra_blkcipher.max_keysize;

	return 0;
}

static int skcipher_setkey_ablkcipher(struct crypto_skcipher *tfm,
				      const u8 *key, unsigned int keylen)
{
	struct crypto_ablkcipher **ctx = crypto_skcipher_ctx(tfm);
	struct crypto_ablkcipher *ablkcipher = *ctx;
	int err;

	crypto_ablkcipher_clear_flags(ablkcipher, ~0);
	crypto_ablkcipher_set_flags(ablkcipher,
				    crypto_skcipher_get_flags(tfm) &
				    CRYPTO_TFM_REQ_MASK);
	err = crypto_ablkcipher_setkey(ablkcipher, key, keylen);
	crypto_skcipher_set_flags(tfm,
				  crypto_ablkcipher_get_flags(ablkcipher) &
				  CRYPTO_TFM_RES_MASK);

	return err;
}

static int skcipher_crypt_ablkcipher(struct skcipher_request *req,
				     int (*crypt)(struct ablkcipher_request *))
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_ablkcipher **ctx = crypto_skcipher_ctx(tfm);
	struct ablkcipher_request *subreq = skcipher_request_ctx(req);

	ablkcipher_request_set_tfm(subreq, *ctx);
	ablkcipher_request_set_callback(subreq, skcipher_request_flags(req),
					req->base.complete, req->base.data);
	ablkcipher_request_set_crypt(subreq, req->src, req->dst, req->cryptlen,
				     req->iv);

	return crypt(subreq);
}

static int skcipher_encrypt_ablkcipher(struct skcipher_request *req)
{
	struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
	struct crypto_tfm *tfm = crypto_skcipher_tfm(skcipher);
	struct ablkcipher_alg *alg = &tfm->__crt_alg->cra_ablkcipher;

	return skcipher_crypt_ablkcipher(req, alg->encrypt);
}

static int skcipher_decrypt_ablkcipher(struct skcipher_request *req)
{
	struct crypto_skcipher *skcipher = crypto_skcipher_reqtfm(req);
	struct crypto_tfm *tfm = crypto_skcipher_tfm(skcipher);
	struct ablkcipher_alg *alg = &tfm->__crt_alg->cra_ablkcipher;

	return skcipher_crypt_ablkcipher(req, alg->decrypt);
}

static void crypto_exit_skcipher_ops_ablkcipher(struct crypto_tfm *tfm)
{
	struct crypto_ablkcipher **ctx = crypto_tfm_ctx(tfm);

	crypto_free_ablkcipher(*ctx);
}

static int crypto_init_skcipher_ops_ablkcipher(struct crypto_tfm *tfm)
{
	struct crypto_alg *calg = tfm->__crt_alg;
	struct crypto_skcipher *skcipher = __crypto_skcipher_cast(tfm);
	struct crypto_ablkcipher **ctx = crypto_tfm_ctx(tfm);
	struct crypto_ablkcipher *ablkcipher;
	struct crypto_tfm *abtfm;

	if (!crypto_mod_get(calg))
		return -EAGAIN;

	abtfm = __crypto_alloc_tfm(calg, 0, 0);
	if (IS_ERR(abtfm)) {
		crypto_mod_put(calg);
		return PTR_ERR(abtfm);
	}

	ablkcipher = __crypto_ablkcipher_cast(abtfm);
	*ctx = ablkcipher;
	tfm->exit = crypto_exit_skcipher_ops_ablkcipher;

	skcipher->setkey = skcipher_setkey_ablkcipher;
	skcipher->encrypt = skcipher_encrypt_ablkcipher;
	skcipher->decrypt = skcipher_decrypt_ablkcipher;

	skcipher->ivsize = crypto_ablkcipher_ivsize(ablkcipher);
	skcipher->reqsize = crypto_ablkcipher_reqsize(ablkcipher) +
			    sizeof(struct ablkcipher_request);
	skcipher->keysize = calg->cra_ablkcipher.max_keysize;

	return 0;
}

static void crypto_skcipher_exit_tfm(struct crypto_tfm *tfm)
{
	struct crypto_skcipher *skcipher = __crypto_skcipher_cast(tfm);
	struct skcipher_alg *alg = crypto_skcipher_alg(skcipher);

	alg->exit(skcipher);
}

static int crypto_skcipher_init_tfm(struct crypto_tfm *tfm)
{
	struct crypto_skcipher *skcipher = __crypto_skcipher_cast(tfm);
	struct skcipher_alg *alg = crypto_skcipher_alg(skcipher);

	if (tfm->__crt_alg->cra_type == &crypto_blkcipher_type)
		return crypto_init_skcipher_ops_blkcipher(tfm);

	if (tfm->__crt_alg->cra_type == &crypto_ablkcipher_type ||
	    tfm->__crt_alg->cra_type == &crypto_givcipher_type)
		return crypto_init_skcipher_ops_ablkcipher(tfm);

	skcipher->setkey = alg->setkey;
	skcipher->encrypt = alg->encrypt;
	skcipher->decrypt = alg->decrypt;
	skcipher->ivsize = alg->ivsize;
	skcipher->keysize = alg->max_keysize;

	if (alg->exit)
		skcipher->base.exit = crypto_skcipher_exit_tfm;

	if (alg->init)
		return alg->init(skcipher);

	return 0;
}

static void crypto_skcipher_free_instance(struct crypto_instance *inst)
{
	struct skcipher_instance *skcipher =
		container_of(inst, struct skcipher_instance, s.base);

	skcipher->free(skcipher);
}

static void crypto_skcipher_show(struct seq_file *m, struct crypto_alg *alg)
	__attribute__ ((unused));
static void crypto_skcipher_show(struct seq_file *m, struct crypto_alg *alg)
{
	struct skcipher_alg *skcipher = container_of(alg, struct skcipher_alg,
						     base);

	seq_printf(m, "type         : skcipher\n");
	seq_printf(m, "async        : %s\n",
		   alg->cra_flags & CRYPTO_ALG_ASYNC ?  "yes" : "no");
	seq_printf(m, "blocksize    : %u\n", alg->cra_blocksize);
	seq_printf(m, "min keysize  : %u\n", skcipher->min_keysize);
	seq_printf(m, "max keysize  : %u\n", skcipher->max_keysize);
	seq_printf(m, "ivsize       : %u\n", skcipher->ivsize);
	seq_printf(m, "chunksize    : %u\n", skcipher->chunksize);
}

#ifdef CONFIG_NET
static int crypto_skcipher_report(struct sk_buff *skb, struct crypto_alg *alg)
{
	struct crypto_report_blkcipher rblkcipher;
	struct skcipher_alg *skcipher = container_of(alg, struct skcipher_alg,
						     base);

	strncpy(rblkcipher.type, "skcipher", sizeof(rblkcipher.type));
	strncpy(rblkcipher.geniv, "<none>", sizeof(rblkcipher.geniv));

	rblkcipher.blocksize = alg->cra_blocksize;
	rblkcipher.min_keysize = skcipher->min_keysize;
	rblkcipher.max_keysize = skcipher->max_keysize;
	rblkcipher.ivsize = skcipher->ivsize;

	if (nla_put(skb, CRYPTOCFGA_REPORT_BLKCIPHER,
		    sizeof(struct crypto_report_blkcipher), &rblkcipher))
		goto nla_put_failure;
	return 0;

nla_put_failure:
	return -EMSGSIZE;
}
#else
static int crypto_skcipher_report(struct sk_buff *skb, struct crypto_alg *alg)
{
	return -ENOSYS;
}
#endif

static const struct crypto_type crypto_skcipher_type2 = {
	.extsize = crypto_skcipher_extsize,
	.init_tfm = crypto_skcipher_init_tfm,
	.free = crypto_skcipher_free_instance,
#ifdef CONFIG_PROC_FS
	.show = crypto_skcipher_show,
#endif
	.report = crypto_skcipher_report,
	.maskclear = ~CRYPTO_ALG_TYPE_MASK,
	.maskset = CRYPTO_ALG_TYPE_BLKCIPHER_MASK,
	.type = CRYPTO_ALG_TYPE_SKCIPHER,
	.tfmsize = offsetof(struct crypto_skcipher, base),
};

int crypto_grab_skcipher(struct crypto_skcipher_spawn *spawn,
			  const char *name, u32 type, u32 mask)
{
	spawn->base.frontend = &crypto_skcipher_type2;
	return crypto_grab_spawn(&spawn->base, name, type, mask);
}
EXPORT_SYMBOL_GPL(crypto_grab_skcipher);

struct crypto_skcipher *crypto_alloc_skcipher(const char *alg_name,
					      u32 type, u32 mask)
{
	return crypto_alloc_tfm(alg_name, &crypto_skcipher_type2, type, mask);
}
EXPORT_SYMBOL_GPL(crypto_alloc_skcipher);

int crypto_has_skcipher2(const char *alg_name, u32 type, u32 mask)
{
	return crypto_type_has_alg(alg_name, &crypto_skcipher_type2,
				   type, mask);
}
EXPORT_SYMBOL_GPL(crypto_has_skcipher2);

static int skcipher_prepare_alg(struct skcipher_alg *alg)
{
	struct crypto_alg *base = &alg->base;

	if (alg->ivsize > PAGE_SIZE / 8 || alg->chunksize > PAGE_SIZE / 8)
		return -EINVAL;

	if (!alg->chunksize)
		alg->chunksize = base->cra_blocksize;

	base->cra_type = &crypto_skcipher_type2;
	base->cra_flags &= ~CRYPTO_ALG_TYPE_MASK;
	base->cra_flags |= CRYPTO_ALG_TYPE_SKCIPHER;

	return 0;
}

int crypto_register_skcipher(struct skcipher_alg *alg)
{
	struct crypto_alg *base = &alg->base;
	int err;

	err = skcipher_prepare_alg(alg);
	if (err)
		return err;

	return crypto_register_alg(base);
}
EXPORT_SYMBOL_GPL(crypto_register_skcipher);

void crypto_unregister_skcipher(struct skcipher_alg *alg)
{
	crypto_unregister_alg(&alg->base);
}
EXPORT_SYMBOL_GPL(crypto_unregister_skcipher);

int crypto_register_skciphers(struct skcipher_alg *algs, int count)
{
	int i, ret;

	for (i = 0; i < count; i++) {
		ret = crypto_register_skcipher(&algs[i]);
		if (ret)
			goto err;
	}

	return 0;

err:
	for (--i; i >= 0; --i)
		crypto_unregister_skcipher(&algs[i]);

	return ret;
}
EXPORT_SYMBOL_GPL(crypto_register_skciphers);

void crypto_unregister_skciphers(struct skcipher_alg *algs, int count)
{
	int i;

	for (i = count - 1; i >= 0; --i)
		crypto_unregister_skcipher(&algs[i]);
}
EXPORT_SYMBOL_GPL(crypto_unregister_skciphers);

int skcipher_register_instance(struct crypto_template *tmpl,
			   struct skcipher_instance *inst)
{
	int err;

	err = skcipher_prepare_alg(&inst->alg);
	if (err)
		return err;

	return crypto_register_instance(tmpl, skcipher_crypto_instance(inst));
}
EXPORT_SYMBOL_GPL(skcipher_register_instance);

MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("Symmetric key cipher type");
back to top